Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Flask server for POA Graph Web Interface | |
| """ | |
| import glob | |
| import os | |
| import pickle | |
| import re | |
| import sys | |
| from flask import Flask, jsonify, request, send_from_directory | |
| from flask_cors import CORS | |
| # Get the repository root directory (parent of web_interface) | |
| REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| # Add the repository root to the path so we can import the POA graph modules | |
| sys.path.append(REPO_ROOT) | |
| from src.new_text_alignment import TextSeqGraphAlignment | |
| from src.text_poa_graph import TextPOAGraph | |
| try: | |
| from src.generation_methods import decode_consensus | |
| except ImportError: | |
| decode_consensus = None | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Base paths for different datasets (relative to repo root) | |
| GRAPH_PATHS = { | |
| "bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"), | |
| "fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"), | |
| "hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"), | |
| "refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"), | |
| "math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"), | |
| "aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"), | |
| } | |
| MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"] | |
| def index(): | |
| """Serve the main HTML file""" | |
| return send_from_directory(".", "index.html") | |
| def get_datasets(): | |
| """Get available datasets""" | |
| datasets = [] | |
| for dataset_name, path in GRAPH_PATHS.items(): | |
| if os.path.exists(path): | |
| # Count available graphs | |
| pkl_files = glob.glob(os.path.join(path, "*.pkl")) | |
| datasets.append( | |
| { | |
| "name": dataset_name, | |
| "display_name": dataset_name.upper(), | |
| "path": path, | |
| "count": len(pkl_files), | |
| } | |
| ) | |
| return jsonify({"datasets": datasets}) | |
| def get_models(): | |
| """Get available models for a specific entity""" | |
| entity = request.args.get("entity") | |
| dataset = request.args.get("dataset") | |
| if not entity: | |
| return jsonify({"error": "Entity parameter required"}), 400 | |
| if not dataset or dataset not in GRAPH_PATHS: | |
| return jsonify({"error": "Invalid dataset"}), 400 | |
| path = GRAPH_PATHS[dataset] | |
| if not os.path.exists(path): | |
| return jsonify({"error": "Dataset path not found"}), 404 | |
| models = [] | |
| pkl_files = glob.glob(os.path.join(path, "*.pkl")) | |
| for pkl_file in pkl_files: | |
| filename = os.path.basename(pkl_file) | |
| # Different filename patterns for different datasets | |
| if dataset == "bio": | |
| # Format: bio_graph_{entity}_merged_{model}.pkl | |
| match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| if entity_name == entity: | |
| models.append({"model": model, "filename": filename, "filepath": pkl_file}) | |
| elif dataset == "fp": | |
| # Format: fp_graph_{number}_merged_{model}.pkl | |
| match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| if f"Problem {entity_name}" == entity: | |
| models.append({"model": model, "filename": filename, "filepath": pkl_file}) | |
| elif dataset == "math": | |
| # Format: qwen72_math_{number}.pkl | |
| match = re.match(r"qwen72_math_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| if f"Math Problem {entity_name}" == entity: | |
| models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file}) | |
| elif dataset == "aime": | |
| # Format: aime_qwen72b_{number}.pkl | |
| match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| if f"AIME Problem {entity_name}" == entity: | |
| models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file}) | |
| else: | |
| # Generic pattern for other datasets | |
| match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| task, entity_name, model = match.groups() | |
| if entity_name == entity: | |
| models.append({"model": model, "filename": filename, "filepath": pkl_file}) | |
| return jsonify({"models": models}) | |
| def get_entities(): | |
| """Get available entities for a dataset""" | |
| dataset = request.args.get("dataset") | |
| if not dataset or dataset not in GRAPH_PATHS: | |
| return jsonify({"error": "Invalid dataset"}), 400 | |
| path = GRAPH_PATHS[dataset] | |
| if not os.path.exists(path): | |
| return jsonify({"error": "Dataset path not found"}), 404 | |
| entities = [] | |
| pkl_files = glob.glob(os.path.join(path, "*.pkl")) | |
| for pkl_file in pkl_files: | |
| filename = os.path.basename(pkl_file) | |
| # Different filename patterns for different datasets | |
| if dataset == "bio": | |
| # Format: bio_graph_{entity}_merged_{model}.pkl | |
| match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| entities.append( | |
| { | |
| "entity": entity_name, | |
| "model": model, | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| elif dataset == "fp": | |
| # Format: fp_graph_{number}_merged_{model}.pkl | |
| match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| entities.append( | |
| { | |
| "entity": f"Problem {entity_name}", | |
| "model": model, | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| elif dataset == "math": | |
| # Format: qwen72_math_{number}.pkl | |
| match = re.match(r"qwen72_math_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| entities.append( | |
| { | |
| "entity": f"Math Problem {entity_name}", | |
| "model": "qwen72b", | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| elif dataset == "aime": | |
| # Format: aime_qwen72b_{number}.pkl | |
| match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| entities.append( | |
| { | |
| "entity": f"AIME Problem {entity_name}", | |
| "model": "qwen72b", | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| else: | |
| # Generic pattern for other datasets | |
| match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| task, entity_name, model = match.groups() | |
| entities.append( | |
| { | |
| "entity": entity_name, | |
| "model": model, | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| return jsonify({"entities": entities}) | |
| def load_existing_graph(): | |
| """Load an existing graph from the stored pickle files""" | |
| try: | |
| data = request.get_json() | |
| filepath = data.get("filepath") | |
| if not filepath or not os.path.exists(filepath): | |
| return jsonify({"error": "Graph file not found"}), 404 | |
| # Read and load the pickle file | |
| try: | |
| with open(filepath, "rb") as f: | |
| graph = pickle.load(f) | |
| except Exception as e: | |
| return jsonify({"error": f"Error loading pickle file: {str(e)}"}), 500 | |
| if not isinstance(graph, TextPOAGraph): | |
| return jsonify({"error": "File does not contain a valid POA graph"}), 400 | |
| # Convert to JSON format for vis.js | |
| nodes = [] | |
| edges = [] | |
| try: | |
| # Get consensus nodes for coloring | |
| consensus_nodes = set(graph.consensus_node_ids) | |
| # Create nodes using the same logic as jsOutput | |
| for node in graph.nodeiterator()(): | |
| title_text = "" | |
| if node.sequences: | |
| title_text += f"Sequences: {node.sequences}" | |
| if node.variations: | |
| title_text += ";;;".join( | |
| [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] | |
| ) | |
| title_text = title_text.replace('"', "'") | |
| # Use the same color logic as jsOutput | |
| color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6" | |
| node_data = { | |
| "id": node.ID, | |
| "label": f"{node.ID}: {node.text}", | |
| "title": title_text, | |
| "color": color, | |
| } | |
| nodes.append(node_data) | |
| # Create edges using the same logic as jsOutput | |
| for node in graph.nodeiterator()(): | |
| nodeID = node.ID # Keep as integer | |
| for edge in node.outEdges: | |
| target = edge # Keep as integer | |
| weight = node.outEdges[edge].weight + 1.5 | |
| edge_data = { | |
| "from": nodeID, | |
| "to": target, | |
| "value": weight, | |
| "color": "#cae0e6", | |
| "arrows": "to", | |
| } | |
| edges.append(edge_data) | |
| except Exception as e: | |
| return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500 | |
| # Extract metadata from filename | |
| filename = os.path.basename(filepath) | |
| metadata = {} | |
| try: | |
| # Different filename patterns for different datasets | |
| if filename.startswith("bio_graph_"): | |
| # Format: bio_graph_{entity}_merged_{model}.pkl | |
| match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| metadata = { | |
| "task": "bio", | |
| "entity": entity_name, | |
| "model": model, | |
| "filename": filename, | |
| } | |
| elif filename.startswith("fp_graph_"): | |
| # Format: fp_graph_{number}_merged_{model}.pkl | |
| match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| metadata = { | |
| "task": "fp", | |
| "entity": f"Problem {entity_name}", | |
| "model": model, | |
| "filename": filename, | |
| } | |
| elif filename.startswith("qwen72_math_"): | |
| # Format: qwen72_math_{number}.pkl | |
| match = re.match(r"qwen72_math_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| metadata = { | |
| "task": "math", | |
| "entity": f"Math Problem {entity_name}", | |
| "model": "qwen72b", | |
| "filename": filename, | |
| } | |
| elif filename.startswith("aime_qwen72b_"): | |
| # Format: aime_qwen72b_{number}.pkl | |
| match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| metadata = { | |
| "task": "aime", | |
| "entity": f"AIME Problem {entity_name}", | |
| "model": "qwen72b", | |
| "filename": filename, | |
| } | |
| else: | |
| # Generic pattern for other datasets | |
| match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| task, entity_name, model = match.groups() | |
| metadata = { | |
| "task": task, | |
| "entity": entity_name, | |
| "model": model, | |
| "filename": filename, | |
| } | |
| except Exception: | |
| # Don't fail the request if metadata extraction fails | |
| pass | |
| # Extract text from consensus nodes | |
| consensus_text = "" | |
| try: | |
| consensus_nodes = set(graph.consensus_node_ids) | |
| consensus_node_texts = [] | |
| for node in graph.nodeiterator()(): | |
| if node.ID in consensus_nodes and node.text and node.text.strip(): | |
| consensus_node_texts.append(node.text.strip()) | |
| consensus_text = " ".join(consensus_node_texts) | |
| except Exception: | |
| consensus_text = "" | |
| # Check if we should compute consensus using decode_consensus | |
| compute_consensus = data.get("compute_consensus", False) | |
| if compute_consensus and decode_consensus: | |
| try: | |
| # Determine task from metadata or default to "bio" | |
| task = metadata.get("task", "bio") if metadata else "bio" | |
| consensus_text = decode_consensus(graph, selection_threshold=0.5, task=task) | |
| except Exception as e: | |
| print(f"DEBUG: Error computing consensus with decode_consensus: {e}") | |
| # Keep the original consensus text if decode_consensus fails | |
| # Get original sequences | |
| try: | |
| raw_sequences = graph._seqs if hasattr(graph, "_seqs") else [] | |
| # Process sequences: join with spaces and remove "||" | |
| print(f"DEBUG: Raw sequences: {raw_sequences}") | |
| original_sequences = [] | |
| for seq in raw_sequences: | |
| if isinstance(seq, list): | |
| # Join list elements with spaces | |
| processed_seq = " ".join(str(item) for item in seq) | |
| else: | |
| processed_seq = str(seq) | |
| # Remove "||" characters | |
| processed_seq = processed_seq.replace("||", "") | |
| print(f"DEBUG: Processed sequence: {processed_seq}") | |
| original_sequences.append(processed_seq) | |
| except Exception: | |
| original_sequences = [] | |
| result = { | |
| "success": True, | |
| "nodes": nodes, | |
| "edges": edges, | |
| "num_sequences": graph.num_sequences, | |
| "num_nodes": len(nodes), | |
| "num_edges": len(edges), | |
| "metadata": metadata, | |
| "consensus_text": consensus_text, | |
| "original_sequences": original_sequences, | |
| } | |
| return jsonify(result) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def create_graph(): | |
| """Create a POA graph from text sequences""" | |
| try: | |
| data = request.get_json() | |
| sequences = data.get("sequences", []) | |
| if len(sequences) < 2: | |
| return jsonify({"error": "At least 2 sequences are required"}), 400 | |
| print(f"DEBUG: Creating graph with sequences: {sequences}") | |
| # Create the graph with first sequence as string | |
| graph = TextPOAGraph(sequences[0], label=0) | |
| print("DEBUG: Initial graph created") | |
| # Add remaining sequences | |
| for i, sequence in enumerate(sequences[1:], 1): | |
| print(f"DEBUG: Adding sequence {i}: {sequence}") | |
| alignment = TextSeqGraphAlignment( | |
| text=sequence, | |
| graph=graph, | |
| fastMethod=True, | |
| globalAlign=True, | |
| matchscore=1, | |
| mismatchscore=-2, | |
| gap_open=-1, | |
| ) | |
| graph.incorporateSeqAlignment(alignment, sequence, label=i) | |
| print("DEBUG: All sequences added") | |
| # Refine the graph with proper domain and model parameters | |
| graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini") | |
| print("DEBUG: Graph refined") | |
| # Convert to JSON format for vis.js | |
| nodes = [] | |
| edges = [] | |
| try: | |
| print("DEBUG: Starting to process graph data") | |
| # Get consensus nodes for coloring (make it optional) | |
| try: | |
| consensus_nodes = set(graph.consensus_node_ids) | |
| print(f"DEBUG: Consensus nodes: {consensus_nodes}") | |
| except Exception as e: | |
| print(f"DEBUG: Error getting consensus nodes: {e}") | |
| consensus_nodes = set() # Fallback to empty set if consensus fails | |
| # Create nodes using the same logic as jsOutput | |
| for node in graph.nodeiterator()(): | |
| title_text = "" | |
| if node.sequences: | |
| title_text += f"Sequences: {node.sequences}" | |
| if node.variations: | |
| title_text += ";;;".join( | |
| [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] | |
| ) | |
| title_text = title_text.replace('"', "'") | |
| # Use the same color logic as jsOutput | |
| color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6" | |
| node_data = { | |
| "id": node.ID, | |
| "label": f"{node.ID}: {node.text}", | |
| "title": title_text, | |
| "color": color, | |
| } | |
| nodes.append(node_data) | |
| print(f"DEBUG: Created {len(nodes)} nodes") | |
| # Create edges using the same logic as jsOutput | |
| for node in graph.nodeiterator()(): | |
| nodeID = node.ID # Keep as integer | |
| for edge in node.outEdges: | |
| target = edge # Keep as integer | |
| weight = node.outEdges[edge].weight + 1.5 | |
| edge_data = { | |
| "from": nodeID, | |
| "to": target, | |
| "value": weight, | |
| "color": "#cae0e6", | |
| "arrows": "to", | |
| } | |
| edges.append(edge_data) | |
| print(f"DEBUG: Created {len(edges)} edges") | |
| except Exception as e: | |
| print(f"DEBUG: Error processing graph data: {e}") | |
| return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500 | |
| # Extract text from consensus nodes | |
| consensus_text = "" | |
| try: | |
| consensus_node_texts = [] | |
| for node in graph.nodeiterator()(): | |
| if node.ID in consensus_nodes and node.text and node.text.strip(): | |
| consensus_node_texts.append(node.text.strip()) | |
| consensus_text = " ".join(consensus_node_texts) | |
| except Exception: | |
| consensus_text = "" | |
| # Check if we should compute consensus using decode_consensus | |
| compute_consensus = data.get("compute_consensus", False) | |
| if compute_consensus and decode_consensus: | |
| try: | |
| # Default to "bio" task for new graphs | |
| consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio") | |
| except Exception as e: | |
| print(f"DEBUG: Error computing consensus with decode_consensus: {e}") | |
| # Keep the original consensus text if decode_consensus fails | |
| # Get original sequences | |
| try: | |
| raw_sequences = graph._seqs if hasattr(graph, "_seqs") else [] | |
| # Process sequences: join with spaces and remove "||" | |
| original_sequences = [] | |
| for seq in raw_sequences: | |
| if isinstance(seq, list): | |
| # Join list elements with spaces | |
| processed_seq = " ".join(str(item) for item in seq) | |
| else: | |
| processed_seq = str(seq) | |
| # Remove "||" characters | |
| processed_seq = processed_seq.replace("||", "") | |
| original_sequences.append(processed_seq) | |
| except Exception: | |
| original_sequences = [] | |
| print("DEBUG: Returning success response") | |
| return jsonify( | |
| { | |
| "success": True, | |
| "nodes": nodes, | |
| "edges": edges, | |
| "num_sequences": len(sequences), | |
| "num_nodes": len(nodes), | |
| "num_edges": len(edges), | |
| "original_sequences": original_sequences, | |
| "consensus_text": consensus_text, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"DEBUG: Main exception in create_graph: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def save_graph(): | |
| """Save a POA graph to a pickle file""" | |
| try: | |
| data = request.get_json() | |
| sequences = data.get("sequences", []) | |
| filename = data.get("filename", "graph.pkl") | |
| if len(sequences) < 2: | |
| return jsonify({"error": "At least 2 sequences are required"}), 400 | |
| # Create the graph | |
| graph = TextPOAGraph(sequences[0], label=0) | |
| # Add remaining sequences | |
| for i, sequence in enumerate(sequences[1:], 1): | |
| alignment = TextSeqGraphAlignment( | |
| text=sequence, | |
| graph=graph, | |
| fastMethod=True, | |
| globalAlign=True, | |
| matchscore=1, | |
| mismatchscore=-2, | |
| gap_open=-1, | |
| ) | |
| graph.incorporateSeqAlignment(alignment, sequence, label=i) | |
| # Refine the graph | |
| graph.refine_graph(verbose=False) | |
| # Save to pickle file | |
| graph.save_to_pickle(filename) | |
| return jsonify( | |
| {"success": True, "filename": filename, "message": f"Graph saved to {filename}"} | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def graph_info(): | |
| """Get information about a graph without creating the full visualization""" | |
| try: | |
| data = request.get_json() | |
| sequences = data.get("sequences", []) | |
| if len(sequences) < 2: | |
| return jsonify({"error": "At least 2 sequences are required"}), 400 | |
| # Create the graph | |
| graph = TextPOAGraph(sequences[0], label=0) | |
| # Add remaining sequences | |
| for i, sequence in enumerate(sequences[1:], 1): | |
| alignment = TextSeqGraphAlignment( | |
| text=sequence, | |
| graph=graph, | |
| fastMethod=True, | |
| globalAlign=True, | |
| matchscore=1, | |
| mismatchscore=-2, | |
| gap_open=-1, | |
| ) | |
| graph.incorporateSeqAlignment(alignment, sequence, label=i) | |
| # Refine the graph | |
| graph.refine_graph(verbose=False) | |
| # Get consensus response | |
| consensus_text = graph.consensus_response() | |
| return jsonify( | |
| { | |
| "success": True, | |
| "num_sequences": len(sequences), | |
| "num_nodes": graph._nnodes, | |
| "consensus_text": consensus_text, | |
| "consensus_node_ids": graph.consensus_node_ids, | |
| } | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| print("Starting POA Graph Web Interface Server...") | |
| print("Open http://localhost:8080 in your browser") | |
| app.run(debug=True, host="0.0.0.0", port=8080) | |