ConGrs / web_interface /server.py
Shahzaib98's picture
initial commit
102ae18
#!/usr/bin/env python3
"""
Flask server for POA Graph Web Interface
"""
import glob
import os
import pickle
import re
import sys
from flask import Flask, jsonify, request, send_from_directory
from flask_cors import CORS
# Get the repository root directory (parent of web_interface)
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Add the repository root to the path so we can import the POA graph modules
sys.path.append(REPO_ROOT)
from src.new_text_alignment import TextSeqGraphAlignment
from src.text_poa_graph import TextPOAGraph
try:
from src.generation_methods import decode_consensus
except ImportError:
decode_consensus = None
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Base paths for different datasets (relative to repo root)
GRAPH_PATHS = {
"bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"),
"fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"),
"hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"),
"refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"),
"math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"),
"aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"),
}
MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"]
@app.route("/")
def index():
"""Serve the main HTML file"""
return send_from_directory(".", "index.html")
@app.route("/api/datasets", methods=["GET"])
def get_datasets():
"""Get available datasets"""
datasets = []
for dataset_name, path in GRAPH_PATHS.items():
if os.path.exists(path):
# Count available graphs
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
datasets.append(
{
"name": dataset_name,
"display_name": dataset_name.upper(),
"path": path,
"count": len(pkl_files),
}
)
return jsonify({"datasets": datasets})
@app.route("/api/models", methods=["GET"])
def get_models():
"""Get available models for a specific entity"""
entity = request.args.get("entity")
dataset = request.args.get("dataset")
if not entity:
return jsonify({"error": "Entity parameter required"}), 400
if not dataset or dataset not in GRAPH_PATHS:
return jsonify({"error": "Invalid dataset"}), 400
path = GRAPH_PATHS[dataset]
if not os.path.exists(path):
return jsonify({"error": "Dataset path not found"}), 404
models = []
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
for pkl_file in pkl_files:
filename = os.path.basename(pkl_file)
# Different filename patterns for different datasets
if dataset == "bio":
# Format: bio_graph_{entity}_merged_{model}.pkl
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
if entity_name == entity:
models.append({"model": model, "filename": filename, "filepath": pkl_file})
elif dataset == "fp":
# Format: fp_graph_{number}_merged_{model}.pkl
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
if f"Problem {entity_name}" == entity:
models.append({"model": model, "filename": filename, "filepath": pkl_file})
elif dataset == "math":
# Format: qwen72_math_{number}.pkl
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
if f"Math Problem {entity_name}" == entity:
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
elif dataset == "aime":
# Format: aime_qwen72b_{number}.pkl
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
if f"AIME Problem {entity_name}" == entity:
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
else:
# Generic pattern for other datasets
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
task, entity_name, model = match.groups()
if entity_name == entity:
models.append({"model": model, "filename": filename, "filepath": pkl_file})
return jsonify({"models": models})
@app.route("/api/entities", methods=["GET"])
def get_entities():
"""Get available entities for a dataset"""
dataset = request.args.get("dataset")
if not dataset or dataset not in GRAPH_PATHS:
return jsonify({"error": "Invalid dataset"}), 400
path = GRAPH_PATHS[dataset]
if not os.path.exists(path):
return jsonify({"error": "Dataset path not found"}), 404
entities = []
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
for pkl_file in pkl_files:
filename = os.path.basename(pkl_file)
# Different filename patterns for different datasets
if dataset == "bio":
# Format: bio_graph_{entity}_merged_{model}.pkl
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
entities.append(
{
"entity": entity_name,
"model": model,
"filename": filename,
"filepath": pkl_file,
}
)
elif dataset == "fp":
# Format: fp_graph_{number}_merged_{model}.pkl
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
entities.append(
{
"entity": f"Problem {entity_name}",
"model": model,
"filename": filename,
"filepath": pkl_file,
}
)
elif dataset == "math":
# Format: qwen72_math_{number}.pkl
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
entities.append(
{
"entity": f"Math Problem {entity_name}",
"model": "qwen72b",
"filename": filename,
"filepath": pkl_file,
}
)
elif dataset == "aime":
# Format: aime_qwen72b_{number}.pkl
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
entities.append(
{
"entity": f"AIME Problem {entity_name}",
"model": "qwen72b",
"filename": filename,
"filepath": pkl_file,
}
)
else:
# Generic pattern for other datasets
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
task, entity_name, model = match.groups()
entities.append(
{
"entity": entity_name,
"model": model,
"filename": filename,
"filepath": pkl_file,
}
)
return jsonify({"entities": entities})
@app.route("/api/load_existing_graph", methods=["POST"])
def load_existing_graph():
"""Load an existing graph from the stored pickle files"""
try:
data = request.get_json()
filepath = data.get("filepath")
if not filepath or not os.path.exists(filepath):
return jsonify({"error": "Graph file not found"}), 404
# Read and load the pickle file
try:
with open(filepath, "rb") as f:
graph = pickle.load(f)
except Exception as e:
return jsonify({"error": f"Error loading pickle file: {str(e)}"}), 500
if not isinstance(graph, TextPOAGraph):
return jsonify({"error": "File does not contain a valid POA graph"}), 400
# Convert to JSON format for vis.js
nodes = []
edges = []
try:
# Get consensus nodes for coloring
consensus_nodes = set(graph.consensus_node_ids)
# Create nodes using the same logic as jsOutput
for node in graph.nodeiterator()():
title_text = ""
if node.sequences:
title_text += f"Sequences: {node.sequences}"
if node.variations:
title_text += ";;;".join(
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
)
title_text = title_text.replace('"', "'")
# Use the same color logic as jsOutput
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
node_data = {
"id": node.ID,
"label": f"{node.ID}: {node.text}",
"title": title_text,
"color": color,
}
nodes.append(node_data)
# Create edges using the same logic as jsOutput
for node in graph.nodeiterator()():
nodeID = node.ID # Keep as integer
for edge in node.outEdges:
target = edge # Keep as integer
weight = node.outEdges[edge].weight + 1.5
edge_data = {
"from": nodeID,
"to": target,
"value": weight,
"color": "#cae0e6",
"arrows": "to",
}
edges.append(edge_data)
except Exception as e:
return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
# Extract metadata from filename
filename = os.path.basename(filepath)
metadata = {}
try:
# Different filename patterns for different datasets
if filename.startswith("bio_graph_"):
# Format: bio_graph_{entity}_merged_{model}.pkl
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
metadata = {
"task": "bio",
"entity": entity_name,
"model": model,
"filename": filename,
}
elif filename.startswith("fp_graph_"):
# Format: fp_graph_{number}_merged_{model}.pkl
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
metadata = {
"task": "fp",
"entity": f"Problem {entity_name}",
"model": model,
"filename": filename,
}
elif filename.startswith("qwen72_math_"):
# Format: qwen72_math_{number}.pkl
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
metadata = {
"task": "math",
"entity": f"Math Problem {entity_name}",
"model": "qwen72b",
"filename": filename,
}
elif filename.startswith("aime_qwen72b_"):
# Format: aime_qwen72b_{number}.pkl
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
metadata = {
"task": "aime",
"entity": f"AIME Problem {entity_name}",
"model": "qwen72b",
"filename": filename,
}
else:
# Generic pattern for other datasets
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
task, entity_name, model = match.groups()
metadata = {
"task": task,
"entity": entity_name,
"model": model,
"filename": filename,
}
except Exception:
# Don't fail the request if metadata extraction fails
pass
# Extract text from consensus nodes
consensus_text = ""
try:
consensus_nodes = set(graph.consensus_node_ids)
consensus_node_texts = []
for node in graph.nodeiterator()():
if node.ID in consensus_nodes and node.text and node.text.strip():
consensus_node_texts.append(node.text.strip())
consensus_text = " ".join(consensus_node_texts)
except Exception:
consensus_text = ""
# Check if we should compute consensus using decode_consensus
compute_consensus = data.get("compute_consensus", False)
if compute_consensus and decode_consensus:
try:
# Determine task from metadata or default to "bio"
task = metadata.get("task", "bio") if metadata else "bio"
consensus_text = decode_consensus(graph, selection_threshold=0.5, task=task)
except Exception as e:
print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
# Keep the original consensus text if decode_consensus fails
# Get original sequences
try:
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
# Process sequences: join with spaces and remove "||"
print(f"DEBUG: Raw sequences: {raw_sequences}")
original_sequences = []
for seq in raw_sequences:
if isinstance(seq, list):
# Join list elements with spaces
processed_seq = " ".join(str(item) for item in seq)
else:
processed_seq = str(seq)
# Remove "||" characters
processed_seq = processed_seq.replace("||", "")
print(f"DEBUG: Processed sequence: {processed_seq}")
original_sequences.append(processed_seq)
except Exception:
original_sequences = []
result = {
"success": True,
"nodes": nodes,
"edges": edges,
"num_sequences": graph.num_sequences,
"num_nodes": len(nodes),
"num_edges": len(edges),
"metadata": metadata,
"consensus_text": consensus_text,
"original_sequences": original_sequences,
}
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/create_graph", methods=["POST"])
def create_graph():
"""Create a POA graph from text sequences"""
try:
data = request.get_json()
sequences = data.get("sequences", [])
if len(sequences) < 2:
return jsonify({"error": "At least 2 sequences are required"}), 400
print(f"DEBUG: Creating graph with sequences: {sequences}")
# Create the graph with first sequence as string
graph = TextPOAGraph(sequences[0], label=0)
print("DEBUG: Initial graph created")
# Add remaining sequences
for i, sequence in enumerate(sequences[1:], 1):
print(f"DEBUG: Adding sequence {i}: {sequence}")
alignment = TextSeqGraphAlignment(
text=sequence,
graph=graph,
fastMethod=True,
globalAlign=True,
matchscore=1,
mismatchscore=-2,
gap_open=-1,
)
graph.incorporateSeqAlignment(alignment, sequence, label=i)
print("DEBUG: All sequences added")
# Refine the graph with proper domain and model parameters
graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini")
print("DEBUG: Graph refined")
# Convert to JSON format for vis.js
nodes = []
edges = []
try:
print("DEBUG: Starting to process graph data")
# Get consensus nodes for coloring (make it optional)
try:
consensus_nodes = set(graph.consensus_node_ids)
print(f"DEBUG: Consensus nodes: {consensus_nodes}")
except Exception as e:
print(f"DEBUG: Error getting consensus nodes: {e}")
consensus_nodes = set() # Fallback to empty set if consensus fails
# Create nodes using the same logic as jsOutput
for node in graph.nodeiterator()():
title_text = ""
if node.sequences:
title_text += f"Sequences: {node.sequences}"
if node.variations:
title_text += ";;;".join(
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
)
title_text = title_text.replace('"', "'")
# Use the same color logic as jsOutput
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
node_data = {
"id": node.ID,
"label": f"{node.ID}: {node.text}",
"title": title_text,
"color": color,
}
nodes.append(node_data)
print(f"DEBUG: Created {len(nodes)} nodes")
# Create edges using the same logic as jsOutput
for node in graph.nodeiterator()():
nodeID = node.ID # Keep as integer
for edge in node.outEdges:
target = edge # Keep as integer
weight = node.outEdges[edge].weight + 1.5
edge_data = {
"from": nodeID,
"to": target,
"value": weight,
"color": "#cae0e6",
"arrows": "to",
}
edges.append(edge_data)
print(f"DEBUG: Created {len(edges)} edges")
except Exception as e:
print(f"DEBUG: Error processing graph data: {e}")
return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
# Extract text from consensus nodes
consensus_text = ""
try:
consensus_node_texts = []
for node in graph.nodeiterator()():
if node.ID in consensus_nodes and node.text and node.text.strip():
consensus_node_texts.append(node.text.strip())
consensus_text = " ".join(consensus_node_texts)
except Exception:
consensus_text = ""
# Check if we should compute consensus using decode_consensus
compute_consensus = data.get("compute_consensus", False)
if compute_consensus and decode_consensus:
try:
# Default to "bio" task for new graphs
consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio")
except Exception as e:
print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
# Keep the original consensus text if decode_consensus fails
# Get original sequences
try:
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
# Process sequences: join with spaces and remove "||"
original_sequences = []
for seq in raw_sequences:
if isinstance(seq, list):
# Join list elements with spaces
processed_seq = " ".join(str(item) for item in seq)
else:
processed_seq = str(seq)
# Remove "||" characters
processed_seq = processed_seq.replace("||", "")
original_sequences.append(processed_seq)
except Exception:
original_sequences = []
print("DEBUG: Returning success response")
return jsonify(
{
"success": True,
"nodes": nodes,
"edges": edges,
"num_sequences": len(sequences),
"num_nodes": len(nodes),
"num_edges": len(edges),
"original_sequences": original_sequences,
"consensus_text": consensus_text,
}
)
except Exception as e:
print(f"DEBUG: Main exception in create_graph: {e}")
return jsonify({"error": str(e)}), 500
@app.route("/api/save_graph", methods=["POST"])
def save_graph():
"""Save a POA graph to a pickle file"""
try:
data = request.get_json()
sequences = data.get("sequences", [])
filename = data.get("filename", "graph.pkl")
if len(sequences) < 2:
return jsonify({"error": "At least 2 sequences are required"}), 400
# Create the graph
graph = TextPOAGraph(sequences[0], label=0)
# Add remaining sequences
for i, sequence in enumerate(sequences[1:], 1):
alignment = TextSeqGraphAlignment(
text=sequence,
graph=graph,
fastMethod=True,
globalAlign=True,
matchscore=1,
mismatchscore=-2,
gap_open=-1,
)
graph.incorporateSeqAlignment(alignment, sequence, label=i)
# Refine the graph
graph.refine_graph(verbose=False)
# Save to pickle file
graph.save_to_pickle(filename)
return jsonify(
{"success": True, "filename": filename, "message": f"Graph saved to {filename}"}
)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/graph_info", methods=["POST"])
def graph_info():
"""Get information about a graph without creating the full visualization"""
try:
data = request.get_json()
sequences = data.get("sequences", [])
if len(sequences) < 2:
return jsonify({"error": "At least 2 sequences are required"}), 400
# Create the graph
graph = TextPOAGraph(sequences[0], label=0)
# Add remaining sequences
for i, sequence in enumerate(sequences[1:], 1):
alignment = TextSeqGraphAlignment(
text=sequence,
graph=graph,
fastMethod=True,
globalAlign=True,
matchscore=1,
mismatchscore=-2,
gap_open=-1,
)
graph.incorporateSeqAlignment(alignment, sequence, label=i)
# Refine the graph
graph.refine_graph(verbose=False)
# Get consensus response
consensus_text = graph.consensus_response()
return jsonify(
{
"success": True,
"num_sequences": len(sequences),
"num_nodes": graph._nnodes,
"consensus_text": consensus_text,
"consensus_node_ids": graph.consensus_node_ids,
}
)
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
print("Starting POA Graph Web Interface Server...")
print("Open http://localhost:8080 in your browser")
app.run(debug=True, host="0.0.0.0", port=8080)