#!/usr/bin/env python3 """ Flask server for POA Graph Web Interface """ import glob import os import pickle import re import sys from flask import Flask, jsonify, request, send_from_directory from flask_cors import CORS # Get the repository root directory (parent of web_interface) REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Add the repository root to the path so we can import the POA graph modules sys.path.append(REPO_ROOT) from src.new_text_alignment import TextSeqGraphAlignment from src.text_poa_graph import TextPOAGraph try: from src.generation_methods import decode_consensus except ImportError: decode_consensus = None app = Flask(__name__) CORS(app) # Enable CORS for all routes # Base paths for different datasets (relative to repo root) GRAPH_PATHS = { "bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"), "fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"), "hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"), "refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"), "math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"), "aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"), } MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"] @app.route("/") def index(): """Serve the main HTML file""" return send_from_directory(".", "index.html") @app.route("/api/datasets", methods=["GET"]) def get_datasets(): """Get available datasets""" datasets = [] for dataset_name, path in GRAPH_PATHS.items(): if os.path.exists(path): # Count available graphs pkl_files = glob.glob(os.path.join(path, "*.pkl")) datasets.append( { "name": dataset_name, "display_name": dataset_name.upper(), "path": path, "count": len(pkl_files), } ) return jsonify({"datasets": datasets}) @app.route("/api/models", methods=["GET"]) def get_models(): """Get available models for a specific entity""" entity = request.args.get("entity") dataset = request.args.get("dataset") if not entity: return jsonify({"error": "Entity parameter required"}), 400 if not dataset or dataset not in GRAPH_PATHS: return jsonify({"error": "Invalid dataset"}), 400 path = GRAPH_PATHS[dataset] if not os.path.exists(path): return jsonify({"error": "Dataset path not found"}), 404 models = [] pkl_files = glob.glob(os.path.join(path, "*.pkl")) for pkl_file in pkl_files: filename = os.path.basename(pkl_file) # Different filename patterns for different datasets if dataset == "bio": # Format: bio_graph_{entity}_merged_{model}.pkl match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) if match: entity_name, model = match.groups() if entity_name == entity: models.append({"model": model, "filename": filename, "filepath": pkl_file}) elif dataset == "fp": # Format: fp_graph_{number}_merged_{model}.pkl match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) if match: entity_name, model = match.groups() if f"Problem {entity_name}" == entity: models.append({"model": model, "filename": filename, "filepath": pkl_file}) elif dataset == "math": # Format: qwen72_math_{number}.pkl match = re.match(r"qwen72_math_(\d+)\.pkl", filename) if match: entity_name = match.group(1) if f"Math Problem {entity_name}" == entity: models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file}) elif dataset == "aime": # Format: aime_qwen72b_{number}.pkl match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) if match: entity_name = match.group(1) if f"AIME Problem {entity_name}" == entity: models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file}) else: # Generic pattern for other datasets match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) if match: task, entity_name, model = match.groups() if entity_name == entity: models.append({"model": model, "filename": filename, "filepath": pkl_file}) return jsonify({"models": models}) @app.route("/api/entities", methods=["GET"]) def get_entities(): """Get available entities for a dataset""" dataset = request.args.get("dataset") if not dataset or dataset not in GRAPH_PATHS: return jsonify({"error": "Invalid dataset"}), 400 path = GRAPH_PATHS[dataset] if not os.path.exists(path): return jsonify({"error": "Dataset path not found"}), 404 entities = [] pkl_files = glob.glob(os.path.join(path, "*.pkl")) for pkl_file in pkl_files: filename = os.path.basename(pkl_file) # Different filename patterns for different datasets if dataset == "bio": # Format: bio_graph_{entity}_merged_{model}.pkl match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) if match: entity_name, model = match.groups() entities.append( { "entity": entity_name, "model": model, "filename": filename, "filepath": pkl_file, } ) elif dataset == "fp": # Format: fp_graph_{number}_merged_{model}.pkl match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) if match: entity_name, model = match.groups() entities.append( { "entity": f"Problem {entity_name}", "model": model, "filename": filename, "filepath": pkl_file, } ) elif dataset == "math": # Format: qwen72_math_{number}.pkl match = re.match(r"qwen72_math_(\d+)\.pkl", filename) if match: entity_name = match.group(1) entities.append( { "entity": f"Math Problem {entity_name}", "model": "qwen72b", "filename": filename, "filepath": pkl_file, } ) elif dataset == "aime": # Format: aime_qwen72b_{number}.pkl match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) if match: entity_name = match.group(1) entities.append( { "entity": f"AIME Problem {entity_name}", "model": "qwen72b", "filename": filename, "filepath": pkl_file, } ) else: # Generic pattern for other datasets match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) if match: task, entity_name, model = match.groups() entities.append( { "entity": entity_name, "model": model, "filename": filename, "filepath": pkl_file, } ) return jsonify({"entities": entities}) @app.route("/api/load_existing_graph", methods=["POST"]) def load_existing_graph(): """Load an existing graph from the stored pickle files""" try: data = request.get_json() filepath = data.get("filepath") if not filepath or not os.path.exists(filepath): return jsonify({"error": "Graph file not found"}), 404 # Read and load the pickle file try: with open(filepath, "rb") as f: graph = pickle.load(f) except Exception as e: return jsonify({"error": f"Error loading pickle file: {str(e)}"}), 500 if not isinstance(graph, TextPOAGraph): return jsonify({"error": "File does not contain a valid POA graph"}), 400 # Convert to JSON format for vis.js nodes = [] edges = [] try: # Get consensus nodes for coloring consensus_nodes = set(graph.consensus_node_ids) # Create nodes using the same logic as jsOutput for node in graph.nodeiterator()(): title_text = "" if node.sequences: title_text += f"Sequences: {node.sequences}" if node.variations: title_text += ";;;".join( [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] ) title_text = title_text.replace('"', "'") # Use the same color logic as jsOutput color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6" node_data = { "id": node.ID, "label": f"{node.ID}: {node.text}", "title": title_text, "color": color, } nodes.append(node_data) # Create edges using the same logic as jsOutput for node in graph.nodeiterator()(): nodeID = node.ID # Keep as integer for edge in node.outEdges: target = edge # Keep as integer weight = node.outEdges[edge].weight + 1.5 edge_data = { "from": nodeID, "to": target, "value": weight, "color": "#cae0e6", "arrows": "to", } edges.append(edge_data) except Exception as e: return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500 # Extract metadata from filename filename = os.path.basename(filepath) metadata = {} try: # Different filename patterns for different datasets if filename.startswith("bio_graph_"): # Format: bio_graph_{entity}_merged_{model}.pkl match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) if match: entity_name, model = match.groups() metadata = { "task": "bio", "entity": entity_name, "model": model, "filename": filename, } elif filename.startswith("fp_graph_"): # Format: fp_graph_{number}_merged_{model}.pkl match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) if match: entity_name, model = match.groups() metadata = { "task": "fp", "entity": f"Problem {entity_name}", "model": model, "filename": filename, } elif filename.startswith("qwen72_math_"): # Format: qwen72_math_{number}.pkl match = re.match(r"qwen72_math_(\d+)\.pkl", filename) if match: entity_name = match.group(1) metadata = { "task": "math", "entity": f"Math Problem {entity_name}", "model": "qwen72b", "filename": filename, } elif filename.startswith("aime_qwen72b_"): # Format: aime_qwen72b_{number}.pkl match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) if match: entity_name = match.group(1) metadata = { "task": "aime", "entity": f"AIME Problem {entity_name}", "model": "qwen72b", "filename": filename, } else: # Generic pattern for other datasets match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) if match: task, entity_name, model = match.groups() metadata = { "task": task, "entity": entity_name, "model": model, "filename": filename, } except Exception: # Don't fail the request if metadata extraction fails pass # Extract text from consensus nodes consensus_text = "" try: consensus_nodes = set(graph.consensus_node_ids) consensus_node_texts = [] for node in graph.nodeiterator()(): if node.ID in consensus_nodes and node.text and node.text.strip(): consensus_node_texts.append(node.text.strip()) consensus_text = " ".join(consensus_node_texts) except Exception: consensus_text = "" # Check if we should compute consensus using decode_consensus compute_consensus = data.get("compute_consensus", False) if compute_consensus and decode_consensus: try: # Determine task from metadata or default to "bio" task = metadata.get("task", "bio") if metadata else "bio" consensus_text = decode_consensus(graph, selection_threshold=0.5, task=task) except Exception as e: print(f"DEBUG: Error computing consensus with decode_consensus: {e}") # Keep the original consensus text if decode_consensus fails # Get original sequences try: raw_sequences = graph._seqs if hasattr(graph, "_seqs") else [] # Process sequences: join with spaces and remove "||" print(f"DEBUG: Raw sequences: {raw_sequences}") original_sequences = [] for seq in raw_sequences: if isinstance(seq, list): # Join list elements with spaces processed_seq = " ".join(str(item) for item in seq) else: processed_seq = str(seq) # Remove "||" characters processed_seq = processed_seq.replace("||", "") print(f"DEBUG: Processed sequence: {processed_seq}") original_sequences.append(processed_seq) except Exception: original_sequences = [] result = { "success": True, "nodes": nodes, "edges": edges, "num_sequences": graph.num_sequences, "num_nodes": len(nodes), "num_edges": len(edges), "metadata": metadata, "consensus_text": consensus_text, "original_sequences": original_sequences, } return jsonify(result) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/create_graph", methods=["POST"]) def create_graph(): """Create a POA graph from text sequences""" try: data = request.get_json() sequences = data.get("sequences", []) if len(sequences) < 2: return jsonify({"error": "At least 2 sequences are required"}), 400 print(f"DEBUG: Creating graph with sequences: {sequences}") # Create the graph with first sequence as string graph = TextPOAGraph(sequences[0], label=0) print("DEBUG: Initial graph created") # Add remaining sequences for i, sequence in enumerate(sequences[1:], 1): print(f"DEBUG: Adding sequence {i}: {sequence}") alignment = TextSeqGraphAlignment( text=sequence, graph=graph, fastMethod=True, globalAlign=True, matchscore=1, mismatchscore=-2, gap_open=-1, ) graph.incorporateSeqAlignment(alignment, sequence, label=i) print("DEBUG: All sequences added") # Refine the graph with proper domain and model parameters graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini") print("DEBUG: Graph refined") # Convert to JSON format for vis.js nodes = [] edges = [] try: print("DEBUG: Starting to process graph data") # Get consensus nodes for coloring (make it optional) try: consensus_nodes = set(graph.consensus_node_ids) print(f"DEBUG: Consensus nodes: {consensus_nodes}") except Exception as e: print(f"DEBUG: Error getting consensus nodes: {e}") consensus_nodes = set() # Fallback to empty set if consensus fails # Create nodes using the same logic as jsOutput for node in graph.nodeiterator()(): title_text = "" if node.sequences: title_text += f"Sequences: {node.sequences}" if node.variations: title_text += ";;;".join( [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] ) title_text = title_text.replace('"', "'") # Use the same color logic as jsOutput color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6" node_data = { "id": node.ID, "label": f"{node.ID}: {node.text}", "title": title_text, "color": color, } nodes.append(node_data) print(f"DEBUG: Created {len(nodes)} nodes") # Create edges using the same logic as jsOutput for node in graph.nodeiterator()(): nodeID = node.ID # Keep as integer for edge in node.outEdges: target = edge # Keep as integer weight = node.outEdges[edge].weight + 1.5 edge_data = { "from": nodeID, "to": target, "value": weight, "color": "#cae0e6", "arrows": "to", } edges.append(edge_data) print(f"DEBUG: Created {len(edges)} edges") except Exception as e: print(f"DEBUG: Error processing graph data: {e}") return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500 # Extract text from consensus nodes consensus_text = "" try: consensus_node_texts = [] for node in graph.nodeiterator()(): if node.ID in consensus_nodes and node.text and node.text.strip(): consensus_node_texts.append(node.text.strip()) consensus_text = " ".join(consensus_node_texts) except Exception: consensus_text = "" # Check if we should compute consensus using decode_consensus compute_consensus = data.get("compute_consensus", False) if compute_consensus and decode_consensus: try: # Default to "bio" task for new graphs consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio") except Exception as e: print(f"DEBUG: Error computing consensus with decode_consensus: {e}") # Keep the original consensus text if decode_consensus fails # Get original sequences try: raw_sequences = graph._seqs if hasattr(graph, "_seqs") else [] # Process sequences: join with spaces and remove "||" original_sequences = [] for seq in raw_sequences: if isinstance(seq, list): # Join list elements with spaces processed_seq = " ".join(str(item) for item in seq) else: processed_seq = str(seq) # Remove "||" characters processed_seq = processed_seq.replace("||", "") original_sequences.append(processed_seq) except Exception: original_sequences = [] print("DEBUG: Returning success response") return jsonify( { "success": True, "nodes": nodes, "edges": edges, "num_sequences": len(sequences), "num_nodes": len(nodes), "num_edges": len(edges), "original_sequences": original_sequences, "consensus_text": consensus_text, } ) except Exception as e: print(f"DEBUG: Main exception in create_graph: {e}") return jsonify({"error": str(e)}), 500 @app.route("/api/save_graph", methods=["POST"]) def save_graph(): """Save a POA graph to a pickle file""" try: data = request.get_json() sequences = data.get("sequences", []) filename = data.get("filename", "graph.pkl") if len(sequences) < 2: return jsonify({"error": "At least 2 sequences are required"}), 400 # Create the graph graph = TextPOAGraph(sequences[0], label=0) # Add remaining sequences for i, sequence in enumerate(sequences[1:], 1): alignment = TextSeqGraphAlignment( text=sequence, graph=graph, fastMethod=True, globalAlign=True, matchscore=1, mismatchscore=-2, gap_open=-1, ) graph.incorporateSeqAlignment(alignment, sequence, label=i) # Refine the graph graph.refine_graph(verbose=False) # Save to pickle file graph.save_to_pickle(filename) return jsonify( {"success": True, "filename": filename, "message": f"Graph saved to {filename}"} ) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/graph_info", methods=["POST"]) def graph_info(): """Get information about a graph without creating the full visualization""" try: data = request.get_json() sequences = data.get("sequences", []) if len(sequences) < 2: return jsonify({"error": "At least 2 sequences are required"}), 400 # Create the graph graph = TextPOAGraph(sequences[0], label=0) # Add remaining sequences for i, sequence in enumerate(sequences[1:], 1): alignment = TextSeqGraphAlignment( text=sequence, graph=graph, fastMethod=True, globalAlign=True, matchscore=1, mismatchscore=-2, gap_open=-1, ) graph.incorporateSeqAlignment(alignment, sequence, label=i) # Refine the graph graph.refine_graph(verbose=False) # Get consensus response consensus_text = graph.consensus_response() return jsonify( { "success": True, "num_sequences": len(sequences), "num_nodes": graph._nnodes, "consensus_text": consensus_text, "consensus_node_ids": graph.consensus_node_ids, } ) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": print("Starting POA Graph Web Interface Server...") print("Open http://localhost:8080 in your browser") app.run(debug=True, host="0.0.0.0", port=8080)