Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Flask server for POA Graph Web Interface | |
| Modified for Hugging Face Spaces deployment | |
| """ | |
| import glob | |
| import os | |
| import pickle | |
| import re | |
| import sys | |
| from flask import Flask, jsonify, request, send_from_directory | |
| from flask_cors import CORS | |
| # Get the directory where this script is located (should be project root) | |
| REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| # Add the repository root to the path so we can import the POA graph modules | |
| sys.path.append(REPO_ROOT) | |
| from src.new_text_alignment import TextSeqGraphAlignment | |
| from src.text_poa_graph import TextPOAGraph | |
| try: | |
| from src.generation_methods import decode_consensus | |
| except ImportError: | |
| decode_consensus = None | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Base paths for different datasets (relative to repo root) | |
| GRAPH_PATHS = { | |
| "bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"), | |
| "fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"), | |
| "hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"), | |
| "refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"), | |
| "math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"), | |
| "aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"), | |
| } | |
| MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"] | |
| def index(): | |
| """Serve the main HTML file from web_interface directory""" | |
| web_interface_path = os.path.join(REPO_ROOT, "web_interface") | |
| return send_from_directory(web_interface_path, "index.html") | |
| def serve_static(path): | |
| """Serve static files from web_interface directory""" | |
| web_interface_path = os.path.join(REPO_ROOT, "web_interface") | |
| return send_from_directory(web_interface_path, path) | |
| def get_datasets(): | |
| """Get available datasets""" | |
| datasets = [] | |
| for dataset_name, path in GRAPH_PATHS.items(): | |
| if os.path.exists(path): | |
| # Count available graphs | |
| pkl_files = glob.glob(os.path.join(path, "*.pkl")) | |
| datasets.append( | |
| { | |
| "name": dataset_name, | |
| "display_name": dataset_name.upper(), | |
| "path": path, | |
| "count": len(pkl_files), | |
| } | |
| ) | |
| return jsonify({"datasets": datasets}) | |
| def get_models(): | |
| """Get available models for a specific entity""" | |
| entity = request.args.get("entity") | |
| dataset = request.args.get("dataset") | |
| if not entity: | |
| return jsonify({"error": "Entity parameter required"}), 400 | |
| if not dataset or dataset not in GRAPH_PATHS: | |
| return jsonify({"error": "Invalid dataset"}), 400 | |
| path = GRAPH_PATHS[dataset] | |
| if not os.path.exists(path): | |
| return jsonify({"error": "Dataset path not found"}), 404 | |
| models = [] | |
| pkl_files = glob.glob(os.path.join(path, "*.pkl")) | |
| for pkl_file in pkl_files: | |
| filename = os.path.basename(pkl_file) | |
| # Different filename patterns for different datasets | |
| if dataset == "bio": | |
| # Format: bio_graph_{entity}_merged_{model}.pkl | |
| match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| if entity_name == entity: | |
| models.append({"model": model, "filename": filename, "filepath": pkl_file}) | |
| elif dataset == "fp": | |
| # Format: fp_graph_{number}_merged_{model}.pkl | |
| match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| if f"Problem {entity_name}" == entity: | |
| models.append({"model": model, "filename": filename, "filepath": pkl_file}) | |
| elif dataset == "math": | |
| # Format: qwen72_math_{number}.pkl | |
| match = re.match(r"qwen72_math_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| if f"Math Problem {entity_name}" == entity: | |
| models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file}) | |
| elif dataset == "aime": | |
| # Format: aime_qwen72b_{number}.pkl | |
| match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| if f"AIME Problem {entity_name}" == entity: | |
| models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file}) | |
| else: | |
| # Generic pattern for other datasets | |
| match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| task, entity_name, model = match.groups() | |
| if entity_name == entity: | |
| models.append({"model": model, "filename": filename, "filepath": pkl_file}) | |
| return jsonify({"models": models}) | |
| def get_entities(): | |
| """Get available entities for a dataset""" | |
| dataset = request.args.get("dataset") | |
| if not dataset or dataset not in GRAPH_PATHS: | |
| return jsonify({"error": "Invalid dataset"}), 400 | |
| path = GRAPH_PATHS[dataset] | |
| if not os.path.exists(path): | |
| return jsonify({"error": "Dataset path not found"}), 404 | |
| entities = [] | |
| pkl_files = glob.glob(os.path.join(path, "*.pkl")) | |
| for pkl_file in pkl_files: | |
| filename = os.path.basename(pkl_file) | |
| # Different filename patterns for different datasets | |
| if dataset == "bio": | |
| # Format: bio_graph_{entity}_merged_{model}.pkl | |
| match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| entities.append( | |
| { | |
| "entity": entity_name, | |
| "model": model, | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| elif dataset == "fp": | |
| # Format: fp_graph_{number}_merged_{model}.pkl | |
| match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| entity_name, model = match.groups() | |
| entities.append( | |
| { | |
| "entity": f"Problem {entity_name}", | |
| "model": model, | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| elif dataset == "math": | |
| # Format: qwen72_math_{number}.pkl | |
| match = re.match(r"qwen72_math_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| entities.append( | |
| { | |
| "entity": f"Math Problem {entity_name}", | |
| "model": "qwen72b", | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| elif dataset == "aime": | |
| # Format: aime_qwen72b_{number}.pkl | |
| match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename) | |
| if match: | |
| entity_name = match.group(1) | |
| entities.append( | |
| { | |
| "entity": f"AIME Problem {entity_name}", | |
| "model": "qwen72b", | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| else: | |
| # Generic pattern for other datasets | |
| match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename) | |
| if match: | |
| task, entity_name, model = match.groups() | |
| entities.append( | |
| { | |
| "entity": entity_name, | |
| "model": model, | |
| "filename": filename, | |
| "filepath": pkl_file, | |
| } | |
| ) | |
| # Get unique entities | |
| unique_entities = {} | |
| for entity_data in entities: | |
| entity_key = entity_data["entity"] | |
| if entity_key not in unique_entities: | |
| unique_entities[entity_key] = entity_data | |
| return jsonify({"entities": list(unique_entities.values())}) | |
| def load_existing_graph(): | |
| """Load an existing graph from pickle file""" | |
| try: | |
| data = request.get_json() | |
| filepath = data.get("filepath") | |
| if not filepath or not os.path.exists(filepath): | |
| return jsonify({"error": "Invalid filepath"}), 400 | |
| # Load the graph from pickle | |
| with open(filepath, "rb") as f: | |
| graph = pickle.load(f) | |
| # Convert to JSON format for vis.js | |
| nodes = [] | |
| edges = [] | |
| # Get consensus nodes for coloring | |
| try: | |
| consensus_nodes = set(graph.consensus_node_ids) | |
| except Exception: | |
| consensus_nodes = set() | |
| # Create nodes | |
| for node in graph.nodeiterator()(): | |
| title_text = "" | |
| if node.sequences: | |
| title_text += f"Sequences: {node.sequences}" | |
| if node.variations: | |
| title_text += ";;;".join( | |
| [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] | |
| ) | |
| title_text = title_text.replace('"', "'") | |
| color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6" | |
| node_data = { | |
| "id": node.ID, | |
| "label": f"{node.ID}: {node.text}", | |
| "title": title_text, | |
| "color": color, | |
| } | |
| nodes.append(node_data) | |
| # Create edges | |
| for node in graph.nodeiterator()(): | |
| nodeID = node.ID | |
| for edge in node.outEdges: | |
| target = edge | |
| weight = node.outEdges[edge].weight + 1.5 | |
| edge_data = { | |
| "from": nodeID, | |
| "to": target, | |
| "value": weight, | |
| "color": "#cae0e6", | |
| "arrows": "to", | |
| } | |
| edges.append(edge_data) | |
| # Get consensus text | |
| consensus_text = "" | |
| try: | |
| consensus_node_texts = [] | |
| for node in graph.nodeiterator()(): | |
| if node.ID in consensus_nodes and node.text and node.text.strip(): | |
| consensus_node_texts.append(node.text.strip()) | |
| consensus_text = " ".join(consensus_node_texts) | |
| except Exception: | |
| consensus_text = "" | |
| # Get original sequences | |
| try: | |
| raw_sequences = graph._seqs if hasattr(graph, "_seqs") else [] | |
| original_sequences = [] | |
| for seq in raw_sequences: | |
| if isinstance(seq, list): | |
| processed_seq = " ".join(str(item) for item in seq) | |
| else: | |
| processed_seq = str(seq) | |
| processed_seq = processed_seq.replace("||", "") | |
| original_sequences.append(processed_seq) | |
| except Exception: | |
| original_sequences = [] | |
| return jsonify( | |
| { | |
| "success": True, | |
| "nodes": nodes, | |
| "edges": edges, | |
| "num_sequences": len(original_sequences), | |
| "num_nodes": len(nodes), | |
| "num_edges": len(edges), | |
| "original_sequences": original_sequences, | |
| "consensus_text": consensus_text, | |
| } | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # ============================================================================ | |
| # COMMENTED OUT: Create New Graph Feature - Disabled per user request | |
| # ============================================================================ | |
| # @app.route("/api/create_graph", methods=["POST"]) | |
| # def create_graph(): | |
| # """Create a new POA graph from text sequences""" | |
| # try: | |
| # print("DEBUG: Received create_graph request") | |
| # data = request.get_json() | |
| # sequences = data.get("sequences", []) | |
| # | |
| # print(f"DEBUG: Number of sequences: {len(sequences)}") | |
| # | |
| # if len(sequences) < 2: | |
| # return jsonify({"error": "At least 2 sequences are required"}), 400 | |
| # | |
| # print("DEBUG: Creating initial graph") | |
| # # Create the graph from first sequence | |
| # graph = TextPOAGraph(sequences[0], label=0) | |
| # print("DEBUG: Initial graph created") | |
| # | |
| # # Add remaining sequences | |
| # for i, sequence in enumerate(sequences[1:], 1): | |
| # print(f"DEBUG: Adding sequence {i}") | |
| # alignment = TextSeqGraphAlignment( | |
| # text=sequence, | |
| # graph=graph, | |
| # fastMethod=True, | |
| # globalAlign=True, | |
| # matchscore=1, | |
| # mismatchscore=-2, | |
| # gap_open=-1, | |
| # ) | |
| # graph.incorporateSeqAlignment(alignment, sequence, label=i) | |
| # | |
| # print("DEBUG: All sequences added") | |
| # | |
| # # Refine the graph with proper domain and model parameters | |
| # graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini") | |
| # print("DEBUG: Graph refined") | |
| # | |
| # # Convert to JSON format for vis.js | |
| # nodes = [] | |
| # edges = [] | |
| # | |
| # try: | |
| # print("DEBUG: Starting to process graph data") | |
| # # Get consensus nodes for coloring (make it optional) | |
| # try: | |
| # consensus_nodes = set(graph.consensus_node_ids) | |
| # print(f"DEBUG: Consensus nodes: {consensus_nodes}") | |
| # except Exception as e: | |
| # print(f"DEBUG: Error getting consensus nodes: {e}") | |
| # consensus_nodes = set() # Fallback to empty set if consensus fails | |
| # | |
| # # Create nodes using the same logic as jsOutput | |
| # for node in graph.nodeiterator()(): | |
| # title_text = "" | |
| # if node.sequences: | |
| # title_text += f"Sequences: {node.sequences}" | |
| # if node.variations: | |
| # title_text += ";;;".join( | |
| # [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()] | |
| # ) | |
| # title_text = title_text.replace('"', "'") | |
| # | |
| # # Use the same color logic as jsOutput | |
| # color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6" | |
| # | |
| # node_data = { | |
| # "id": node.ID, | |
| # "label": f"{node.ID}: {node.text}", | |
| # "title": title_text, | |
| # "color": color, | |
| # } | |
| # nodes.append(node_data) | |
| # | |
| # print(f"DEBUG: Created {len(nodes)} nodes") | |
| # | |
| # # Create edges using the same logic as jsOutput | |
| # for node in graph.nodeiterator()(): | |
| # nodeID = node.ID # Keep as integer | |
| # for edge in node.outEdges: | |
| # target = edge # Keep as integer | |
| # weight = node.outEdges[edge].weight + 1.5 | |
| # edge_data = { | |
| # "from": nodeID, | |
| # "to": target, | |
| # "value": weight, | |
| # "color": "#cae0e6", | |
| # "arrows": "to", | |
| # } | |
| # edges.append(edge_data) | |
| # | |
| # print(f"DEBUG: Created {len(edges)} edges") | |
| # except Exception as e: | |
| # print(f"DEBUG: Error processing graph data: {e}") | |
| # return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500 | |
| # | |
| # # Extract text from consensus nodes | |
| # consensus_text = "" | |
| # try: | |
| # consensus_node_texts = [] | |
| # for node in graph.nodeiterator()(): | |
| # if node.ID in consensus_nodes and node.text and node.text.strip(): | |
| # consensus_node_texts.append(node.text.strip()) | |
| # consensus_text = " ".join(consensus_node_texts) | |
| # except Exception: | |
| # consensus_text = "" | |
| # | |
| # # Check if we should compute consensus using decode_consensus | |
| # compute_consensus = data.get("compute_consensus", False) | |
| # if compute_consensus and decode_consensus: | |
| # try: | |
| # # Default to "bio" task for new graphs | |
| # consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio") | |
| # except Exception as e: | |
| # print(f"DEBUG: Error computing consensus with decode_consensus: {e}") | |
| # # Keep the original consensus text if decode_consensus fails | |
| # | |
| # # Get original sequences | |
| # try: | |
| # raw_sequences = graph._seqs if hasattr(graph, "_seqs") else [] | |
| # # Process sequences: join with spaces and remove "||" | |
| # original_sequences = [] | |
| # for seq in raw_sequences: | |
| # if isinstance(seq, list): | |
| # # Join list elements with spaces | |
| # processed_seq = " ".join(str(item) for item in seq) | |
| # else: | |
| # processed_seq = str(seq) | |
| # # Remove "||" characters | |
| # processed_seq = processed_seq.replace("||", "") | |
| # original_sequences.append(processed_seq) | |
| # except Exception: | |
| # original_sequences = [] | |
| # | |
| # print("DEBUG: Returning success response") | |
| # return jsonify( | |
| # { | |
| # "success": True, | |
| # "nodes": nodes, | |
| # "edges": edges, | |
| # "num_sequences": len(sequences), | |
| # "num_nodes": len(nodes), | |
| # "num_edges": len(edges), | |
| # "original_sequences": original_sequences, | |
| # "consensus_text": consensus_text, | |
| # } | |
| # ) | |
| # | |
| # except Exception as e: | |
| # print(f"DEBUG: Main exception in create_graph: {e}") | |
| # return jsonify({"error": str(e)}), 500 | |
| # @app.route("/api/save_graph", methods=["POST"]) | |
| # def save_graph(): | |
| # """Save a POA graph to a pickle file""" | |
| # try: | |
| # data = request.get_json() | |
| # sequences = data.get("sequences", []) | |
| # filename = data.get("filename", "graph.pkl") | |
| # | |
| # if len(sequences) < 2: | |
| # return jsonify({"error": "At least 2 sequences are required"}), 400 | |
| # | |
| # # Create the graph | |
| # graph = TextPOAGraph(sequences[0], label=0) | |
| # | |
| # # Add remaining sequences | |
| # for i, sequence in enumerate(sequences[1:], 1): | |
| # alignment = TextSeqGraphAlignment( | |
| # text=sequence, | |
| # graph=graph, | |
| # fastMethod=True, | |
| # globalAlign=True, | |
| # matchscore=1, | |
| # mismatchscore=-2, | |
| # gap_open=-1, | |
| # ) | |
| # graph.incorporateSeqAlignment(alignment, sequence, label=i) | |
| # | |
| # # Refine the graph | |
| # graph.refine_graph(verbose=False) | |
| # | |
| # # Save to pickle file | |
| # graph.save_to_pickle(filename) | |
| # | |
| # return jsonify( | |
| # {"success": True, "filename": filename, "message": f"Graph saved to {filename}"} | |
| # ) | |
| # | |
| # except Exception as e: | |
| # return jsonify({"error": str(e)}), 500 | |
| # @app.route("/api/graph_info", methods=["POST"]) | |
| # def graph_info(): | |
| # """Get information about a graph without creating the full visualization""" | |
| # try: | |
| # data = request.get_json() | |
| # sequences = data.get("sequences", []) | |
| # | |
| # if len(sequences) < 2: | |
| # return jsonify({"error": "At least 2 sequences are required"}), 400 | |
| # | |
| # # Create the graph | |
| # graph = TextPOAGraph(sequences[0], label=0) | |
| # | |
| # # Add remaining sequences | |
| # for i, sequence in enumerate(sequences[1:], 1): | |
| # alignment = TextSeqGraphAlignment( | |
| # text=sequence, | |
| # graph=graph, | |
| # fastMethod=True, | |
| # globalAlign=True, | |
| # matchscore=1, | |
| # mismatchscore=-2, | |
| # gap_open=-1, | |
| # ) | |
| # graph.incorporateSeqAlignment(alignment, sequence, label=i) | |
| # | |
| # # Refine the graph | |
| # graph.refine_graph(verbose=False) | |
| # | |
| # # Get consensus response | |
| # consensus_text = graph.consensus_response() | |
| # | |
| # return jsonify( | |
| # { | |
| # "success": True, | |
| # "num_sequences": len(sequences), | |
| # "num_nodes": graph._nnodes, | |
| # "consensus_text": consensus_text, | |
| # "consensus_node_ids": graph.consensus_node_ids, | |
| # } | |
| # ) | |
| # | |
| # except Exception as e: | |
| # return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| # For HF Spaces, port must be 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| print("Starting POA Graph Web Interface Server...") | |
| print(f"Repository root: {REPO_ROOT}") | |
| print(f"Serving static files from: {os.path.join(REPO_ROOT, 'web_interface')}") | |
| print(f"Open http://localhost:{port} in your browser") | |
| app.run(debug=False, host="0.0.0.0", port=port) |