ConGrs / app.py
Shahzaib98's picture
Removed customization feature for now
cb83116 verified
#!/usr/bin/env python3
"""
Flask server for POA Graph Web Interface
Modified for Hugging Face Spaces deployment
"""
import glob
import os
import pickle
import re
import sys
from flask import Flask, jsonify, request, send_from_directory
from flask_cors import CORS
# Get the directory where this script is located (should be project root)
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
# Add the repository root to the path so we can import the POA graph modules
sys.path.append(REPO_ROOT)
from src.new_text_alignment import TextSeqGraphAlignment
from src.text_poa_graph import TextPOAGraph
try:
from src.generation_methods import decode_consensus
except ImportError:
decode_consensus = None
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Base paths for different datasets (relative to repo root)
GRAPH_PATHS = {
"bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"),
"fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"),
"hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"),
"refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"),
"math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"),
"aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"),
}
MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"]
@app.route("/")
def index():
"""Serve the main HTML file from web_interface directory"""
web_interface_path = os.path.join(REPO_ROOT, "web_interface")
return send_from_directory(web_interface_path, "index.html")
@app.route("/<path:path>")
def serve_static(path):
"""Serve static files from web_interface directory"""
web_interface_path = os.path.join(REPO_ROOT, "web_interface")
return send_from_directory(web_interface_path, path)
@app.route("/api/datasets", methods=["GET"])
def get_datasets():
"""Get available datasets"""
datasets = []
for dataset_name, path in GRAPH_PATHS.items():
if os.path.exists(path):
# Count available graphs
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
datasets.append(
{
"name": dataset_name,
"display_name": dataset_name.upper(),
"path": path,
"count": len(pkl_files),
}
)
return jsonify({"datasets": datasets})
@app.route("/api/models", methods=["GET"])
def get_models():
"""Get available models for a specific entity"""
entity = request.args.get("entity")
dataset = request.args.get("dataset")
if not entity:
return jsonify({"error": "Entity parameter required"}), 400
if not dataset or dataset not in GRAPH_PATHS:
return jsonify({"error": "Invalid dataset"}), 400
path = GRAPH_PATHS[dataset]
if not os.path.exists(path):
return jsonify({"error": "Dataset path not found"}), 404
models = []
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
for pkl_file in pkl_files:
filename = os.path.basename(pkl_file)
# Different filename patterns for different datasets
if dataset == "bio":
# Format: bio_graph_{entity}_merged_{model}.pkl
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
if entity_name == entity:
models.append({"model": model, "filename": filename, "filepath": pkl_file})
elif dataset == "fp":
# Format: fp_graph_{number}_merged_{model}.pkl
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
if f"Problem {entity_name}" == entity:
models.append({"model": model, "filename": filename, "filepath": pkl_file})
elif dataset == "math":
# Format: qwen72_math_{number}.pkl
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
if f"Math Problem {entity_name}" == entity:
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
elif dataset == "aime":
# Format: aime_qwen72b_{number}.pkl
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
if f"AIME Problem {entity_name}" == entity:
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
else:
# Generic pattern for other datasets
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
task, entity_name, model = match.groups()
if entity_name == entity:
models.append({"model": model, "filename": filename, "filepath": pkl_file})
return jsonify({"models": models})
@app.route("/api/entities", methods=["GET"])
def get_entities():
"""Get available entities for a dataset"""
dataset = request.args.get("dataset")
if not dataset or dataset not in GRAPH_PATHS:
return jsonify({"error": "Invalid dataset"}), 400
path = GRAPH_PATHS[dataset]
if not os.path.exists(path):
return jsonify({"error": "Dataset path not found"}), 404
entities = []
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
for pkl_file in pkl_files:
filename = os.path.basename(pkl_file)
# Different filename patterns for different datasets
if dataset == "bio":
# Format: bio_graph_{entity}_merged_{model}.pkl
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
entities.append(
{
"entity": entity_name,
"model": model,
"filename": filename,
"filepath": pkl_file,
}
)
elif dataset == "fp":
# Format: fp_graph_{number}_merged_{model}.pkl
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
if match:
entity_name, model = match.groups()
entities.append(
{
"entity": f"Problem {entity_name}",
"model": model,
"filename": filename,
"filepath": pkl_file,
}
)
elif dataset == "math":
# Format: qwen72_math_{number}.pkl
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
entities.append(
{
"entity": f"Math Problem {entity_name}",
"model": "qwen72b",
"filename": filename,
"filepath": pkl_file,
}
)
elif dataset == "aime":
# Format: aime_qwen72b_{number}.pkl
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
if match:
entity_name = match.group(1)
entities.append(
{
"entity": f"AIME Problem {entity_name}",
"model": "qwen72b",
"filename": filename,
"filepath": pkl_file,
}
)
else:
# Generic pattern for other datasets
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
if match:
task, entity_name, model = match.groups()
entities.append(
{
"entity": entity_name,
"model": model,
"filename": filename,
"filepath": pkl_file,
}
)
# Get unique entities
unique_entities = {}
for entity_data in entities:
entity_key = entity_data["entity"]
if entity_key not in unique_entities:
unique_entities[entity_key] = entity_data
return jsonify({"entities": list(unique_entities.values())})
@app.route("/api/load_existing_graph", methods=["POST"])
def load_existing_graph():
"""Load an existing graph from pickle file"""
try:
data = request.get_json()
filepath = data.get("filepath")
if not filepath or not os.path.exists(filepath):
return jsonify({"error": "Invalid filepath"}), 400
# Load the graph from pickle
with open(filepath, "rb") as f:
graph = pickle.load(f)
# Convert to JSON format for vis.js
nodes = []
edges = []
# Get consensus nodes for coloring
try:
consensus_nodes = set(graph.consensus_node_ids)
except Exception:
consensus_nodes = set()
# Create nodes
for node in graph.nodeiterator()():
title_text = ""
if node.sequences:
title_text += f"Sequences: {node.sequences}"
if node.variations:
title_text += ";;;".join(
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
)
title_text = title_text.replace('"', "'")
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
node_data = {
"id": node.ID,
"label": f"{node.ID}: {node.text}",
"title": title_text,
"color": color,
}
nodes.append(node_data)
# Create edges
for node in graph.nodeiterator()():
nodeID = node.ID
for edge in node.outEdges:
target = edge
weight = node.outEdges[edge].weight + 1.5
edge_data = {
"from": nodeID,
"to": target,
"value": weight,
"color": "#cae0e6",
"arrows": "to",
}
edges.append(edge_data)
# Get consensus text
consensus_text = ""
try:
consensus_node_texts = []
for node in graph.nodeiterator()():
if node.ID in consensus_nodes and node.text and node.text.strip():
consensus_node_texts.append(node.text.strip())
consensus_text = " ".join(consensus_node_texts)
except Exception:
consensus_text = ""
# Get original sequences
try:
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
original_sequences = []
for seq in raw_sequences:
if isinstance(seq, list):
processed_seq = " ".join(str(item) for item in seq)
else:
processed_seq = str(seq)
processed_seq = processed_seq.replace("||", "")
original_sequences.append(processed_seq)
except Exception:
original_sequences = []
return jsonify(
{
"success": True,
"nodes": nodes,
"edges": edges,
"num_sequences": len(original_sequences),
"num_nodes": len(nodes),
"num_edges": len(edges),
"original_sequences": original_sequences,
"consensus_text": consensus_text,
}
)
except Exception as e:
return jsonify({"error": str(e)}), 500
# ============================================================================
# COMMENTED OUT: Create New Graph Feature - Disabled per user request
# ============================================================================
# @app.route("/api/create_graph", methods=["POST"])
# def create_graph():
# """Create a new POA graph from text sequences"""
# try:
# print("DEBUG: Received create_graph request")
# data = request.get_json()
# sequences = data.get("sequences", [])
#
# print(f"DEBUG: Number of sequences: {len(sequences)}")
#
# if len(sequences) < 2:
# return jsonify({"error": "At least 2 sequences are required"}), 400
#
# print("DEBUG: Creating initial graph")
# # Create the graph from first sequence
# graph = TextPOAGraph(sequences[0], label=0)
# print("DEBUG: Initial graph created")
#
# # Add remaining sequences
# for i, sequence in enumerate(sequences[1:], 1):
# print(f"DEBUG: Adding sequence {i}")
# alignment = TextSeqGraphAlignment(
# text=sequence,
# graph=graph,
# fastMethod=True,
# globalAlign=True,
# matchscore=1,
# mismatchscore=-2,
# gap_open=-1,
# )
# graph.incorporateSeqAlignment(alignment, sequence, label=i)
#
# print("DEBUG: All sequences added")
#
# # Refine the graph with proper domain and model parameters
# graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini")
# print("DEBUG: Graph refined")
#
# # Convert to JSON format for vis.js
# nodes = []
# edges = []
#
# try:
# print("DEBUG: Starting to process graph data")
# # Get consensus nodes for coloring (make it optional)
# try:
# consensus_nodes = set(graph.consensus_node_ids)
# print(f"DEBUG: Consensus nodes: {consensus_nodes}")
# except Exception as e:
# print(f"DEBUG: Error getting consensus nodes: {e}")
# consensus_nodes = set() # Fallback to empty set if consensus fails
#
# # Create nodes using the same logic as jsOutput
# for node in graph.nodeiterator()():
# title_text = ""
# if node.sequences:
# title_text += f"Sequences: {node.sequences}"
# if node.variations:
# title_text += ";;;".join(
# [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
# )
# title_text = title_text.replace('"', "'")
#
# # Use the same color logic as jsOutput
# color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
#
# node_data = {
# "id": node.ID,
# "label": f"{node.ID}: {node.text}",
# "title": title_text,
# "color": color,
# }
# nodes.append(node_data)
#
# print(f"DEBUG: Created {len(nodes)} nodes")
#
# # Create edges using the same logic as jsOutput
# for node in graph.nodeiterator()():
# nodeID = node.ID # Keep as integer
# for edge in node.outEdges:
# target = edge # Keep as integer
# weight = node.outEdges[edge].weight + 1.5
# edge_data = {
# "from": nodeID,
# "to": target,
# "value": weight,
# "color": "#cae0e6",
# "arrows": "to",
# }
# edges.append(edge_data)
#
# print(f"DEBUG: Created {len(edges)} edges")
# except Exception as e:
# print(f"DEBUG: Error processing graph data: {e}")
# return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
#
# # Extract text from consensus nodes
# consensus_text = ""
# try:
# consensus_node_texts = []
# for node in graph.nodeiterator()():
# if node.ID in consensus_nodes and node.text and node.text.strip():
# consensus_node_texts.append(node.text.strip())
# consensus_text = " ".join(consensus_node_texts)
# except Exception:
# consensus_text = ""
#
# # Check if we should compute consensus using decode_consensus
# compute_consensus = data.get("compute_consensus", False)
# if compute_consensus and decode_consensus:
# try:
# # Default to "bio" task for new graphs
# consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio")
# except Exception as e:
# print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
# # Keep the original consensus text if decode_consensus fails
#
# # Get original sequences
# try:
# raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
# # Process sequences: join with spaces and remove "||"
# original_sequences = []
# for seq in raw_sequences:
# if isinstance(seq, list):
# # Join list elements with spaces
# processed_seq = " ".join(str(item) for item in seq)
# else:
# processed_seq = str(seq)
# # Remove "||" characters
# processed_seq = processed_seq.replace("||", "")
# original_sequences.append(processed_seq)
# except Exception:
# original_sequences = []
#
# print("DEBUG: Returning success response")
# return jsonify(
# {
# "success": True,
# "nodes": nodes,
# "edges": edges,
# "num_sequences": len(sequences),
# "num_nodes": len(nodes),
# "num_edges": len(edges),
# "original_sequences": original_sequences,
# "consensus_text": consensus_text,
# }
# )
#
# except Exception as e:
# print(f"DEBUG: Main exception in create_graph: {e}")
# return jsonify({"error": str(e)}), 500
# @app.route("/api/save_graph", methods=["POST"])
# def save_graph():
# """Save a POA graph to a pickle file"""
# try:
# data = request.get_json()
# sequences = data.get("sequences", [])
# filename = data.get("filename", "graph.pkl")
#
# if len(sequences) < 2:
# return jsonify({"error": "At least 2 sequences are required"}), 400
#
# # Create the graph
# graph = TextPOAGraph(sequences[0], label=0)
#
# # Add remaining sequences
# for i, sequence in enumerate(sequences[1:], 1):
# alignment = TextSeqGraphAlignment(
# text=sequence,
# graph=graph,
# fastMethod=True,
# globalAlign=True,
# matchscore=1,
# mismatchscore=-2,
# gap_open=-1,
# )
# graph.incorporateSeqAlignment(alignment, sequence, label=i)
#
# # Refine the graph
# graph.refine_graph(verbose=False)
#
# # Save to pickle file
# graph.save_to_pickle(filename)
#
# return jsonify(
# {"success": True, "filename": filename, "message": f"Graph saved to {filename}"}
# )
#
# except Exception as e:
# return jsonify({"error": str(e)}), 500
# @app.route("/api/graph_info", methods=["POST"])
# def graph_info():
# """Get information about a graph without creating the full visualization"""
# try:
# data = request.get_json()
# sequences = data.get("sequences", [])
#
# if len(sequences) < 2:
# return jsonify({"error": "At least 2 sequences are required"}), 400
#
# # Create the graph
# graph = TextPOAGraph(sequences[0], label=0)
#
# # Add remaining sequences
# for i, sequence in enumerate(sequences[1:], 1):
# alignment = TextSeqGraphAlignment(
# text=sequence,
# graph=graph,
# fastMethod=True,
# globalAlign=True,
# matchscore=1,
# mismatchscore=-2,
# gap_open=-1,
# )
# graph.incorporateSeqAlignment(alignment, sequence, label=i)
#
# # Refine the graph
# graph.refine_graph(verbose=False)
#
# # Get consensus response
# consensus_text = graph.consensus_response()
#
# return jsonify(
# {
# "success": True,
# "num_sequences": len(sequences),
# "num_nodes": graph._nnodes,
# "consensus_text": consensus_text,
# "consensus_node_ids": graph.consensus_node_ids,
# }
# )
#
# except Exception as e:
# return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# For HF Spaces, port must be 7860
port = int(os.environ.get("PORT", 7860))
print("Starting POA Graph Web Interface Server...")
print(f"Repository root: {REPO_ROOT}")
print(f"Serving static files from: {os.path.join(REPO_ROOT, 'web_interface')}")
print(f"Open http://localhost:{port} in your browser")
app.run(debug=False, host="0.0.0.0", port=port)