Shahzaib98 commited on
Commit
102ae18
Β·
1 Parent(s): 2cb9f34

initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ build-essential \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements first for better caching
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Download spaCy model
19
+ RUN python -m spacy download en_core_web_sm
20
+
21
+ # Copy the entire application
22
+ COPY . .
23
+
24
+ # Expose port 7860 (required by Hugging Face Spaces)
25
+ EXPOSE 7860
26
+
27
+ # Use gunicorn for production
28
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "--timeout", "120", "--workers", "2", "app:app"]
README.md CHANGED
@@ -1,12 +1,155 @@
1
  ---
2
- title: ConGrs
3
- emoji: πŸ‘€
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: docker
 
7
  pinned: false
8
- license: apache-2.0
9
- short_description: Explore and visualize ConGrs (https://www.google.com/search)
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ConGr Visualizer
3
+ emoji: πŸ”—
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ license: mit
 
10
  ---
11
 
12
+ # ConGr Visualizer
13
+
14
+ A standalone web-based interface for exploring and visualizing ConGrs (Consensus Graphs) from research datasets.
15
+
16
+ ## Overview
17
+
18
+ This repository contains the web interface and necessary dependencies for visualizing ConGrs. It has been separated from the main sample-fusion repository to provide a standalone visualization tool.
19
+
20
+ ## Features
21
+
22
+ ### Browse Existing Graphs
23
+ - **Dataset Selection**: Choose from available datasets (BIO, FP, HIST, REFS, MATH, AIME)
24
+ - **Entity Selection**: Browse entities within each dataset
25
+ - **Model Information**: See which language models were used for each graph
26
+ - **Graph Visualization**: Interactive network visualization using vis.js
27
+ - **Metadata Display**: View graph statistics and consensus text
28
+
29
+ ### Create New Graphs
30
+ - **Text Input**: Enter multiple text sequences to create new ConGrs
31
+ - **Real-time Visualization**: See the graph structure as it's created
32
+ - **Save Functionality**: Save created graphs to pickle files
33
+
34
+ ## Deployment on Hugging Face Spaces
35
+
36
+ This application is configured to run on Hugging Face Spaces using Docker.
37
+
38
+ ### Project Structure
39
+
40
+ ```
41
+ congr-visualizer/
42
+ β”œβ”€β”€ Dockerfile # Docker configuration for HF Spaces
43
+ β”œβ”€β”€ app.py # Main Flask application
44
+ β”œβ”€β”€ requirements.txt # Python dependencies
45
+ β”œβ”€β”€ README.md # This file
46
+ β”œβ”€β”€ web_interface/ # Web interface files
47
+ β”‚ └── index.html # Web interface
48
+ β”œβ”€β”€ src/ # Source code modules
49
+ β”‚ β”œβ”€β”€ alignment.py
50
+ β”‚ β”œβ”€β”€ new_alignment.py
51
+ β”‚ β”œβ”€β”€ poa_graph.py
52
+ β”‚ β”œβ”€β”€ new_text_alignment.py
53
+ β”‚ β”œβ”€β”€ text_poa_graph.py
54
+ β”‚ β”œβ”€β”€ text_poa_graph_utils.py
55
+ β”‚ β”œβ”€β”€ global_edit_utils.py
56
+ β”‚ β”œβ”€β”€ generation_utils.py
57
+ β”‚ β”œβ”€β”€ generation_methods.py
58
+ β”‚ └── utils.py
59
+ └── results/ # Graph data
60
+ └── graphs/
61
+ └── HALoGEN/
62
+ β”œβ”€β”€ bio/
63
+ β”œβ”€β”€ fp/
64
+ β”œβ”€β”€ hist/
65
+ β”œβ”€β”€ refs/
66
+ β”œβ”€β”€ MATH/
67
+ └── AIME/
68
+ ```
69
+
70
+ ## Local Development
71
+
72
+ If you want to run this locally:
73
+
74
+ 1. Clone the repository:
75
+ ```bash
76
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/congr-visualizer
77
+ cd congr-visualizer
78
+ ```
79
+
80
+ 2. Install dependencies:
81
+ ```bash
82
+ pip install -r requirements.txt
83
+ python -m spacy download en_core_web_sm
84
+ ```
85
+
86
+ 3. Run the application:
87
+ ```bash
88
+ python app.py
89
+ ```
90
+
91
+ The server will start on `http://localhost:7860`
92
+
93
+ ## Available Datasets
94
+
95
+ - **BIO**: Biography datasets with various public figures
96
+ - **FP**: False Presupposition datasets
97
+ - **HIST**: Historical events datasets
98
+ - **REFS**: Reference datasets
99
+ - **MATH**: Mathematical problem datasets
100
+ - **AIME**: American Invitational Mathematics Examination datasets
101
+
102
+ ## Models
103
+
104
+ The graphs are generated using various language models:
105
+ - olmo7b, olmo32b
106
+ - qwen72b, qwen7b
107
+ - llama70b, llama8b
108
+
109
+ ## API Endpoints
110
+
111
+ - `GET /api/datasets` - Get available datasets
112
+ - `GET /api/entities?dataset=<dataset>` - Get entities for a dataset
113
+ - `GET /api/models?dataset=<dataset>&entity=<entity>` - Get models for an entity
114
+ - `POST /api/load_existing_graph` - Load an existing graph
115
+ - `POST /api/create_graph` - Create a new graph from text sequences
116
+ - `POST /api/save_graph` - Save a graph to file
117
+ - `POST /api/graph_info` - Get graph information without full visualization
118
+
119
+ ## Graph Information
120
+
121
+ When viewing a graph, you can see:
122
+ - **Dataset**: The source dataset
123
+ - **Entity**: The specific entity or topic
124
+ - **Model**: The language model used
125
+ - **Sequences**: Number of input sequences
126
+ - **Nodes**: Number of nodes in the graph
127
+ - **Edges**: Number of edges in the graph
128
+ - **Consensus**: The consensus text generated from the graph
129
+
130
+ ## Visualization Features
131
+
132
+ - **Hierarchical Layout**: Graphs are displayed in a hierarchical structure
133
+ - **Color Coding**: Consensus nodes are highlighted in green
134
+ - **Interactive**: Zoom, pan, and hover for more information
135
+ - **Responsive**: Works on desktop and mobile devices
136
+
137
+ ## Environment Variables
138
+
139
+ For full functionality (especially consensus decoding), you may need to set:
140
+ - `OPENAI_API_KEY`: For OpenAI API calls
141
+ - `HUGGINGFACE_API_KEY`: For HuggingFace API calls
142
+
143
+ These can be set in the Hugging Face Spaces settings.
144
+
145
+ ## Technical Details
146
+
147
+ - **Framework**: Flask with CORS enabled
148
+ - **Server**: Gunicorn for production
149
+ - **Port**: 7860 (required by Hugging Face Spaces)
150
+ - **Visualization**: vis.js for interactive graph rendering
151
+ - **Graph Format**: Pickle files for serialized POA graphs
152
+
153
+ ## License
154
+
155
+ This is a standalone visualization tool extracted from the sample-fusion research project.
app.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flask server for POA Graph Web Interface
4
+ Modified for Hugging Face Spaces deployment
5
+ """
6
+
7
+ import glob
8
+ import os
9
+ import pickle
10
+ import re
11
+ import sys
12
+
13
+ from flask import Flask, jsonify, request, send_from_directory
14
+ from flask_cors import CORS
15
+
16
+ # Get the directory where this script is located (should be project root)
17
+ REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
18
+
19
+ # Add the repository root to the path so we can import the POA graph modules
20
+ sys.path.append(REPO_ROOT)
21
+
22
+ from src.new_text_alignment import TextSeqGraphAlignment
23
+ from src.text_poa_graph import TextPOAGraph
24
+
25
+ try:
26
+ from src.generation_methods import decode_consensus
27
+ except ImportError:
28
+ decode_consensus = None
29
+
30
+ app = Flask(__name__)
31
+ CORS(app) # Enable CORS for all routes
32
+
33
+ # Base paths for different datasets (relative to repo root)
34
+ GRAPH_PATHS = {
35
+ "bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"),
36
+ "fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"),
37
+ "hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"),
38
+ "refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"),
39
+ "math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"),
40
+ "aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"),
41
+ }
42
+
43
+ MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"]
44
+
45
+
46
+ @app.route("/")
47
+ def index():
48
+ """Serve the main HTML file from web_interface directory"""
49
+ web_interface_path = os.path.join(REPO_ROOT, "web_interface")
50
+ return send_from_directory(web_interface_path, "index.html")
51
+
52
+
53
+ @app.route("/<path:path>")
54
+ def serve_static(path):
55
+ """Serve static files from web_interface directory"""
56
+ web_interface_path = os.path.join(REPO_ROOT, "web_interface")
57
+ return send_from_directory(web_interface_path, path)
58
+
59
+
60
+ @app.route("/api/datasets", methods=["GET"])
61
+ def get_datasets():
62
+ """Get available datasets"""
63
+ datasets = []
64
+ for dataset_name, path in GRAPH_PATHS.items():
65
+ if os.path.exists(path):
66
+ # Count available graphs
67
+ pkl_files = glob.glob(os.path.join(path, "*.pkl"))
68
+ datasets.append(
69
+ {
70
+ "name": dataset_name,
71
+ "display_name": dataset_name.upper(),
72
+ "path": path,
73
+ "count": len(pkl_files),
74
+ }
75
+ )
76
+ return jsonify({"datasets": datasets})
77
+
78
+
79
+ @app.route("/api/models", methods=["GET"])
80
+ def get_models():
81
+ """Get available models for a specific entity"""
82
+ entity = request.args.get("entity")
83
+ dataset = request.args.get("dataset")
84
+
85
+ if not entity:
86
+ return jsonify({"error": "Entity parameter required"}), 400
87
+
88
+ if not dataset or dataset not in GRAPH_PATHS:
89
+ return jsonify({"error": "Invalid dataset"}), 400
90
+
91
+ path = GRAPH_PATHS[dataset]
92
+ if not os.path.exists(path):
93
+ return jsonify({"error": "Dataset path not found"}), 404
94
+
95
+ models = []
96
+ pkl_files = glob.glob(os.path.join(path, "*.pkl"))
97
+
98
+ for pkl_file in pkl_files:
99
+ filename = os.path.basename(pkl_file)
100
+
101
+ # Different filename patterns for different datasets
102
+ if dataset == "bio":
103
+ # Format: bio_graph_{entity}_merged_{model}.pkl
104
+ match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
105
+ if match:
106
+ entity_name, model = match.groups()
107
+ if entity_name == entity:
108
+ models.append({"model": model, "filename": filename, "filepath": pkl_file})
109
+ elif dataset == "fp":
110
+ # Format: fp_graph_{number}_merged_{model}.pkl
111
+ match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
112
+ if match:
113
+ entity_name, model = match.groups()
114
+ if f"Problem {entity_name}" == entity:
115
+ models.append({"model": model, "filename": filename, "filepath": pkl_file})
116
+ elif dataset == "math":
117
+ # Format: qwen72_math_{number}.pkl
118
+ match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
119
+ if match:
120
+ entity_name = match.group(1)
121
+ if f"Math Problem {entity_name}" == entity:
122
+ models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
123
+ elif dataset == "aime":
124
+ # Format: aime_qwen72b_{number}.pkl
125
+ match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
126
+ if match:
127
+ entity_name = match.group(1)
128
+ if f"AIME Problem {entity_name}" == entity:
129
+ models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
130
+ else:
131
+ # Generic pattern for other datasets
132
+ match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
133
+ if match:
134
+ task, entity_name, model = match.groups()
135
+ if entity_name == entity:
136
+ models.append({"model": model, "filename": filename, "filepath": pkl_file})
137
+
138
+ return jsonify({"models": models})
139
+
140
+
141
+ @app.route("/api/entities", methods=["GET"])
142
+ def get_entities():
143
+ """Get available entities for a dataset"""
144
+ dataset = request.args.get("dataset")
145
+ if not dataset or dataset not in GRAPH_PATHS:
146
+ return jsonify({"error": "Invalid dataset"}), 400
147
+
148
+ path = GRAPH_PATHS[dataset]
149
+ if not os.path.exists(path):
150
+ return jsonify({"error": "Dataset path not found"}), 404
151
+
152
+ entities = []
153
+ pkl_files = glob.glob(os.path.join(path, "*.pkl"))
154
+
155
+ for pkl_file in pkl_files:
156
+ filename = os.path.basename(pkl_file)
157
+
158
+ # Different filename patterns for different datasets
159
+ if dataset == "bio":
160
+ # Format: bio_graph_{entity}_merged_{model}.pkl
161
+ match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
162
+ if match:
163
+ entity_name, model = match.groups()
164
+ entities.append(
165
+ {
166
+ "entity": entity_name,
167
+ "model": model,
168
+ "filename": filename,
169
+ "filepath": pkl_file,
170
+ }
171
+ )
172
+ elif dataset == "fp":
173
+ # Format: fp_graph_{number}_merged_{model}.pkl
174
+ match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
175
+ if match:
176
+ entity_name, model = match.groups()
177
+ entities.append(
178
+ {
179
+ "entity": f"Problem {entity_name}",
180
+ "model": model,
181
+ "filename": filename,
182
+ "filepath": pkl_file,
183
+ }
184
+ )
185
+ elif dataset == "math":
186
+ # Format: qwen72_math_{number}.pkl
187
+ match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
188
+ if match:
189
+ entity_name = match.group(1)
190
+ entities.append(
191
+ {
192
+ "entity": f"Math Problem {entity_name}",
193
+ "model": "qwen72b",
194
+ "filename": filename,
195
+ "filepath": pkl_file,
196
+ }
197
+ )
198
+ elif dataset == "aime":
199
+ # Format: aime_qwen72b_{number}.pkl
200
+ match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
201
+ if match:
202
+ entity_name = match.group(1)
203
+ entities.append(
204
+ {
205
+ "entity": f"AIME Problem {entity_name}",
206
+ "model": "qwen72b",
207
+ "filename": filename,
208
+ "filepath": pkl_file,
209
+ }
210
+ )
211
+ else:
212
+ # Generic pattern for other datasets
213
+ match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
214
+ if match:
215
+ task, entity_name, model = match.groups()
216
+ entities.append(
217
+ {
218
+ "entity": entity_name,
219
+ "model": model,
220
+ "filename": filename,
221
+ "filepath": pkl_file,
222
+ }
223
+ )
224
+
225
+ # Get unique entities
226
+ unique_entities = {}
227
+ for entity_data in entities:
228
+ entity_key = entity_data["entity"]
229
+ if entity_key not in unique_entities:
230
+ unique_entities[entity_key] = entity_data
231
+
232
+ return jsonify({"entities": list(unique_entities.values())})
233
+
234
+
235
+ @app.route("/api/load_existing_graph", methods=["POST"])
236
+ def load_existing_graph():
237
+ """Load an existing graph from pickle file"""
238
+ try:
239
+ data = request.get_json()
240
+ filepath = data.get("filepath")
241
+
242
+ if not filepath or not os.path.exists(filepath):
243
+ return jsonify({"error": "Invalid filepath"}), 400
244
+
245
+ # Load the graph from pickle
246
+ with open(filepath, "rb") as f:
247
+ graph = pickle.load(f)
248
+
249
+ # Convert to JSON format for vis.js
250
+ nodes = []
251
+ edges = []
252
+
253
+ # Get consensus nodes for coloring
254
+ try:
255
+ consensus_nodes = set(graph.consensus_node_ids)
256
+ except Exception:
257
+ consensus_nodes = set()
258
+
259
+ # Create nodes
260
+ for node in graph.nodeiterator()():
261
+ title_text = ""
262
+ if node.sequences:
263
+ title_text += f"Sequences: {node.sequences}"
264
+ if node.variations:
265
+ title_text += ";;;".join(
266
+ [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
267
+ )
268
+ title_text = title_text.replace('"', "'")
269
+
270
+ color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
271
+
272
+ node_data = {
273
+ "id": node.ID,
274
+ "label": f"{node.ID}: {node.text}",
275
+ "title": title_text,
276
+ "color": color,
277
+ }
278
+ nodes.append(node_data)
279
+
280
+ # Create edges
281
+ for node in graph.nodeiterator()():
282
+ nodeID = node.ID
283
+ for edge in node.outEdges:
284
+ target = edge
285
+ weight = node.outEdges[edge].weight + 1.5
286
+ edge_data = {
287
+ "from": nodeID,
288
+ "to": target,
289
+ "value": weight,
290
+ "color": "#cae0e6",
291
+ "arrows": "to",
292
+ }
293
+ edges.append(edge_data)
294
+
295
+ # Get consensus text
296
+ consensus_text = ""
297
+ try:
298
+ consensus_node_texts = []
299
+ for node in graph.nodeiterator()():
300
+ if node.ID in consensus_nodes and node.text and node.text.strip():
301
+ consensus_node_texts.append(node.text.strip())
302
+ consensus_text = " ".join(consensus_node_texts)
303
+ except Exception:
304
+ consensus_text = ""
305
+
306
+ # Get original sequences
307
+ try:
308
+ raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
309
+ original_sequences = []
310
+ for seq in raw_sequences:
311
+ if isinstance(seq, list):
312
+ processed_seq = " ".join(str(item) for item in seq)
313
+ else:
314
+ processed_seq = str(seq)
315
+ processed_seq = processed_seq.replace("||", "")
316
+ original_sequences.append(processed_seq)
317
+ except Exception:
318
+ original_sequences = []
319
+
320
+ return jsonify(
321
+ {
322
+ "success": True,
323
+ "nodes": nodes,
324
+ "edges": edges,
325
+ "num_sequences": len(original_sequences),
326
+ "num_nodes": len(nodes),
327
+ "num_edges": len(edges),
328
+ "original_sequences": original_sequences,
329
+ "consensus_text": consensus_text,
330
+ }
331
+ )
332
+
333
+ except Exception as e:
334
+ return jsonify({"error": str(e)}), 500
335
+
336
+
337
+ @app.route("/api/create_graph", methods=["POST"])
338
+ def create_graph():
339
+ """Create a new POA graph from text sequences"""
340
+ try:
341
+ print("DEBUG: Received create_graph request")
342
+ data = request.get_json()
343
+ sequences = data.get("sequences", [])
344
+
345
+ print(f"DEBUG: Number of sequences: {len(sequences)}")
346
+
347
+ if len(sequences) < 2:
348
+ return jsonify({"error": "At least 2 sequences are required"}), 400
349
+
350
+ print("DEBUG: Creating initial graph")
351
+ # Create the graph from first sequence
352
+ graph = TextPOAGraph(sequences[0], label=0)
353
+ print("DEBUG: Initial graph created")
354
+
355
+ # Add remaining sequences
356
+ for i, sequence in enumerate(sequences[1:], 1):
357
+ print(f"DEBUG: Adding sequence {i}")
358
+ alignment = TextSeqGraphAlignment(
359
+ text=sequence,
360
+ graph=graph,
361
+ fastMethod=True,
362
+ globalAlign=True,
363
+ matchscore=1,
364
+ mismatchscore=-2,
365
+ gap_open=-1,
366
+ )
367
+ graph.incorporateSeqAlignment(alignment, sequence, label=i)
368
+
369
+ print("DEBUG: All sequences added")
370
+
371
+ # Refine the graph with proper domain and model parameters
372
+ graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini")
373
+ print("DEBUG: Graph refined")
374
+
375
+ # Convert to JSON format for vis.js
376
+ nodes = []
377
+ edges = []
378
+
379
+ try:
380
+ print("DEBUG: Starting to process graph data")
381
+ # Get consensus nodes for coloring (make it optional)
382
+ try:
383
+ consensus_nodes = set(graph.consensus_node_ids)
384
+ print(f"DEBUG: Consensus nodes: {consensus_nodes}")
385
+ except Exception as e:
386
+ print(f"DEBUG: Error getting consensus nodes: {e}")
387
+ consensus_nodes = set() # Fallback to empty set if consensus fails
388
+
389
+ # Create nodes using the same logic as jsOutput
390
+ for node in graph.nodeiterator()():
391
+ title_text = ""
392
+ if node.sequences:
393
+ title_text += f"Sequences: {node.sequences}"
394
+ if node.variations:
395
+ title_text += ";;;".join(
396
+ [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
397
+ )
398
+ title_text = title_text.replace('"', "'")
399
+
400
+ # Use the same color logic as jsOutput
401
+ color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
402
+
403
+ node_data = {
404
+ "id": node.ID,
405
+ "label": f"{node.ID}: {node.text}",
406
+ "title": title_text,
407
+ "color": color,
408
+ }
409
+ nodes.append(node_data)
410
+
411
+ print(f"DEBUG: Created {len(nodes)} nodes")
412
+
413
+ # Create edges using the same logic as jsOutput
414
+ for node in graph.nodeiterator()():
415
+ nodeID = node.ID # Keep as integer
416
+ for edge in node.outEdges:
417
+ target = edge # Keep as integer
418
+ weight = node.outEdges[edge].weight + 1.5
419
+ edge_data = {
420
+ "from": nodeID,
421
+ "to": target,
422
+ "value": weight,
423
+ "color": "#cae0e6",
424
+ "arrows": "to",
425
+ }
426
+ edges.append(edge_data)
427
+
428
+ print(f"DEBUG: Created {len(edges)} edges")
429
+ except Exception as e:
430
+ print(f"DEBUG: Error processing graph data: {e}")
431
+ return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
432
+
433
+ # Extract text from consensus nodes
434
+ consensus_text = ""
435
+ try:
436
+ consensus_node_texts = []
437
+ for node in graph.nodeiterator()():
438
+ if node.ID in consensus_nodes and node.text and node.text.strip():
439
+ consensus_node_texts.append(node.text.strip())
440
+ consensus_text = " ".join(consensus_node_texts)
441
+ except Exception:
442
+ consensus_text = ""
443
+
444
+ # Check if we should compute consensus using decode_consensus
445
+ compute_consensus = data.get("compute_consensus", False)
446
+ if compute_consensus and decode_consensus:
447
+ try:
448
+ # Default to "bio" task for new graphs
449
+ consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio")
450
+ except Exception as e:
451
+ print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
452
+ # Keep the original consensus text if decode_consensus fails
453
+
454
+ # Get original sequences
455
+ try:
456
+ raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
457
+ # Process sequences: join with spaces and remove "||"
458
+ original_sequences = []
459
+ for seq in raw_sequences:
460
+ if isinstance(seq, list):
461
+ # Join list elements with spaces
462
+ processed_seq = " ".join(str(item) for item in seq)
463
+ else:
464
+ processed_seq = str(seq)
465
+ # Remove "||" characters
466
+ processed_seq = processed_seq.replace("||", "")
467
+ original_sequences.append(processed_seq)
468
+ except Exception:
469
+ original_sequences = []
470
+
471
+ print("DEBUG: Returning success response")
472
+ return jsonify(
473
+ {
474
+ "success": True,
475
+ "nodes": nodes,
476
+ "edges": edges,
477
+ "num_sequences": len(sequences),
478
+ "num_nodes": len(nodes),
479
+ "num_edges": len(edges),
480
+ "original_sequences": original_sequences,
481
+ "consensus_text": consensus_text,
482
+ }
483
+ )
484
+
485
+ except Exception as e:
486
+ print(f"DEBUG: Main exception in create_graph: {e}")
487
+ return jsonify({"error": str(e)}), 500
488
+
489
+
490
+ @app.route("/api/save_graph", methods=["POST"])
491
+ def save_graph():
492
+ """Save a POA graph to a pickle file"""
493
+ try:
494
+ data = request.get_json()
495
+ sequences = data.get("sequences", [])
496
+ filename = data.get("filename", "graph.pkl")
497
+
498
+ if len(sequences) < 2:
499
+ return jsonify({"error": "At least 2 sequences are required"}), 400
500
+
501
+ # Create the graph
502
+ graph = TextPOAGraph(sequences[0], label=0)
503
+
504
+ # Add remaining sequences
505
+ for i, sequence in enumerate(sequences[1:], 1):
506
+ alignment = TextSeqGraphAlignment(
507
+ text=sequence,
508
+ graph=graph,
509
+ fastMethod=True,
510
+ globalAlign=True,
511
+ matchscore=1,
512
+ mismatchscore=-2,
513
+ gap_open=-1,
514
+ )
515
+ graph.incorporateSeqAlignment(alignment, sequence, label=i)
516
+
517
+ # Refine the graph
518
+ graph.refine_graph(verbose=False)
519
+
520
+ # Save to pickle file
521
+ graph.save_to_pickle(filename)
522
+
523
+ return jsonify(
524
+ {"success": True, "filename": filename, "message": f"Graph saved to {filename}"}
525
+ )
526
+
527
+ except Exception as e:
528
+ return jsonify({"error": str(e)}), 500
529
+
530
+
531
+ @app.route("/api/graph_info", methods=["POST"])
532
+ def graph_info():
533
+ """Get information about a graph without creating the full visualization"""
534
+ try:
535
+ data = request.get_json()
536
+ sequences = data.get("sequences", [])
537
+
538
+ if len(sequences) < 2:
539
+ return jsonify({"error": "At least 2 sequences are required"}), 400
540
+
541
+ # Create the graph
542
+ graph = TextPOAGraph(sequences[0], label=0)
543
+
544
+ # Add remaining sequences
545
+ for i, sequence in enumerate(sequences[1:], 1):
546
+ alignment = TextSeqGraphAlignment(
547
+ text=sequence,
548
+ graph=graph,
549
+ fastMethod=True,
550
+ globalAlign=True,
551
+ matchscore=1,
552
+ mismatchscore=-2,
553
+ gap_open=-1,
554
+ )
555
+ graph.incorporateSeqAlignment(alignment, sequence, label=i)
556
+
557
+ # Refine the graph
558
+ graph.refine_graph(verbose=False)
559
+
560
+ # Get consensus response
561
+ consensus_text = graph.consensus_response()
562
+
563
+ return jsonify(
564
+ {
565
+ "success": True,
566
+ "num_sequences": len(sequences),
567
+ "num_nodes": graph._nnodes,
568
+ "consensus_text": consensus_text,
569
+ "consensus_node_ids": graph.consensus_node_ids,
570
+ }
571
+ )
572
+
573
+ except Exception as e:
574
+ return jsonify({"error": str(e)}), 500
575
+
576
+
577
+ if __name__ == "__main__":
578
+ # For HF Spaces, port must be 7860
579
+ port = int(os.environ.get("PORT", 7860))
580
+ print("Starting POA Graph Web Interface Server...")
581
+ print(f"Repository root: {REPO_ROOT}")
582
+ print(f"Serving static files from: {os.path.join(REPO_ROOT, 'web_interface')}")
583
+ print(f"Open http://localhost:{port} in your browser")
584
+ app.run(debug=False, host="0.0.0.0", port=port)
dockerignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git files
2
+ .git
3
+ .gitignore
4
+ .gitattributes
5
+
6
+ # Python cache
7
+ __pycache__
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+ .Python
12
+ *.egg-info/
13
+ dist/
14
+ build/
15
+
16
+ # Virtual environments
17
+ venv/
18
+ env/
19
+ ENV/
20
+
21
+ # IDE
22
+ .vscode/
23
+ .idea/
24
+ *.swp
25
+ *.swo
26
+ *~
27
+
28
+ # OS files
29
+ .DS_Store
30
+ Thumbs.db
31
+
32
+ # Logs
33
+ *.log
34
+
35
+ # Testing
36
+ .pytest_cache/
37
+ .coverage
38
+ htmlcov/
39
+
40
+ # Documentation
41
+ docs/
42
+ *.md.backup
43
+
44
+ # Large unnecessary files
45
+ *.tar.gz
46
+ *.zip
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask==2.3.3
2
+ flask-cors==4.0.0
3
+ numpy==1.24.3
4
+ tqdm==4.66.1
5
+ huggingface_hub==0.25.1
6
+ openai==1.63.2
7
+ python-dotenv==1.0.1
8
+ sentence_transformers==3.3.1
9
+ torch==2.5.1
10
+ transformers==4.46.3
11
+ nltk==3.9.1
12
+ spacy==3.7.6
13
+ gunicorn==21.2.0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (162 Bytes). View file
 
src/__pycache__/alignment.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
src/__pycache__/generation_methods.cpython-312.pyc ADDED
Binary file (11.9 kB). View file
 
src/__pycache__/generation_utils.cpython-312.pyc ADDED
Binary file (9.5 kB). View file
 
src/__pycache__/global_edit_utils.cpython-312.pyc ADDED
Binary file (5.54 kB). View file
 
src/__pycache__/new_alignment.cpython-312.pyc ADDED
Binary file (8.39 kB). View file
 
src/__pycache__/new_text_alignment.cpython-312.pyc ADDED
Binary file (7.24 kB). View file
 
src/__pycache__/poa_graph.cpython-312.pyc ADDED
Binary file (28.9 kB). View file
 
src/__pycache__/text_poa_graph.cpython-312.pyc ADDED
Binary file (34.7 kB). View file
 
src/__pycache__/text_poa_graph_utils.cpython-312.pyc ADDED
Binary file (6.21 kB). View file
 
src/alignment.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from Jonathan Dursi
3
+ https://github.com/ljdursi/poapy
4
+ """
5
+
6
+ import numpy
7
+
8
+
9
+ class SeqGraphAlignment(object):
10
+ __matchscore = 1
11
+ __mismatchscore = -2
12
+ __gap = -1
13
+
14
+ def __init__(
15
+ self,
16
+ sequence,
17
+ graph,
18
+ fastMethod=True,
19
+ globalAlign=False,
20
+ matchscore=__matchscore,
21
+ mismatchscore=__mismatchscore,
22
+ gapscore=__gap,
23
+ *args,
24
+ **kwargs,
25
+ ):
26
+ self._mismatchscore = mismatchscore
27
+ self._matchscore = matchscore
28
+ self._gap = gapscore
29
+ self.sequence = sequence
30
+ self.graph = graph
31
+ self.stringidxs = None
32
+ self.nodeidxs = None
33
+ self.globalAlign = globalAlign
34
+ if fastMethod:
35
+ matches = self.alignStringToGraphFast(*args, **kwargs)
36
+ else:
37
+ matches = self.alignStringToGraphSimple(*args, **kwargs)
38
+ self.stringidxs, self.nodeidxs = matches
39
+
40
+ def alignmentStrings(self):
41
+ return "".join(
42
+ self.sequence[i] if i is not None else "-" for i in self.stringidxs
43
+ ), "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs)
44
+
45
+ def matchscore(self, c1, c2):
46
+ if c1 == c2:
47
+ return self._matchscore
48
+ else:
49
+ return self._mismatchscore
50
+
51
+ def matchscoreVec(self, c, v):
52
+ return numpy.where(v == c, self._matchscore, self._mismatchscore)
53
+
54
+ def alignStringToGraphSimple(self):
55
+ """Align string to graph, following same approach as smith waterman
56
+ example"""
57
+ if type(self.sequence) is not str:
58
+ raise TypeError("Invalid Type")
59
+
60
+ nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
61
+ self.initializeDynamicProgrammingData()
62
+ )
63
+
64
+ # Dynamic Programming
65
+ ni = self.graph.nodeiterator()
66
+ for i, node in enumerate(ni()):
67
+ pbase = node.text
68
+
69
+ for j, sbase in enumerate(self.sequence):
70
+ # add all candidates to a list, pick the best
71
+ candidates = [(scores[i + 1, j] + self._gap, i + 1, j, "INS")]
72
+ for predIndex in self.prevIndices(node, nodeIDtoIndex):
73
+ candidates += [
74
+ (scores[predIndex + 1, j + 1] + self._gap, predIndex + 1, j + 1, "DEL")
75
+ ]
76
+ candidates += [
77
+ (
78
+ scores[predIndex + 1, j] + self.matchscore(sbase, pbase),
79
+ predIndex + 1,
80
+ j,
81
+ "MATCH",
82
+ )
83
+ ]
84
+
85
+ (
86
+ scores[i + 1, j + 1],
87
+ backGrphIdx[i + 1, j + 1],
88
+ backStrIdx[i + 1, j + 1],
89
+ movetype,
90
+ ) = max(candidates)
91
+
92
+ if not self.globalAlign and scores[i + 1, j + 1] < 0:
93
+ scores[i + 1, j + 1] = 0.0
94
+ backGrphIdx[i + 1, j + 1] = -1
95
+ backStrIdx[i + 1, j + 1] = -1
96
+
97
+ return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
98
+
99
+ def alignStringToGraphFast(self):
100
+ """Align string to graph - using numpy to vectorize across the string
101
+ at each iteration."""
102
+ if type(self.sequence) is not str:
103
+ raise TypeError("Invalid Type")
104
+
105
+ l2 = len(self.sequence)
106
+ seqvec = numpy.array(list(self.sequence))
107
+
108
+ nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
109
+ self.initializeDynamicProgrammingData()
110
+ )
111
+ inserted = numpy.zeros((l2), dtype=bool)
112
+
113
+ # having the inner loop as a function improves performance
114
+ # can use Cython, etc on this for significant further improvements
115
+ # can't vectorize this since there's a loop-carried dependency
116
+ # along the string
117
+ def insertions(i, l2, scores, inserted):
118
+ inserted[:] = False
119
+ for j in range(l2):
120
+ insscore = scores[i + 1, j] + self._gap
121
+ if insscore >= scores[i + 1, j + 1]:
122
+ scores[i + 1, j + 1] = insscore
123
+ inserted[j] = True
124
+
125
+ # Dynamic Programming
126
+ ni = self.graph.nodeiterator()
127
+ for i, node in enumerate(ni()):
128
+ gbase = node.text
129
+ predecessors = self.prevIndices(node, nodeIDtoIndex)
130
+
131
+ # calculate all best deletions, matches in one go over all
132
+ # predecessors.
133
+
134
+ # First calculate for the first predecessor, over all string posns:
135
+ deletescore = scores[predecessors[0] + 1, 1:] + self._gap
136
+ bestdelete = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
137
+
138
+ matchpoints = self.matchscoreVec(gbase, seqvec)
139
+ matchscore = scores[predecessors[0] + 1, 0:-1] + matchpoints
140
+ bestmatch = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
141
+
142
+ # then, the remaining
143
+ for predecessor in predecessors[1:]:
144
+ newdeletescore = scores[predecessor + 1, 1:] + self._gap
145
+ bestdelete = numpy.where(newdeletescore > deletescore, predecessor + 1, bestdelete)
146
+ deletescore = numpy.maximum(newdeletescore, deletescore)
147
+
148
+ gbase = self.graph.nodeIdxToBase(predecessor)
149
+ matchpoints = self.matchscoreVec(gbase, seqvec)
150
+ newmatchscore = scores[predecessor + 1, 0:-1] + matchpoints
151
+ bestmatch = numpy.where(newmatchscore > matchscore, predecessor + 1, bestmatch)
152
+ matchscore = numpy.maximum(newmatchscore, matchscore)
153
+
154
+ # choose best options available of match, delete
155
+ deleted = deletescore >= matchscore
156
+ backGrphIdx[i + 1, 1:] = numpy.where(deleted, bestdelete, bestmatch)
157
+ backStrIdx[i + 1, 1:] = numpy.where(
158
+ deleted, numpy.arange(1, l2 + 1), numpy.arange(0, l2)
159
+ )
160
+ scores[i + 1, 1:] = numpy.where(deleted, deletescore, matchscore)
161
+
162
+ # insertions: updated in place, don't depend on predecessors
163
+ insertions(i, l2, scores, inserted)
164
+ backGrphIdx[i + 1, 1:] = numpy.where(inserted, i + 1, backGrphIdx[i + 1, 1:])
165
+ backStrIdx[i + 1, 1:] = numpy.where(inserted, numpy.arange(l2), backStrIdx[i + 1, 1:])
166
+
167
+ # if we're doing local alignment, don't let bad global alignment
168
+ # drag us negative
169
+ if not self.globalAlign:
170
+ backGrphIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backGrphIdx[i + 1, :], -1)
171
+ backStrIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backStrIdx[i + 1, :], -1)
172
+ scores[i + 1, :] = numpy.maximum(scores[i + 1, :], 0)
173
+
174
+ return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
175
+
176
+ def prevIndices(self, node, nodeIDtoIndex):
177
+ """Return a list of the previous dynamic programming table indices
178
+ corresponding to predecessors of the current node."""
179
+ prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
180
+ # if no predecessors, point to just before the graph
181
+ if not prev:
182
+ prev = [-1]
183
+ return prev
184
+
185
+ def initializeDynamicProgrammingData(self):
186
+ """Initalize the dynamic programming tables:
187
+ - set up scores array
188
+ - set up backtracking array
189
+ - create index to Node ID table and vice versa"""
190
+ l1 = self.graph.nNodes
191
+ l2 = len(self.sequence)
192
+
193
+ nodeIDtoIndex = {}
194
+ nodeIndexToID = {-1: None}
195
+ # generate a dict of (nodeID) -> (index into nodelist (and thus matrix))
196
+ ni = self.graph.nodeiterator()
197
+ for index, node in enumerate(ni()):
198
+ nodeIDtoIndex[node.ID] = index
199
+ nodeIndexToID[index] = node.ID
200
+
201
+ # Dynamic Programming data structures; scores matrix and backtracking
202
+ # matrix
203
+ scores = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
204
+
205
+ # initialize insertion score
206
+ # if global align, penalty for starting at head != 0
207
+ if self.globalAlign:
208
+ scores[0, :] = numpy.arange(l2 + 1) * self._gap
209
+
210
+ ni = self.graph.nodeiterator()
211
+ for index, node in enumerate(ni()):
212
+ prevIdxs = self.prevIndices(node, nodeIDtoIndex)
213
+ best = scores[prevIdxs[0] + 1, 0]
214
+ for prevIdx in prevIdxs:
215
+ best = max(best, scores[prevIdx + 1, 0])
216
+ scores[index + 1, 0] = best + self._gap
217
+
218
+ # backtracking matrices
219
+ backStrIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
220
+ backGrphIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
221
+
222
+ return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx
223
+
224
+ def backtrack(self, scores, backStrIdx, backGrphIdx, nodeIndexToID):
225
+ """Backtrack through the scores and backtrack arrays.
226
+ Return a list of sequence indices and node IDs (not indices, which
227
+ depend on ordering)."""
228
+ besti, bestj = scores.shape
229
+ besti -= 1
230
+ bestj -= 1
231
+ if not self.globalAlign:
232
+ besti, bestj = numpy.argwhere(scores == numpy.amax(scores))[-1]
233
+ else:
234
+ ni = self.graph.nodeiterator()
235
+ # still have to find best final index to start from
236
+ terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
237
+ print(terminalIndices)
238
+ besti = terminalIndices[0] + 1
239
+ bestscore = scores[besti, bestj]
240
+ for i in terminalIndices[1:]:
241
+ score = scores[i + 1, bestj]
242
+ if score > bestscore:
243
+ bestscore, besti = score, i + 1
244
+
245
+ matches = []
246
+ strindexes = []
247
+ while (self.globalAlign or scores[besti, bestj] > 0) and (besti != 0 or bestj != 0):
248
+ nexti, nextj = backGrphIdx[besti, bestj], backStrIdx[besti, bestj]
249
+ curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
250
+
251
+ strindexes.insert(0, curstridx if nextj != bestj else None)
252
+ matches.insert(0, curnodeidx if nexti != besti else None)
253
+
254
+ besti, bestj = nexti, nextj
255
+
256
+ return strindexes, matches
src/generation_methods.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from src.generation_utils import (
4
+ extract_alternative_paths,
5
+ extract_context,
6
+ extract_equivalent_classes,
7
+ self_complete,
8
+ verify_correctness_pairwise,
9
+ )
10
+ from src.global_edit_utils import clean_up_text
11
+ from src.text_poa_graph import TextPOAGraph
12
+
13
+ """
14
+ Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
15
+ Only the primary variation of selected variable nodes are selected.
16
+ Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).
17
+
18
+ Args:
19
+ text_poa_graph: The TextPOAGraph object to decode.
20
+ selection_threshold: The threshold for selecting nodes.
21
+ model: The model to use for decoding.
22
+
23
+ Returns:
24
+ A string of the decoded text.
25
+ """
26
+
27
+
28
+ def decode_consensus(
29
+ text_poa_graph: TextPOAGraph,
30
+ selection_threshold: Optional[float] = 0.5,
31
+ task: str = "bio",
32
+ verbose: bool = False,
33
+ **kwargs,
34
+ ) -> str:
35
+ if text_poa_graph.failed:
36
+ return "Abstain"
37
+
38
+ text_poa_graph.toposort()
39
+
40
+ consensus_node_ids = text_poa_graph.consensus_node_ids
41
+
42
+ selected_node_ids = []
43
+
44
+ for node_id in consensus_node_ids:
45
+ if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
46
+ continue
47
+
48
+ selected_node_ids.append(node_id)
49
+
50
+ for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
51
+ if neighbor_id in consensus_node_ids:
52
+ continue
53
+
54
+ if (
55
+ len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
56
+ >= selection_threshold
57
+ ):
58
+ selected_node_ids.append(neighbor_id)
59
+
60
+ texts = []
61
+ for node_id in selected_node_ids:
62
+ if not text_poa_graph.nodedict[node_id].variations:
63
+ texts.append(text_poa_graph.nodedict[node_id].text)
64
+ else:
65
+ all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
66
+ all_texts.append(text_poa_graph.nodedict[node_id].text)
67
+ # select the variation that is longest
68
+ texts.append(max(all_texts, key=len))
69
+ text = " ".join(texts)
70
+ edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
71
+ if verbose:
72
+ return text, edited_text
73
+ else:
74
+ return edited_text
75
+
76
+
77
+ def decode_self_verified(
78
+ text_poa_graph: TextPOAGraph,
79
+ problem: str,
80
+ uncertainty_threshold: float = 0.6,
81
+ verification_api: str = "openai",
82
+ verification_model: str = "gpt-4o-mini",
83
+ grace_period: bool = True,
84
+ ):
85
+ high_uncertainty_nodes = []
86
+ for node_id in text_poa_graph.consensus_node_ids:
87
+ if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
88
+ continue
89
+
90
+ outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
91
+ branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences
92
+
93
+ if branching_factor > uncertainty_threshold:
94
+ high_uncertainty_nodes.append(node_id)
95
+
96
+ selected_labels = list(text_poa_graph._seq_paths.keys())
97
+ masked_candidates = {}
98
+ uncertain_region = False
99
+ for label in selected_labels:
100
+ text = ""
101
+ for node_id in text_poa_graph._seq_paths[label]:
102
+ if uncertain_region:
103
+ text += f" *START_SEPARATOR*_{node_id} "
104
+ if node_id in high_uncertainty_nodes:
105
+ uncertain_region = True
106
+
107
+ if len(text_poa_graph.nodedict[node_id].variations) > 0:
108
+ text += text_poa_graph.nodedict[node_id].variations[label]
109
+ text += " "
110
+ else:
111
+ text += text_poa_graph.nodedict[node_id].text
112
+ text += " "
113
+
114
+ if uncertain_region and node_id not in high_uncertainty_nodes:
115
+ text += f" *END_SEPARATOR*_{node_id} "
116
+ uncertain_region = False
117
+ masked_candidates[label] = text
118
+
119
+ patch_start_node = None
120
+ uncertain_ids = []
121
+
122
+ # give a grace period for the first incorrect step
123
+ prev_step = {label: None for label in selected_labels}
124
+
125
+ for node_id in high_uncertainty_nodes:
126
+ uncertain_ids.append(node_id)
127
+ context_before = extract_context(text_poa_graph, node_id)
128
+ alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
129
+ equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
130
+ new_labels = selected_labels.copy()
131
+
132
+ # Only do self-verifaction for labels from different sematically equivalent branches
133
+ if len(equivalent_classes) <= 1:
134
+ continue
135
+ i = 0
136
+ while i < len(equivalent_classes):
137
+ if i + 1 < len(equivalent_classes):
138
+ label_a = equivalent_classes[i][0]
139
+ label_b = equivalent_classes[i + 1][0]
140
+ full_a = context_before[label_a] + alternative_paths[label_a]
141
+ full_b = context_before[label_b] + alternative_paths[label_b]
142
+
143
+ score = verify_correctness_pairwise(
144
+ full_text_1=full_a,
145
+ full_text_2=full_b,
146
+ verification_model=verification_model,
147
+ problem=problem,
148
+ api=verification_api,
149
+ )
150
+ if float(score[0]) < 1.0:
151
+ print(f"Label {label_a} is incorrect at node {node_id}")
152
+ masked_candidates[label_a] = (
153
+ masked_candidates[label_a]
154
+ .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
155
+ .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
156
+ )
157
+ if not prev_step[label_a]:
158
+ prev_step[label_a] = True
159
+ if prev_step[label_a] and grace_period or not grace_period:
160
+ for label_i in equivalent_classes[i]:
161
+ new_labels.remove(label_i)
162
+ print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
163
+ if float(score[0]) == 1.0:
164
+ prev_step[label_a] = False
165
+ if float(score[1]) < 1.0:
166
+ print(f"Label {label_b} is incorrect at node {node_id}")
167
+ masked_candidates[label_b] = (
168
+ masked_candidates[label_b]
169
+ .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
170
+ .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
171
+ )
172
+ if not prev_step[label_b]:
173
+ prev_step[label_b] = True
174
+ if prev_step[label_b] and grace_period or not grace_period:
175
+ for label_i in equivalent_classes[i + 1]:
176
+ new_labels.remove(label_i)
177
+ print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
178
+ if float(score[1]) == 1.0:
179
+ prev_step[label_b] = False
180
+ i += 2
181
+ else:
182
+ break
183
+
184
+ if len(new_labels) == 0:
185
+ patch_start_node = node_id
186
+ break
187
+
188
+ selected_labels = new_labels.copy()
189
+
190
+ # These are the pruned approaches with masking
191
+ print(masked_candidates)
192
+ masked_approaches = "\n".join(
193
+ [
194
+ f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
195
+ for label in selected_labels
196
+ ]
197
+ )
198
+ # These are all approaches with masking
199
+ all_approaches = "\n".join(
200
+ [f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
201
+ )
202
+
203
+ default_prompt = f"""
204
+ Solve the following math problem with mathematical precision and clarity.
205
+
206
+ Problem: {problem}
207
+
208
+ Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*).
209
+ These sections may contain conceptual or computational errors.
210
+
211
+ There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
212
+ A verification step indicated that these steps are highly likely to contain errors.
213
+
214
+ Potential Approaches:
215
+ {masked_approaches}
216
+
217
+ Your task:
218
+ 1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
219
+ If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
220
+ 2. Using the sections with special markers, identify potential errors.
221
+ 3. Develop a rigorous, step-by-step solution based on sound mathematical principles
222
+ 4. For uncertain regions:
223
+ - Verify each step using algebraic or numerical validation
224
+ - If correct, incorporate these steps with appropriate justification
225
+ - If incorrect, provide clear corrections with mathematical reasoning for your changes
226
+ 5. Follow a comparative approach, using the differences between approaches to identify potential errors.
227
+ 6. Do not blindly follow the approaches, but rather use them to identify potential errors.
228
+
229
+ Guidelines for your solution:
230
+ - Begin with a strategic overview of your chosen approach
231
+ - Present each mathematical step with clear notation and justification
232
+ - Pay special attention to areas that were previously marked uncertain
233
+
234
+ Conclude your solution with:
235
+ Therefore, the final answer is: $\\boxed{{answer}}$.
236
+
237
+ Solution:
238
+ """
239
+
240
+ patch_prompt = f"""
241
+ Solve the following mathematical problem with precision and clarity.
242
+
243
+ Problem: {problem}
244
+
245
+ You have been provided with several partial solution approaches that attempted to solve this problem.
246
+ None of these approaches are correct, but may contain valuable insights.
247
+ Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
248
+ A verification step indicated that these steps are likely to contain errors.
249
+
250
+ INSTRUCTIONS:
251
+ 1. Synthesize a correct solution using insights from the previous approaches
252
+ 2. Pay special attention to fixing the problematic areas marked by separators
253
+ 3. Develop your solution step-by-step, showing clear mathematical reasoning
254
+ 4. Focus especially on mathematical correctness in areas where previous solutions diverged
255
+ 5. Present your work in a logical, sequential manner suitable for an advanced reader
256
+
257
+ GUIDELINES FOR MATHEMATICAL RIGOR:
258
+ 1. MAINTAIN MATHEMATICAL RIGOR
259
+ - Verify that all mathematical operations follow from established principles and definitions
260
+ - Ensure dimensional consistency throughout calculations
261
+ - Check that algebraic manipulations preserve equality and do not introduce errors
262
+
263
+ 2. CONSIDER ALTERNATIVE PERSPECTIVES
264
+ - Even when approaches reach the same conclusion, examine their reasoning independently
265
+ - Look for more elegant or insightful connections that may be missed across all approaches
266
+ - Consider whether fundamental mathematical principles suggest a different path
267
+
268
+ 3. CRITICAL VALIDATION
269
+ - Test conclusions using known mathematical properties and relationships
270
+ - When possible, verify results using alternative methods
271
+ - Be especially cautious when all approaches agree on a result but use similar reasoning
272
+
273
+ 4. USE PRECISION IN CORRECTIONS
274
+ - When correcting uncertain regions, specify exactly what was incorrect and why
275
+ - Provide clear mathematical justification for any changes
276
+ - Ensure corrections align with standard mathematical principles and notations
277
+
278
+ Previous Approaches (for reference only):
279
+ {all_approaches}
280
+
281
+ Your Solution:
282
+ [Begin with a clear statement of your approach]
283
+ [Provide detailed mathematical steps]
284
+ [Ensure correct handling of complex mathematical operations]
285
+ [Verify your work at key points, especially in previously problematic areas]
286
+
287
+ Always conclude with:
288
+ Therefore, the final answer is: $\\boxed{{answer}}$
289
+ """
290
+
291
+ if patch_start_node is not None or len(masked_candidates.keys()) == 1:
292
+ print("None correct, patching")
293
+ prompt = patch_prompt
294
+ else:
295
+ prompt = default_prompt
296
+
297
+ return self_complete(
298
+ verification_prompt=prompt, verification_model=verification_model, api=verification_api
299
+ ), masked_candidates
src/generation_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from openai import OpenAI
5
+ from together import Together
6
+
7
+ from src.text_poa_graph import TextPOAGraph
8
+
9
+
10
+ def extract_context(text_poa_graph, node_id):
11
+ """Extract context up to and including the specified node_id."""
12
+ contexts = {}
13
+ for label, path in text_poa_graph._seq_paths.items():
14
+ idx = path.index(node_id)
15
+ context = path[: idx + 1]
16
+ contexts[label] = " ".join(
17
+ text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
18
+ for nid in context
19
+ )
20
+ return contexts
21
+
22
+
23
+ def extract_alternative_paths(text_poa_graph: TextPOAGraph, node_id):
24
+ """Extract all alternative paths from this uncertainty point to the next consensus node."""
25
+ alternative_paths = {}
26
+ for label, path in text_poa_graph._seq_paths.items():
27
+ idx = path.index(node_id)
28
+ next_cn = None
29
+ for i in range(idx + 1, len(path)):
30
+ if path[i] in text_poa_graph.consensus_node_ids:
31
+ next_cn = path[i]
32
+ break
33
+
34
+ if next_cn:
35
+ next_cn_idx = path.index(next_cn)
36
+ alternative_segment = path[idx + 1 : next_cn_idx + 1]
37
+ else:
38
+ alternative_segment = []
39
+
40
+ alternative_paths[label] = " ".join(
41
+ text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
42
+ for nid in alternative_segment
43
+ )
44
+ return alternative_paths
45
+
46
+
47
+ def is_same_branch(text_poa_graph: TextPOAGraph, node_id, lable_1, label_2):
48
+ """Check if the next vaiable nodes for two sequences are the same after node_id."""
49
+ path_1 = text_poa_graph._seq_paths[lable_1]
50
+ path_2 = text_poa_graph._seq_paths[label_2]
51
+ idx_1 = path_1.index(node_id)
52
+ idx_2 = path_2.index(node_id)
53
+ return path_1[idx_1 + 1] == path_2[idx_2 + 1]
54
+
55
+
56
+ def extract_equivalent_classes(text_poa_graph: TextPOAGraph, node_id, selected_labels):
57
+ """Extract equivalent classes from the text POA graph."""
58
+ if not selected_labels:
59
+ return []
60
+
61
+ equivalent_classes = []
62
+ for label in selected_labels:
63
+ matched = False
64
+ for class_group in equivalent_classes:
65
+ if is_same_branch(text_poa_graph, node_id, class_group[0], label):
66
+ class_group.append(label)
67
+ matched = True
68
+ break
69
+ if not matched:
70
+ equivalent_classes.append([label])
71
+ return equivalent_classes
72
+
73
+
74
+ def verify_correctness_pairwise(
75
+ full_text_1: str, full_text_2: str, verification_model: str, problem: str, api: str = "openai"
76
+ ):
77
+ """Pairwise verification of two partial solution paths."""
78
+ if api == "openai":
79
+ client = OpenAI()
80
+ elif api == "hf":
81
+ client = InferenceClient()
82
+ elif api == "together":
83
+ client = Together()
84
+ else:
85
+ raise ValueError(f"Invalid API: {api}")
86
+
87
+ prompt = f"""
88
+ You will be given a problem and 2 partial solutions.
89
+ Your task is to use comparison as an EFFICIENCY TOOL to quickly identify potential errors.
90
+ You will be given guidelines to follow, and you will be penalized if you do not follow them.
91
+
92
+ Problem: {problem}
93
+
94
+ Partial Solution 1: {full_text_1}
95
+ Partial Solution 2: {full_text_2}
96
+
97
+ CRITICAL GUIDELINES:
98
+ - DO NOT penalize a solution for being incomplete or having missing steps
99
+ - DO NOT make a comparison of which solution is better
100
+ - DO NOT consider steps incorrect just because they differ between solutions
101
+ - DO NOT prematurely evaluate based on final answers or future steps
102
+ - DO NOT expect both solutions to be at the same stage of completion
103
+ - DO NOT consider a step incorrect just because it lacks sufficient detail or justification
104
+
105
+ KEY EFFICIENCY PRINCIPLE:
106
+ - Use agreement between solutions as evidence of correctness
107
+ - Use disagreement as a signal to investigate more deeply
108
+ - Only label a step as an error if it contains a specific mathematical mistake
109
+ - Incompleteness is not a mathematical error.
110
+
111
+ Here are the instructions for how to complete your task:
112
+
113
+ EFFICIENT VERIFICATION APPROACH:
114
+
115
+ 1. QUICK COMPARISON (Use this to focus your attention):
116
+ - Immediately identify where the solutions differ in approach or results
117
+ - Use these differences as "error hotspots" to prioritize your verification
118
+ - When solutions agree, you can generally assume that part is correct
119
+ - When solutions disagree, investigate those specific points deeply
120
+
121
+ 2. TARGETED VERIFICATION (Only where needed):
122
+ - Most important: Do not consider any incomplete steps as errors
123
+ - Focus your mathematical verification on the "hotspots" identified above
124
+ - Check mathematical validity only at points of difference or uncertainty
125
+ - Avoid line-by-line checking of steps where solutions agree
126
+ - For each potential error spot, verify if the mathematical reasoning is valid
127
+ - If an intermediate step is later corrected, do not penalize the solution for having the incorrect intermediate step
128
+
129
+ After your targeted verification, propose a score tuple (score_1, score_2):
130
+ - Score (1,1) if both partial solutions are valid
131
+ - Score (1,0) if only the first solution is valid
132
+ - Score (0,1) if only the second solution is valid
133
+ - Score (0,0) if both solutions are invalid
134
+
135
+ In case you score a solution as 0, you must give an explanation for each check below:
136
+ 3. FINAL CHECKS:
137
+ - If you score a solution as 0, you MUST identify the specific mathematical error.
138
+ - You must also double check the problem statement. Reconsider your score and determine if you have misinterpreted the problem statement.
139
+ - You must also check whether you have penalized a solution for being incomplete or having missing steps.
140
+
141
+ Before outputting your final score, you must answer these questions:
142
+ STOP! Did you give a score of 0 to a solution that was incomplete?
143
+ STOP! Did you penalize a solution for being incomplete or having missing steps?
144
+ STOP! Did you make a comparison of which solution is better?
145
+ STOP! Did you consider steps incorrect just because they differ between solutions?
146
+ STOP! Did you prematurely evaluate based on final answers?
147
+ STOP! Did you consider a step incorrect just because it lacks sufficient detail or justification?
148
+
149
+ Now give your final score:
150
+ Final score:
151
+ """
152
+ completion = client.chat.completions.create(
153
+ model=verification_model,
154
+ messages=[
155
+ {"role": "system", "content": "You are a helpful assistant."},
156
+ {"role": "user", "content": prompt},
157
+ ],
158
+ temperature=0.0,
159
+ )
160
+ response = completion.choices[0].message.content.strip()
161
+ print(full_text_1)
162
+ print(full_text_2)
163
+ print(f"Correctness score: {response} \n")
164
+ score_match = re.findall(r"\(\s*([01](?:\.0)?)\s*,\s*([01](?:\.0)?)\s*\)", response)
165
+ score = score_match[-1] if score_match else (0, 0)
166
+ return score
167
+
168
+
169
+ def self_complete(verification_prompt: str, verification_model: str, api: str = "openai"):
170
+ print(verification_prompt)
171
+ """Completetion method"""
172
+ if api == "openai":
173
+ client = OpenAI()
174
+ elif api == "hf":
175
+ client = InferenceClient()
176
+ elif api == "together":
177
+ client = Together()
178
+ else:
179
+ raise ValueError(f"Invalid API: {api}")
180
+
181
+ completion = client.chat.completions.create(
182
+ model=verification_model,
183
+ messages=[
184
+ {"role": "system", "content": "You are a helpful assistant."},
185
+ {"role": "user", "content": verification_prompt},
186
+ ],
187
+ temperature=0.0,
188
+ )
189
+ response = completion.choices[0].message.content.strip()
190
+ return response
src/global_edit_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ from openai import OpenAI
3
+
4
+ bio_prompt = """
5
+ You are given a piece of text that is a part of a biography of an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
6
+ Then, remove any redundant information.
7
+ Text: {text}
8
+
9
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
10
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
11
+ Only return the cleaned up text. Do not include any other text:
12
+ """
13
+
14
+ fp_prompt = """
15
+ You are given a piece of text that is a part of a false presupposition task which includes outputting a list of items.
16
+ This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
17
+ Then, remove any redundant information.
18
+ Text: {text}
19
+
20
+ The resulting list of items should be separated by semicolons with no other text.
21
+ If this list it not possible to generate, return "Abstain".
22
+ """
23
+
24
+ hist_prompt = """
25
+ You are given a piece of text that is a part of a historical event task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
26
+ Then, remove any redundant information.
27
+ Text: {text}
28
+
29
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
30
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
31
+ Only return the cleaned up text. Do not include any other text:
32
+ """
33
+
34
+ refs_prompt = """
35
+ You are given a piece of text that is a part of a reference task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
36
+ Then, remove any redundant information.
37
+ Text: {text}
38
+
39
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
40
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
41
+ Only return the cleaned up text. Do not include any other text:
42
+ """
43
+
44
+ gpqa_prompt = """
45
+ You are given a piece of text that is a part of a graduate level question answering task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
46
+ Then, remove any redundant information.
47
+ Text: {text}
48
+ Only return the cleaned up text. Do not include any other text:
49
+ """
50
+
51
+ popqa_prompt = """
52
+ You are given a piece of text that is a part of a paragraph which details facts related to an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
53
+ Then, remove any redundant information.
54
+ Text: {text}
55
+
56
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
57
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
58
+ Only return the cleaned up text. Do not include any other text:
59
+ """
60
+ task_to_prompt = {
61
+ "bio": bio_prompt,
62
+ "fp": fp_prompt,
63
+ "hist": hist_prompt,
64
+ "refs": refs_prompt,
65
+ "gpqa": gpqa_prompt,
66
+ "popqa": popqa_prompt
67
+ }
68
+
69
+ '''
70
+ Cleans up disfluencies in the draft response in consensus decoding.
71
+
72
+ Args:
73
+ text: The text to clean up.
74
+ api: The API to use for cleaning up the text.
75
+ task: The task : biography, false presupposition, historical event, reference, graduate question answering, paragraph question answering.
76
+ model: The model to use for cleaning up the text.
77
+
78
+ Returns:
79
+ A string of the cleaned up text.
80
+ '''
81
+
82
+ def clean_up_text(text: str, api: str, task: str, model: str = "gpt-4.1-mini", **kwargs):
83
+ if api == "openai":
84
+ client = OpenAI()
85
+ elif api == "hf":
86
+ tokenizer = kwargs.get("tokenizer")
87
+ model = kwargs.get("hf_model")
88
+
89
+ if tokenizer is None or model is None:
90
+ raise ValueError("For 'hf', both 'tokenizer' and 'model' must be provided.")
91
+
92
+ clean_up_prompt = task_to_prompt[task].format(text=text)
93
+
94
+ messages = [{"role": "user", "content": clean_up_prompt}]
95
+ input_ids = tokenizer.apply_chat_template(
96
+ messages,
97
+ add_generation_prompt=True,
98
+ return_tensors="pt"
99
+ ).to(model.device)
100
+
101
+ terminators = [ tokenizer.eos_token_id, ]
102
+ outputs = model.generate(
103
+ input_ids,
104
+ max_new_tokens=500,
105
+ do_sample=False,
106
+ pad_token_id=tokenizer.eos_token_id,
107
+ eos_token_id=terminators,
108
+ )
109
+
110
+ return tokenizer.decode(
111
+ outputs[0][input_ids.shape[-1]:],
112
+ skip_special_tokens=True
113
+ ).strip()
114
+ else:
115
+ raise ValueError(f"Invalid API: {api}")
116
+
117
+ clean_up_prompt = task_to_prompt[task].format(text=text)
118
+
119
+ completion = client.chat.completions.create(
120
+ model=model,
121
+ messages=[
122
+ {"role": "system", "content": "You are a helpful assistant."},
123
+ {"role": "user", "content": clean_up_prompt},
124
+ ],
125
+ )
126
+
127
+ return completion.choices[0].message.content.strip()
src/new_alignment.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+
3
+
4
+ class ScoreParam:
5
+ def __init__(self, match, mismatch, gap_open, gap_extend):
6
+ self.match = match
7
+ self.mismatch = mismatch
8
+ self.gap_open = gap_open
9
+ self.gap_extend = gap_extend
10
+
11
+ def __str__(self):
12
+ return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}"
13
+
14
+
15
+ class SeqGraphAlignment(object):
16
+ __default_score = ScoreParam(1, -3, -2, -1)
17
+
18
+ def __init__(
19
+ self,
20
+ sequence,
21
+ graph,
22
+ fastMethod=True,
23
+ globalAlign=False,
24
+ score_params=__default_score,
25
+ *args,
26
+ **kwargs,
27
+ ):
28
+ self.score = score_params
29
+ self.sequence = sequence
30
+ self.graph = graph
31
+ self.stringidxs = None
32
+ self.nodeidxs = None
33
+ self.globalAlign = globalAlign
34
+ if fastMethod:
35
+ matches = self.alignStringToGraphFast(*args, **kwargs)
36
+ else:
37
+ matches = self.alignStringToGraphSimple(*args, **kwargs)
38
+ self.stringidxs, self.nodeidxs = matches
39
+
40
+ def alignmentStrings(self):
41
+ return (
42
+ "".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs),
43
+ "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs),
44
+ )
45
+
46
+ def matchscore(self, c1, c2):
47
+ if c1 == c2:
48
+ return self.score.match
49
+ else:
50
+ return self.score.mismatch
51
+
52
+ def matchscoreVec(self, c, v):
53
+ return numpy.where(v == c, self.score.match, self.score.mismatch)
54
+
55
+ def prevIndices(self, node, nodeIDtoIndex):
56
+ prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
57
+ if not prev:
58
+ prev = [-1]
59
+ return prev
60
+
61
+ def initializeDynamicProgrammingData(self):
62
+ l1 = self.graph.nNodes
63
+ l2 = len(self.sequence)
64
+
65
+ nodeIDtoIndex = {}
66
+ nodeIndexToID = {-1: None}
67
+ ni = self.graph.nodeiterator()
68
+ for index, node in enumerate(ni()):
69
+ nodeIDtoIndex[node.ID] = index
70
+ nodeIndexToID[index] = node.ID
71
+
72
+ scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
73
+
74
+ if self.globalAlign:
75
+ # M[0, i] = -inf
76
+ scores[0, 0, :] = [
77
+ -1000000000 for i in range(l2+1)
78
+ ]
79
+ scores[0, 0, 0] = 0
80
+ # X[0, i] = gap_open + i * gap_extend
81
+ scores[1, 0, :] = [
82
+ self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1)
83
+ ]
84
+ scores[1, 0, 0] = -1000000000
85
+ # Y[0, i] = -inf
86
+ scores[2, 0, :] = [
87
+ -1000000000 for i in range(l2+1)
88
+ ]
89
+
90
+ ni = self.graph.nodeiterator()
91
+ # After topology sort, the predcessors will have index less than the current node
92
+ for index, node in enumerate(ni()):
93
+ scores[0, index + 1, 0] = -1000000000
94
+ scores[1, index + 1, 0] = -1000000000
95
+ prevIdxs = self.prevIndices(node, nodeIDtoIndex)
96
+ best = scores[2 ,prevIdxs[0] + 1, 0]
97
+ for prevIdx in prevIdxs:
98
+ best = max(best, scores[2, prevIdx + 1, 0])
99
+ # If we have no predecessors, we start the gap
100
+ if prevIdxs == [-1]:
101
+ scores[2, index + 1, 0] = self.score.gap_open + self.score.gap_extend
102
+ else:
103
+ scores[2, index + 1, 0] = best + self.score.gap_extend
104
+
105
+ # 3D Backtracking
106
+ backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
107
+ backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
108
+ backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
109
+
110
+ return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx
111
+
112
+ def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID):
113
+ besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1
114
+ #Storing best matrices for each [i,j]
115
+ scores_arr = numpy.array(scores)
116
+ max_m = numpy.argmax(scores_arr, axis=0)
117
+
118
+ if self.globalAlign:
119
+ ni = self.graph.nodeiterator()
120
+ # Finding the best node to start from
121
+ terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
122
+ print(terminalIndices)
123
+ besti = terminalIndices[0] + 1
124
+ bestscore = scores[max_m[besti, bestj], besti, bestj]
125
+ for i in terminalIndices[1:]:
126
+ score = scores[max_m[i + 1, bestj], i + 1, bestj]
127
+ if score > bestscore:
128
+ bestscore, besti = score, i + 1
129
+ bestm = max_m[besti, bestj]
130
+
131
+ matches = []
132
+ strindexes = []
133
+
134
+ while (besti != 0 or bestj != 0):
135
+ nextm, nexti, nextj, = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj]
136
+ curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
137
+
138
+ if bestm == 0:
139
+ matches.insert(0, curnodeidx)
140
+ strindexes.insert(0, curstridx)
141
+ elif bestm == 1:
142
+ matches.insert(0, None)
143
+ strindexes.insert(0, curstridx)
144
+ else:
145
+ matches.insert(0, curnodeidx)
146
+ strindexes.insert(0, None)
147
+
148
+ bestm, besti, bestj = nextm, nexti, nextj
149
+
150
+ return strindexes, matches
src/new_text_alignment.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from difflib import SequenceMatcher
2
+
3
+ import numpy as np
4
+
5
+ from .new_alignment import ScoreParam, SeqGraphAlignment
6
+
7
+ PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"]
8
+
9
+ class TextSeqGraphAlignment(SeqGraphAlignment):
10
+ def __init__(
11
+ self,
12
+ text,
13
+ graph,
14
+ fastMethod=True,
15
+ globalAlign=True,
16
+ matchscore=1,
17
+ mismatchscore=-3,
18
+ gap_open=-2,
19
+ gap_extend=-1,
20
+ position_weight=0.1,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ score_params = ScoreParam(
25
+ match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend
26
+ )
27
+
28
+ if isinstance(text, str):
29
+ self.original_text = text
30
+ self.sequence = text.split()
31
+ else:
32
+ self.sequence = text
33
+ self.original_text = " ".join(text)
34
+ self.position_weight = position_weight
35
+
36
+ super().__init__(
37
+ self.sequence,
38
+ graph,
39
+ fastMethod,
40
+ globalAlign=globalAlign,
41
+ score_params=score_params,
42
+ *args,
43
+ **kwargs,
44
+ )
45
+
46
+ def string_similarity(self, s1, s2):
47
+ """Get edit-distance based similarity between two strings"""
48
+ return SequenceMatcher(None, s1, s2).ratio()
49
+
50
+ def matchscore(self, word1: str, word2: str) -> float:
51
+ """Enhanced scoring function that considers string similarity
52
+ and relative position"""
53
+ # Calculate basic string similarity
54
+ similarity = self.string_similarity(word1, word2)
55
+
56
+ # If words are very similar, treat as match
57
+ if similarity > 0.8: # Can tune this threshold
58
+ similarity = self.score.match
59
+ # For less similar words, scale score based on similarity
60
+ elif similarity > 0.5: # Can tune this threshold too
61
+ similarity = self.score.match * similarity
62
+ else:
63
+ similarity = self.score.mismatch
64
+ return similarity
65
+
66
+ # add weight if any punctuation mark is present
67
+ if any(char in word1 for char in PUNCTUATION_MARKS) or any(
68
+ char in word2 for char in PUNCTUATION_MARKS
69
+ ):
70
+ similarity = similarity * 1.5
71
+
72
+ return similarity
73
+
74
+ def alignmentStrings(self):
75
+ """Override to handle word-based alignment"""
76
+ aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs]
77
+ aligned_graph = [
78
+ self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs
79
+ ]
80
+ return " ".join(aligned_seq), " ".join(aligned_graph)
81
+
82
+ def alignStringToGraphFast(self):
83
+ if not isinstance(self.sequence, list):
84
+ raise TypeError("Sequence must be a list of words")
85
+
86
+ nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = (
87
+ self.initializeDynamicProgrammingData()
88
+ )
89
+ # M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence
90
+ M, X, Y = 0, 1, 2
91
+
92
+ ni = self.graph.nodeiterator()
93
+ for i, node in enumerate(ni()):
94
+ gbase = node.text
95
+
96
+ for j, sbase in enumerate(self.sequence):
97
+ candidates_X , candidates_Y , candidates_M = [], [], []
98
+ candidates_X += [
99
+ (self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M),
100
+ (self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X),
101
+ (self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y)
102
+ ]
103
+ for predIndex in self.prevIndices(node, nodeIDtoIndex):
104
+ candidates_Y += [
105
+ (self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M),
106
+ (self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X),
107
+ (self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y)
108
+ ]
109
+ candidates_M += [
110
+ (self.matchscore(sbase, gbase) + scores[0, predIndex + 1, j], predIndex + 1, j, M),
111
+ (self.matchscore(sbase, gbase) + scores[1, predIndex + 1, j], predIndex + 1, j, X),
112
+ (self.matchscore(sbase, gbase) + scores[2, predIndex + 1, j], predIndex + 1, j, Y)
113
+ ]
114
+
115
+ (
116
+ scores[0, i + 1, j + 1],
117
+ backGrphIdx[0, i + 1, j + 1],
118
+ backStrIdx[0, i + 1, j + 1],
119
+ backMtxIdx[0, i + 1, j + 1],
120
+ ) = max(candidates_M)
121
+ (
122
+ scores[1, i + 1, j + 1],
123
+ backGrphIdx[1, i + 1, j + 1],
124
+ backStrIdx[1, i + 1, j + 1],
125
+ backMtxIdx[1, i + 1, j + 1],
126
+ ) = max(candidates_X)
127
+ (
128
+ scores[2, i + 1, j + 1],
129
+ backGrphIdx[2, i + 1, j + 1],
130
+ backStrIdx[2, i + 1, j + 1],
131
+ backMtxIdx[2, i + 1, j + 1],
132
+ ) = max(candidates_Y)
133
+
134
+ return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)
src/poa_graph.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from Jonathan Dursi
3
+ https://github.com/ljdursi/poapy
4
+ """
5
+
6
+ import collections
7
+ import textwrap
8
+ from typing import Dict, List, Optional, Union
9
+
10
+ import numpy
11
+
12
+ from .alignment import SeqGraphAlignment
13
+
14
+
15
+ class Node(object):
16
+ def __init__(self, nodeID: int = -1, text: str = ""):
17
+ self.ID = nodeID
18
+ self.text = text
19
+ self.inEdges = {}
20
+ self.outEdges = {}
21
+ self.alignedTo = []
22
+
23
+ def __str__(self):
24
+ return "(%d:%s)" % (self.ID, self.text)
25
+
26
+ def _add_edge(
27
+ self,
28
+ edgeset: Dict[int, "Node"],
29
+ neighbourID: int,
30
+ label: Union[int, List[int]],
31
+ from_neighbour: bool,
32
+ weight: int = 1,
33
+ ):
34
+ if neighbourID is None:
35
+ return
36
+ # already present? just update labels
37
+ # otherwise create appropriately-ordered edge and proceed
38
+ if neighbourID in edgeset:
39
+ edgeset[neighbourID].weight += weight
40
+ if isinstance(label, list):
41
+ edgeset[neighbourID].labels.extend(label)
42
+ else:
43
+ edgeset[neighbourID].labels.append(label)
44
+ # remove duplicates
45
+ edgeset[neighbourID].labels = list(set(edgeset[neighbourID].labels))
46
+ else:
47
+ if from_neighbour:
48
+ edge = Edge(outNodeID=neighbourID, inNodeID=self.ID, label=label, weight=weight)
49
+ else:
50
+ edge = Edge(outNodeID=self.ID, inNodeID=neighbourID, label=label, weight=weight)
51
+ edgeset[neighbourID] = edge
52
+
53
+ def addInEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
54
+ self._add_edge(self.inEdges, neighbourID, label, from_neighbour=True, weight=weight)
55
+
56
+ def addOutEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
57
+ self._add_edge(self.outEdges, neighbourID, label, from_neighbour=False, weight=weight)
58
+
59
+ def nextNode(self, label: int):
60
+ """Returns the first (presumably only) outward neighbour
61
+ having the given edge label"""
62
+ nextID = None
63
+ for e in self.outEdges:
64
+ if label in self.outEdges[e].labels:
65
+ nextID = e
66
+ return nextID
67
+
68
+ @property
69
+ def inDegree(self):
70
+ return len(self.inEdges)
71
+
72
+ @property
73
+ def outDegree(self):
74
+ return len(self.outEdges)
75
+
76
+ @property
77
+ def weightedInDegree(self):
78
+ return sum(edge.weight for edge in self.inEdges.values())
79
+
80
+ @property
81
+ def weightedOutDegree(self):
82
+ return sum(edge.weight for edge in self.outEdges.values())
83
+
84
+ @property
85
+ def labels(self):
86
+ """Returns all the labels associated with an in-edge or an out edge."""
87
+ labelset = set([])
88
+ for e in list(self.inEdges.values()):
89
+ labelset = labelset.union(e.labels)
90
+ for e in list(self.outEdges.values()):
91
+ labelset = labelset.union(e.labels)
92
+ return list(labelset)
93
+
94
+
95
+ class Edge(object):
96
+ def __init__(
97
+ self,
98
+ inNodeID: int = -1,
99
+ outNodeID: int = -1,
100
+ label: Optional[Union[int, List[int]]] = None,
101
+ weight: int = 1,
102
+ ):
103
+ self.inNodeID = inNodeID
104
+ self.outNodeID = outNodeID
105
+
106
+ self.weight = weight
107
+
108
+ if label is None:
109
+ self.labels = []
110
+ elif isinstance(label, list):
111
+ self.labels = label
112
+ else:
113
+ self.labels = [label]
114
+
115
+ def addLabel(self, newlabel):
116
+ self.labels.append(newlabel)
117
+
118
+ def __str__(self):
119
+ nodestr = "(%d) -> (%d) " % (self.inNodeID, self.outNodeID)
120
+ if self.labels is None:
121
+ return nodestr
122
+ else:
123
+ return nodestr + self.labels.__str__()
124
+
125
+
126
+ class POAGraph(object):
127
+ def addUnmatchedSeq(self, seq, label: int = -1, updateSequences=True):
128
+ """Add a completely independant (sub)string to the graph,
129
+ and return node index to initial and final node"""
130
+ if seq is None:
131
+ return
132
+
133
+ firstID, lastID = None, None
134
+ neededSort = self.needsSort
135
+
136
+ for text in seq:
137
+ nodeID = self.addNode(text)
138
+ if firstID is None:
139
+ firstID = nodeID
140
+ if lastID is not None:
141
+ self.addEdge(lastID, nodeID, label)
142
+ lastID = nodeID
143
+
144
+ self._needsort = neededSort # no new order problems introduced
145
+ if updateSequences:
146
+ self._seqs.append(seq)
147
+ self._labels.append(label)
148
+ self._starts.append(firstID)
149
+ return firstID, lastID
150
+
151
+ def __init__(self, seq=None, label: Optional[Union[int, List[int]]] = None):
152
+ self._nextnodeID = 0
153
+ self._nnodes = 0
154
+ self._nedges = 0
155
+ self.nodedict = {}
156
+ self.nodeidlist = [] # allows a (partial) order to be imposed on the nodes
157
+ self._needsort = False
158
+ self._labels = []
159
+ self._seqs = []
160
+ self._starts = []
161
+
162
+ if seq is not None:
163
+ self.addUnmatchedSeq(seq, label)
164
+
165
+ def nodeIdxToBase(self, idx):
166
+ return self.nodedict[self.nodeidlist[idx]].text
167
+
168
+ def addNode(self, text):
169
+ nid = self._nextnodeID
170
+ newnode = Node(nid, text)
171
+ self.nodedict[nid] = newnode
172
+ self.nodeidlist.append(nid)
173
+ self._nnodes += 1
174
+ self._nextnodeID += 1
175
+ self._needsSort = True
176
+ return nid
177
+
178
+ def addEdge(self, start, end, label, weight: int = 1):
179
+ if start is None or end is None:
180
+ return
181
+
182
+ if start not in self.nodedict:
183
+ raise KeyError("addEdge: Start node not in graph: " + str(start))
184
+ if end not in self.nodedict:
185
+ raise KeyError("addEdge: End node not in graph: " + str(end))
186
+
187
+ oldNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
188
+
189
+ self.nodedict[start].addOutEdge(end, label, weight)
190
+ self.nodedict[end].addInEdge(start, label, weight)
191
+
192
+ newNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
193
+
194
+ if newNodeEdges != oldNodeEdges:
195
+ self._nedges += 1
196
+
197
+ self._needsSort = True
198
+ return
199
+
200
+ @property
201
+ def needsSort(self):
202
+ return self._needsort
203
+
204
+ @property
205
+ def nNodes(self):
206
+ return self._nnodes
207
+
208
+ @property
209
+ def nEdges(self):
210
+ return self._nedges
211
+
212
+ @property
213
+ def num_sequences(self):
214
+ return len(self._seqs)
215
+
216
+ def get_sequences(self):
217
+ return self._seqs
218
+
219
+ def _simplified_graph_rep(self):
220
+
221
+ node_to_pn = {}
222
+ pn_to_nodes = {}
223
+
224
+ # Find the mappings from nodes to pseudonodes
225
+ cur_pnid = 0
226
+ for _, node in self.nodedict.items():
227
+ if node.ID not in node_to_pn:
228
+ node_ids = [node.ID] + node.alignedTo
229
+ pn_to_nodes[cur_pnid] = node_ids
230
+ for nid in node_ids:
231
+ node_to_pn[nid] = cur_pnid
232
+ cur_pnid += 1
233
+
234
+ # create the pseudonodes
235
+ Pseudonode = collections.namedtuple(
236
+ "Pseudonode", ["pnode_id", "predecessors", "successors", "node_ids"]
237
+ )
238
+ pseudonodes = []
239
+
240
+ for pnid in range(cur_pnid):
241
+ nids, preds, succs = pn_to_nodes[pnid], [], []
242
+ for nid in nids:
243
+ node = self.nodedict[nid]
244
+ preds += [node_to_pn[inEdge.outNodeID] for _, inEdge in node.inEdges.items()]
245
+ succs += [node_to_pn[outEdge.inNodeID] for _, outEdge in node.outEdges.items()]
246
+
247
+ pn = Pseudonode(pnode_id=pnid, predecessors=preds, successors=succs, node_ids=nids)
248
+ pseudonodes.append(pn)
249
+
250
+ return pseudonodes
251
+
252
+ def toposort(self):
253
+ """Sorts node list so that all incoming edges come from nodes earlier in the list."""
254
+ sortedlist = []
255
+ completed = set([])
256
+
257
+ #
258
+ # The topological sort of this graph is complicated by the alignedTo edges;
259
+ # we want to nodes connected by such edges to remain near each other in the
260
+ # topological sort.
261
+ #
262
+ # Here we'll create a simple version of the graph that merges nodes that
263
+ # are alignedTo each other, performs the sort, and then decomposes the
264
+ # 'pseudonodes'.
265
+ #
266
+ # The need for this suggests that the way the graph is currently represented
267
+ # isn't quite right and needs some rethinking.
268
+ #
269
+
270
+ pseudonodes = self._simplified_graph_rep()
271
+
272
+ def dfs(start, complete, sortedlist):
273
+ stack, started = [start], set()
274
+ while stack:
275
+ pnodeID = stack.pop()
276
+
277
+ if pnodeID in complete:
278
+ continue
279
+
280
+ if pnodeID in started:
281
+ complete.add(pnodeID)
282
+ for nid in pseudonodes[pnodeID].node_ids:
283
+ sortedlist.insert(0, nid)
284
+ started.remove(pnodeID)
285
+ continue
286
+
287
+ successors = pseudonodes[pnodeID].successors
288
+ started.add(pnodeID)
289
+ stack.append(pnodeID)
290
+ stack.extend(successors)
291
+
292
+ while len(sortedlist) < self.nNodes:
293
+ found = None
294
+ for pnid in range(len(pseudonodes)):
295
+ if pnid not in completed and len(pseudonodes[pnid].predecessors) == 0:
296
+ found = pnid
297
+ break
298
+ assert found is not None
299
+ dfs(found, completed, sortedlist)
300
+
301
+ assert len(sortedlist) == self.nNodes
302
+ self.nodeidlist = sortedlist
303
+ self._needsSort = False
304
+ return
305
+
306
+ def testsort(self):
307
+ """Test the nodeidlist to make sure it is topologically sorted:
308
+ eg, all predecessors of a node preceed the node in the list"""
309
+ if self.nodeidlist is None:
310
+ return
311
+ seen_nodes = set()
312
+ for nodeidx in self.nodeidlist:
313
+ node = self.nodedict[nodeidx]
314
+ for in_neighbour in node.inEdges:
315
+ assert in_neighbour in seen_nodes
316
+ seen_nodes.add(nodeidx)
317
+ return
318
+
319
+ def nodeiterator(self):
320
+ if self.needsSort:
321
+ self.toposort()
322
+
323
+ def nodegenerator():
324
+ for nodeidx in self.nodeidlist:
325
+ yield self.nodedict[nodeidx]
326
+
327
+ return nodegenerator
328
+
329
+ def __str__(self):
330
+ selfstr = ""
331
+ ni = self.nodeiterator()
332
+ for node in ni():
333
+ selfstr += node.__str__() + "\n"
334
+ for outIdx in node.outEdges:
335
+ selfstr += " " + node.outEdges[outIdx].__str__() + "\n"
336
+ return selfstr
337
+
338
+ def incorporateSeqAlignment(self, alignment: SeqGraphAlignment, seq, label: int = -1):
339
+ """Incorporate a SeqGraphAlignment into the graph."""
340
+ newseq = alignment.sequence
341
+ stringidxs = alignment.stringidxs
342
+ nodeidxs = alignment.nodeidxs
343
+
344
+ firstID = None
345
+ headID = None
346
+ tailID = None
347
+
348
+ path = []
349
+ # head, tail of sequence may be unaligned; just add those into the
350
+ # graph directly
351
+ validstringidxs = [si for si in stringidxs if si is not None]
352
+ startSeqIdx, endSeqIdx = validstringidxs[0], validstringidxs[-1]
353
+ if startSeqIdx > 0:
354
+ firstID, headID = self.addUnmatchedSeq(
355
+ newseq[0:startSeqIdx], label, updateSequences=False
356
+ )
357
+ if endSeqIdx < len(newseq):
358
+ tailID, __ = self.addUnmatchedSeq(newseq[endSeqIdx + 1 :], label, updateSequences=False)
359
+
360
+ # now we march along the aligned part. For each text, we find or create
361
+ # a node in the graph:
362
+ # - if unmatched, the corresponding node is a new node
363
+ # - if matched:
364
+ # - if matched to a node with the same text, the node is that node
365
+ # - if matched to a node with a different text whch is in turn
366
+ # aligned to a node with the same text, that aligned node is
367
+ # the node
368
+ # - otherwise, we create a new node.
369
+ # In all cases, we create edges (or add labels) threading through the
370
+ # nodes.
371
+ for sindex, matchID in zip(stringidxs, nodeidxs):
372
+ if sindex is None:
373
+ continue
374
+ text = newseq[sindex]
375
+ if matchID is None:
376
+ nodeID = self.addNode(text)
377
+ elif self.nodedict[matchID].text == text:
378
+ nodeID = matchID
379
+ else:
380
+ otherAligns = self.nodedict[matchID].alignedTo
381
+ foundNode = None
382
+ for otherNodeID in otherAligns:
383
+ if self.nodedict[otherNodeID].text == text:
384
+ foundNode = otherNodeID
385
+ if foundNode is None:
386
+ nodeID = self.addNode(text)
387
+ self.nodedict[nodeID].alignedTo = [matchID] + otherAligns
388
+ for otherNodeID in [matchID] + otherAligns:
389
+ self.nodedict[otherNodeID].alignedTo.append(nodeID)
390
+ else:
391
+ nodeID = foundNode
392
+
393
+ self.addEdge(headID, nodeID, label)
394
+ headID = nodeID
395
+ if firstID is None:
396
+ firstID = headID
397
+
398
+ path.append(nodeID)
399
+
400
+ # finished the unaligned portion: now add an edge from the current headID to the tailID.
401
+ self.addEdge(headID, tailID, label)
402
+
403
+ # resort
404
+ self.toposort()
405
+
406
+ self._seqs.append(seq)
407
+ self._labels.append(label)
408
+ self._starts.append(firstID)
409
+ self._seq_paths[label] = path
410
+ return
411
+
412
+ def consensus(self, excludeLabels=None):
413
+ if excludeLabels is None:
414
+ excludeLabels = []
415
+
416
+ if self.needsSort:
417
+ self.toposort()
418
+
419
+ nodesInReverse = self.nodeidlist[::-1]
420
+ maxnodeID = max(nodesInReverse) + 1
421
+ nextInPath = [-1] * maxnodeID
422
+ scores = numpy.zeros((maxnodeID))
423
+
424
+ for nodeID in nodesInReverse:
425
+ bestWeightScoreEdge = (-1, -1, None)
426
+ for neighbourID in self.nodedict[nodeID].outEdges:
427
+ # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
428
+ e = self.nodedict[nodeID].outEdges[neighbourID]
429
+ weightScoreEdge = (e.weight, scores[neighbourID], neighbourID)
430
+
431
+ if weightScoreEdge > bestWeightScoreEdge:
432
+ bestWeightScoreEdge = weightScoreEdge
433
+
434
+ scores[nodeID] = sum(bestWeightScoreEdge[0:2])
435
+ nextInPath[nodeID] = bestWeightScoreEdge[2]
436
+
437
+ pos = numpy.argmax(scores)
438
+ path = []
439
+ bases = []
440
+ labels = []
441
+ while pos is not None and pos > -1:
442
+ path.append(pos)
443
+ bases.append(self.nodedict[pos].text)
444
+ labels.append(self.nodedict[pos].labels)
445
+ pos = nextInPath[pos]
446
+
447
+ # ignore END node
448
+ path = path[:-1]
449
+ bases = bases[:-1]
450
+ labels = labels[:-1]
451
+ return path, bases, labels
452
+
453
+ def allConsenses(self, maxfraction=0.5):
454
+ allpaths = []
455
+ allbases = []
456
+ alllabels = []
457
+ exclusions = []
458
+
459
+ passno = 0
460
+ lastlen = 1000
461
+ maxpasses = 10
462
+
463
+ while len(exclusions) < len(self._labels) and lastlen >= 10 and passno < maxpasses:
464
+ path, bases, labellists = self.consensus(exclusions)
465
+ if len(path) > 0:
466
+ allpaths.append(path)
467
+ allbases.append(bases)
468
+ alllabels.append(labellists)
469
+
470
+ labelcounts = collections.defaultdict(int)
471
+ for ll in labellists:
472
+ for label in ll:
473
+ labelcounts[label] += 1
474
+
475
+ for label, seq in zip(self._labels, self._seqs):
476
+ if label in labelcounts and labelcounts[label] >= maxfraction * len(seq):
477
+ exclusions.append(label)
478
+
479
+ lastlen = len(path)
480
+ passno += 1
481
+
482
+ return list(zip(allpaths, allbases, alllabels))
483
+
484
+ def generateAlignmentStrings(self):
485
+ """Return a list of strings corresponding to the alignments in the graph"""
486
+
487
+ # Step 1: assign node IDs to columns in the output
488
+ # column_index[node.ID] is the position in the toposorted node list
489
+ # of the node itself, or the earliest node it is aligned to.
490
+ column_index = {}
491
+ current_column = 0
492
+
493
+ # go through nodes in toposort order
494
+ ni = self.nodeiterator()
495
+ for node in ni():
496
+ other_columns = [
497
+ column_index[other] for other in node.alignedTo if other in column_index
498
+ ]
499
+ if other_columns:
500
+ found_idx = min(other_columns)
501
+ else:
502
+ found_idx = current_column
503
+ current_column += 1
504
+
505
+ column_index[node.ID] = found_idx
506
+
507
+ ncolumns = current_column
508
+
509
+ # Step 2: given the column indexes, populate the strings
510
+ # corresponding to the sequences inserted in the graph
511
+ seqnames = []
512
+ alignstrings = []
513
+ for label, start in zip(self._labels, self._starts):
514
+ seqnames.append(label)
515
+ curnode_id = start
516
+ charlist = ["-"] * ncolumns
517
+ while curnode_id is not None:
518
+ node = self.nodedict[curnode_id]
519
+ charlist[column_index[curnode_id]] = node.text
520
+ curnode_id = node.nextNode(label)
521
+ alignstrings.append("".join(charlist))
522
+
523
+ # Step 3: Same as step 2, but with consensus sequences
524
+ consenses = self.allConsenses()
525
+ for i, consensus in enumerate(consenses):
526
+ seqnames.append("Consensus" + str(i))
527
+ charlist = ["-"] * ncolumns
528
+ for path, text in zip(consensus[0], consensus[1]):
529
+ charlist[column_index[path]] = text
530
+ alignstrings.append("".join(charlist))
531
+
532
+ return list(zip(seqnames, alignstrings))
533
+
534
+ def jsOutput(self, verbose: bool = False, annotate_consensus: bool = True):
535
+ """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
536
+
537
+ # get the consensus sequence, which we'll use as the "spine" of the
538
+ # graph
539
+ pathdict = {}
540
+ if annotate_consensus:
541
+ path, __, __ = self.consensus()
542
+ lines = ["var nodes = ["]
543
+
544
+ ni = self.nodeiterator()
545
+ count = 0
546
+ for node in ni():
547
+ line = " {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text + '"'
548
+ if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
549
+ line += (
550
+ ", x: "
551
+ + str(pathdict[node.ID])
552
+ + ", y: 0 , fixed: { x:true, y:false},"
553
+ + "color: '#7BE141', is_consensus:true},"
554
+ )
555
+ else:
556
+ line += "},"
557
+ lines.append(line)
558
+
559
+ lines[-1] = lines[-1][:-1]
560
+ lines.append("];")
561
+
562
+ lines.append(" ")
563
+
564
+ lines.append("var edges = [")
565
+ ni = self.nodeiterator()
566
+ for node in ni():
567
+ nodeID = str(node.ID)
568
+ for edge in node.outEdges:
569
+ target = str(edge)
570
+ weight = str(len(node.outEdges[edge].labels) + 1.5)
571
+ lines.append(
572
+ " {from: "
573
+ + nodeID
574
+ + ", to: "
575
+ + target
576
+ + ", value: "
577
+ + weight
578
+ + ", color: '#4b72b0', arrows: 'to'},"
579
+ )
580
+ if verbose:
581
+ for alignededge in node.alignedTo:
582
+ # These edges indicate alignment to different bases, and are
583
+ # undirected; thus make sure we only plot them once:
584
+ if node.ID > alignededge:
585
+ continue
586
+ target = str(alignededge)
587
+ lines.append(
588
+ " {from: "
589
+ + nodeID
590
+ + ", to: "
591
+ + target
592
+ + ', value: 1, style: "dash-line", color: "red"},'
593
+ )
594
+
595
+ lines[-1] = lines[-1][:-1]
596
+ lines.append("];")
597
+ return lines
598
+
599
+ def htmlOutput(self, outfile, verbose: bool = False, annotate_consensus: bool = True):
600
+ header = """
601
+ <!doctype html>
602
+ <html>
603
+ <head>
604
+ <title>POA Graph Alignment</title>
605
+
606
+ <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
607
+ </head>
608
+
609
+ <body>
610
+
611
+ <div id="loadingProgress">0%</div>
612
+
613
+ <div id="mynetwork"></div>
614
+
615
+ <script type="text/javascript">
616
+ // create a network
617
+ """
618
+ outfile.write(textwrap.dedent(header[1:]))
619
+ lines = self.jsOutput(verbose=verbose, annotate_consensus=annotate_consensus)
620
+ for line in lines:
621
+ outfile.write(line + "\n")
622
+ footer = """
623
+ var container = document.getElementById('mynetwork');
624
+ var data= {
625
+ nodes: nodes,
626
+ edges: edges,
627
+ };
628
+ var options = {
629
+ width: '100%',
630
+ height: '800px',
631
+ physics: {
632
+ enabled: false,
633
+ stabilization: {
634
+ updateInterval: 10,
635
+ },
636
+ hierarchicalRepulsion: {
637
+ avoidOverlap: 0.9,
638
+ },
639
+ },
640
+ edges: {
641
+ color: {
642
+ inherit: false
643
+ }
644
+ },
645
+ layout: {
646
+ hierarchical: {
647
+ direction: "UD",
648
+ sortMethod: "directed",
649
+ shakeTowards: "roots",
650
+ levelSeparation: 150, // Adjust as needed
651
+ nodeSpacing: 100, // Adjust as needed
652
+ treeSpacing: 200, // Adjust as needed
653
+ parentCentralization: true,
654
+ }
655
+ }
656
+ };
657
+ var network = new vis.Network(container, data, options);
658
+
659
+ network.on('beforeDrawing', function(ctx) {
660
+ nodes.forEach(function(node) {
661
+ if (node.isConsensus) {
662
+ // Set the level of spine nodes to the bottom
663
+ network.body.data.nodes.update({
664
+ id: node.id,
665
+ level: 0 // Set level to 0 for spine nodes
666
+ });
667
+ }
668
+ });
669
+ });
670
+
671
+ network.on("stabilizationProgress", function (params) {
672
+ document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
673
+ });
674
+ network.once("stabilizationIterationsDone", function () {
675
+ document.getElementById("loadingProgress").innerText = "100%";
676
+ setTimeout(function () {
677
+ document.getElementById("loadingProgress").style.display = "none";
678
+ }, 500);
679
+ });
680
+ </script>
681
+
682
+ </body>
683
+ </html>
684
+ """
685
+ outfile.write(textwrap.dedent(footer))
src/text_poa_graph.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced version of POAGraph for text alignment
3
+ """
4
+
5
+ import pickle
6
+ import textwrap
7
+ from typing import Dict, Optional
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from src.text_poa_graph_utils import path_sim_llm
13
+ from src.global_edit_utils import clean_up_text
14
+
15
+ from .new_text_alignment import TextSeqGraphAlignment
16
+ from .poa_graph import Node, POAGraph
17
+
18
+
19
+ class TextNode(Node):
20
+ def __init__(self, nodeID=-1, text=""):
21
+ super().__init__(nodeID, text)
22
+ self.variations = {} # Track alternate phrasings
23
+ self.sequences = [] # Track sequences that contain this node
24
+ self.influenceScore = 0
25
+ self.num_tokens_used = 0
26
+
27
+ def add_variation(self, text, sequence_id):
28
+ self.variations[sequence_id] = text
29
+
30
+ @property
31
+ def is_stable(self):
32
+ """A node is stable if it appears frequently enough relative to total sequences"""
33
+ return self.frequency >= self.graph.stability_threshold
34
+
35
+
36
+ class TextPOAGraph(POAGraph):
37
+ def __init__(self, text=None, label=-1):
38
+ self.consensus_node_ids = []
39
+ self._seq_paths = {}
40
+ self.end_id = -1
41
+ self.start_id = -1
42
+ self.failed = False
43
+ self.num_input_tokens_used = 0
44
+ self.num_output_tokens_used = 0
45
+ super().__init__(text, label)
46
+
47
+ def addNode(self, text):
48
+ """Override to use TextNode"""
49
+ nid = self._nextnodeID
50
+ newnode = TextNode(nid, text)
51
+ self.nodedict[nid] = newnode
52
+ self.nodeidlist.append(nid)
53
+ self._nnodes += 1
54
+ self._nextnodeID += 1
55
+ self._needsSort = True
56
+ return nid
57
+
58
+ def addUnmatchedSeq(self, text, label=-1, updateSequences=True):
59
+ """Modified to handle text sequences"""
60
+ if text is None:
61
+ return
62
+
63
+ # Handle both string and list input
64
+ if isinstance(text, str):
65
+ words = text.split()
66
+ else:
67
+ words = text
68
+
69
+ firstID, lastID = None, None
70
+ neededSort = self.needsSort
71
+
72
+ path = []
73
+ for word in words:
74
+ nodeID = self.addNode(word)
75
+ if firstID is None:
76
+ firstID = nodeID
77
+ if lastID is not None:
78
+ self.addEdge(lastID, nodeID, label=label)
79
+ lastID = nodeID
80
+ path.append(nodeID)
81
+
82
+ self._needsort = neededSort
83
+ if updateSequences:
84
+ self._seqs.append(words)
85
+ self._labels.append(label)
86
+ self._starts.append(firstID)
87
+ self._seq_paths[label] = path
88
+
89
+ return firstID, lastID
90
+
91
+ def add_text(self, text, label=-1):
92
+ """Main method to add new text to the alignment"""
93
+ if len(self._seqs) == 0:
94
+ # First sequence - just add it
95
+ self.addUnmatchedSeq(text, label)
96
+ else:
97
+ # Align to existing graph
98
+ alignment = TextSeqGraphAlignment(
99
+ text, self, matchscore=2, mismatchscore=-1, gapscore=-2
100
+ )
101
+ self.incorporateSeqAlignment(alignment, text, label)
102
+
103
+ # Update node frequencies
104
+ self._update_frequencies()
105
+
106
+ def removeNode(self, nodeID):
107
+ """Override to handle text nodes"""
108
+ node = self.nodedict[nodeID]
109
+ if node is None:
110
+ return
111
+
112
+ # Remove all edges to this node
113
+ out_edges = node.outEdges.copy()
114
+ in_edges = node.inEdges.copy()
115
+
116
+ for edge in out_edges:
117
+ self.removeEdge(node.ID, edge)
118
+ for edge in in_edges:
119
+ self.removeEdge(edge, node.ID)
120
+
121
+ # Remove from graph
122
+ del self.nodedict[nodeID]
123
+ self.nodeidlist.remove(nodeID)
124
+
125
+ for path in self._seq_paths.values():
126
+ if nodeID in path:
127
+ path.remove(nodeID)
128
+
129
+ self._nnodes -= 1
130
+ self._needsSort = True
131
+
132
+ def removeEdge(self, nodeID1, nodeID2):
133
+ """Override to handle text nodes"""
134
+ node1 = self.nodedict[nodeID1]
135
+ node2 = self.nodedict[nodeID2]
136
+
137
+ if node1 is None or node2 is None:
138
+ return
139
+
140
+ # Remove from graph
141
+ del node1.outEdges[nodeID2]
142
+ del node2.inEdges[nodeID1]
143
+
144
+ def merge_consensus_nodes(self, verbose: bool = False):
145
+ self.toposort()
146
+ # reset consensus node ids
147
+ self.consensus_node_ids = []
148
+ nodes = list(self.nodeiterator()())
149
+ consensus_segments = []
150
+ i = 0
151
+ while i < len(nodes):
152
+ node = nodes[i]
153
+ out_weight = sum(e.weight for e in node.outEdges.values())
154
+ in_weight = sum(e.weight for e in node.inEdges.values())
155
+
156
+ if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]:
157
+ consensus_segment = [(node.ID, node.text)]
158
+ next_node = node
159
+ while (i + 1) < len(nodes) and len(next_node.outEdges) == 1:
160
+ next_node = nodes[i + 1]
161
+ next_out_weight = sum(e.weight for e in next_node.outEdges.values())
162
+ next_in_weight = sum(e.weight for e in next_node.inEdges.values())
163
+
164
+ if (
165
+ next_out_weight != self.num_sequences
166
+ or next_in_weight != self.num_sequences
167
+ ):
168
+ break
169
+
170
+ consensus_segment.append((next_node.ID, next_node.text))
171
+ i += 1
172
+ consensus_segments.append(consensus_segment)
173
+ i += 1
174
+ # merge consensus nodes into a single node
175
+ for segment in consensus_segments:
176
+ if len(segment) == 1:
177
+ self.consensus_node_ids.append(segment[0][0])
178
+ continue
179
+ merged_text = " ".join([text for _, text in segment])
180
+ first_node_id = segment[0][0]
181
+ last_node_id = segment[-1][0]
182
+
183
+ self.nodedict[last_node_id].text = merged_text
184
+ self.consensus_node_ids.append(last_node_id)
185
+
186
+ # attach all incoming edges to first node to last node
187
+ for id, edge in self.nodedict[first_node_id].inEdges.items():
188
+ weight = edge.weight
189
+ for _ in range(weight):
190
+ self.addEdge(id, last_node_id, label=edge.labels)
191
+
192
+ # delete all nodes except last node
193
+ for node_id, _ in segment[:-1]:
194
+ self.removeNode(node_id)
195
+
196
+
197
+
198
+ if verbose:
199
+ print(self.consensus_node_ids)
200
+
201
+ """
202
+ find all paths between start_node_id and end_node_id from original sequences
203
+ return a list of dictionaries with the following keys:
204
+ - path: list of node ids in the path (excluding start and including end)
205
+ - text: text of the path (excluding start and end)
206
+ - weight: minimal edge weight across all edges in the path
207
+ - labels: intersection of all edge labels in the path
208
+ """
209
+
210
+ def find_paths_between(self, start_node_id: int, end_node_id: int):
211
+ # find all paths between start_node_id and end_node_id from original sequences
212
+ path_dicts = []
213
+
214
+ # keep track of visited paths to avoid duplicates
215
+ visited_paths = set()
216
+
217
+ for _, path in self._seq_paths.items():
218
+ start_index = path.index(start_node_id) if start_node_id in path else None
219
+ end_index = path.index(end_node_id) if end_node_id in path else None
220
+
221
+ # print(start_index, end_index)
222
+ # print(path)
223
+
224
+ if (
225
+ start_index is not None
226
+ and end_index is not None
227
+ and end_index - start_index > 1
228
+ and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths
229
+ ):
230
+ # intersection of all edge labels in the path
231
+ path_labels = set.intersection(
232
+ *[
233
+ set(self.nodedict[next_node_id].inEdges[node_id].labels)
234
+ for node_id, next_node_id in zip(
235
+ path[start_index:end_index], path[start_index + 1 : end_index + 1]
236
+ )
237
+ ]
238
+ )
239
+ path_weight = len(path_labels)
240
+ path_dicts.append(
241
+ {
242
+
243
+ "path": path[start_index + 1 : end_index + 1],
244
+ "body_text": " ".join(
245
+ [
246
+ self.nodedict[node_id].text
247
+ for node_id in path[start_index + 1 : end_index]
248
+ ]
249
+ ),
250
+ "begin_text": self.nodedict[path[start_index]].text,
251
+ "end_text": self.nodedict[path[end_index]].text,
252
+ "weight": path_weight,
253
+ "labels": path_labels,
254
+ }
255
+ )
256
+ visited_paths.add(tuple(path[start_index + 1 : end_index + 1]))
257
+
258
+ return path_dicts
259
+
260
+ def _follow_path(self, start_id):
261
+ """Follow all possible paths from a node"""
262
+ paths = []
263
+ visited = set()
264
+
265
+ def dfs(node_id, current_path):
266
+ if node_id in visited:
267
+ return
268
+ visited.add(node_id)
269
+ node = self.nodedict[node_id]
270
+
271
+ if not node.outEdges:
272
+ paths.append(current_path + [node_id])
273
+ return
274
+
275
+ for next_id in node.outEdges:
276
+ dfs(next_id, current_path + [node_id])
277
+
278
+ dfs(start_id, [])
279
+ return paths
280
+
281
+ def merge_paths_between(
282
+ self,
283
+ start_node_id: int,
284
+ end_node_id: int,
285
+ path_sim_type: str = "llm",
286
+ verbose: bool = False,
287
+ **kwargs,
288
+ ):
289
+ path_dicts = self.find_paths_between(start_node_id, end_node_id)
290
+
291
+ if path_sim_type == "llm":
292
+ api = kwargs.get("api", "openai")
293
+ model = kwargs.get("model", "gpt-4o-mini")
294
+ domain = kwargs.get("domain", None)
295
+ similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None)
296
+
297
+ def path_sim_func(path1_text, path2_text):
298
+ return path_sim_llm(
299
+ path1_text,
300
+ path2_text,
301
+ api=api,
302
+ model=model,
303
+ domain=domain,
304
+ custom_similarity_judge_prompt=similarity_judge_prompt,
305
+ )
306
+
307
+ elif path_sim_type == "cosine":
308
+ pass
309
+ # embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
310
+ # threshold = kwargs.get("threshold", 0.9)
311
+ # path_sim_func = path_sim_cosine(embedding_model, threshold)
312
+ else:
313
+ raise ValueError(f"Invalid path similarity type: {path_sim_type}")
314
+
315
+ # merge paths based on semantic similarity
316
+ path_equivalence_classes = {}
317
+ class_count = 0
318
+
319
+ for path_dict in path_dicts:
320
+ if verbose:
321
+ print(path_dict)
322
+ found_class = False
323
+ for _, eq_class in path_equivalence_classes.items():
324
+ # check if path dict is already in an equivalence class
325
+ path1_text = (
326
+ path_dict["begin_text"]
327
+ + " "
328
+ + path_dict["body_text"]
329
+ + " "
330
+ + path_dict["end_text"]
331
+ )
332
+ path2_text = (
333
+ eq_class[0]["begin_text"]
334
+ + " "
335
+ + eq_class[0]["body_text"]
336
+ + " "
337
+ + eq_class[0]["end_text"]
338
+ )
339
+
340
+ judgement, num_input_tokens, num_output_tokens = path_sim_func(
341
+ path1_text, path2_text
342
+ )
343
+ self.num_input_tokens_used += num_input_tokens
344
+ self.num_output_tokens_used += num_output_tokens
345
+ if judgement:
346
+ eq_class.append(path_dict)
347
+ found_class = True
348
+ break
349
+ if not found_class:
350
+ class_count += 1
351
+ path_equivalence_classes[class_count] = [path_dict]
352
+
353
+ nodes_to_remove = set() # Track nodes to remove
354
+ for _, eq_class in path_equivalence_classes.items():
355
+ path_dict = eq_class[0]
356
+
357
+ if verbose:
358
+ print(eq_class)
359
+ # add new node with merged text
360
+ new_node_id = self.addNode(path_dict["body_text"])
361
+ for sequence_id in path_dict["labels"]:
362
+ self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
363
+
364
+ # collect nodes to remove from first path
365
+ nodes_to_remove.update(path_dict["path"][:-1])
366
+
367
+ # process data regarding weights and labels
368
+ labels = list(path_dict["labels"])
369
+ weight = path_dict["weight"]
370
+ self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
371
+
372
+ # Updated seq_paths for all labels to include new_node betwwen start_node and end_node
373
+ for label in labels:
374
+ index = self._seq_paths[label].index(start_node_id)
375
+ if (
376
+ index + 1 < len(self._seq_paths[label])
377
+ and self._seq_paths[label][index + 1] != new_node_id
378
+ ):
379
+ self._seq_paths[label].insert(index + 1, new_node_id)
380
+
381
+ self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
382
+
383
+ self.nodedict[new_node_id].sequences = labels
384
+ # process additional paths
385
+ for path_dict in eq_class[1:]:
386
+ for sequence_id in path_dict["labels"]:
387
+ self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
388
+ nodes_to_remove.update(path_dict["path"][:-1])
389
+
390
+ # copy incoming edges to new node
391
+ labels = list(path_dict["labels"])
392
+ weight = path_dict["weight"]
393
+ self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
394
+
395
+ # Updated seq_paths for all labels to include new_node betwwen start_node and end_node
396
+ for label in labels:
397
+ index = self._seq_paths[label].index(start_node_id)
398
+ if (
399
+ index + 1 < len(self._seq_paths[label])
400
+ and self._seq_paths[label][index + 1] != new_node_id
401
+ ):
402
+ self._seq_paths[label].insert(index + 1, new_node_id)
403
+
404
+ self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
405
+ self.nodedict[new_node_id].sequences.extend(labels)
406
+
407
+ self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences))
408
+
409
+ # Remove all collected nodes after processing
410
+ for node_id in nodes_to_remove:
411
+ if node_id in self.nodedict:
412
+ if verbose:
413
+ print(f"Removing node {node_id}")
414
+ self.removeNode(node_id)
415
+
416
+ def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs):
417
+ # add dummy end node to the end of the graph
418
+ if not self.consensus_node_ids:
419
+ self.merge_consensus_nodes(verbose=verbose)
420
+
421
+ self.toposort()
422
+
423
+ if self.start_id == -1:
424
+ if verbose:
425
+ print("Adding start node")
426
+ self.start_id = self.addNode(text="START")
427
+ self._nextnodeID += 1
428
+ self.consensus_node_ids.insert(0, self.start_id)
429
+
430
+ for label, path in self._seq_paths.items():
431
+ self.addEdge(self.start_id, path[0], label=label, weight=1)
432
+ path.insert(0, self.start_id)
433
+
434
+ if self.end_id == -1:
435
+ if verbose:
436
+ print("Adding end node")
437
+ self.end_id = self.addNode(text="END")
438
+ self._nextnodeID += 1
439
+ self.consensus_node_ids = self.consensus_node_ids + [self.end_id]
440
+
441
+ for label, path in self._seq_paths.items():
442
+ self.addEdge(path[-1], self.end_id, label=label, weight=1)
443
+ path.append(self.end_id)
444
+
445
+ for i in tqdm(range(len(self.consensus_node_ids) - 1)):
446
+ if verbose:
447
+ print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1])
448
+ self.merge_paths_between(
449
+ self.consensus_node_ids[i],
450
+ self.consensus_node_ids[i + 1],
451
+ path_sim_type=path_sim_type,
452
+ verbose=verbose,
453
+ **kwargs,
454
+ )
455
+
456
+ def get_variable_node_ids(self):
457
+ return [
458
+ node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids
459
+ ]
460
+
461
+ def compress_paths_between(self, start_node_id: int, end_node_id: int):
462
+ pass
463
+
464
+ def compress_graph(self):
465
+ pass
466
+
467
+ def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2):
468
+ self.toposort()
469
+ direct_scores = []
470
+ for node in self.nodedict.values():
471
+ next_out_weight = sum(e.weight for e in node.outEdges.values())
472
+ next_in_weight = sum(e.weight for e in node.inEdges.values())
473
+ if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences:
474
+ out_list = []
475
+ for edge in node.outEdges.values():
476
+ for _ in range(len(set(edge.labels))):
477
+ out_list.append(np.mean([outcome[label] for label in set(edge.labels)]))
478
+ direct_scores.append((node.ID, np.var(out_list)))
479
+
480
+ scores = direct_scores.copy()
481
+
482
+ # Start from the end and propagate influence backward
483
+ for i in range(len(scores) - 2, -1, -1):
484
+ # Current node gets its direct influence plus discounted influence of next node
485
+ current_direct = scores[i][1]
486
+ next_total = scores[i + 1][1]
487
+ scores[i] = (scores[i][0], current_direct + discount_factor * next_total)
488
+
489
+ scores.sort(key=lambda x: x[1], reverse=True)
490
+ return scores
491
+
492
+ def jsOutput(
493
+ self,
494
+ verbose: bool = False,
495
+ annotate_consensus: bool = True,
496
+ color_annotations: Dict[int, str] = None,
497
+ ):
498
+ """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
499
+
500
+ # get the consensus sequence, which we'll use as the "spine" of the
501
+ # graph
502
+ pathdict = {}
503
+ if annotate_consensus:
504
+ path, __, __ = self.consensus()
505
+ lines = ["var nodes = ["]
506
+
507
+ ni = self.nodeiterator()
508
+ count = 0
509
+ for node in ni():
510
+ title_text = ""
511
+ if node.sequences:
512
+ title_text += f"Sequences: {node.sequences}"
513
+ if node.variations:
514
+ title_text += ";;;".join(
515
+ [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
516
+ )
517
+ title_text = title_text.replace('"', "'")
518
+ line = (
519
+ " {id:"
520
+ + str(node.ID)
521
+ + ', label: "'
522
+ + str(node.ID)
523
+ + ": "
524
+ + node.text.replace('"', "'")
525
+ + '", title: '
526
+ + '"'
527
+ + title_text
528
+ + '",'
529
+ )
530
+ if color_annotations and node.ID in color_annotations:
531
+ line += f" color: '{color_annotations[node.ID]}', "
532
+ if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
533
+ line += (
534
+ ", x: "
535
+ + str(pathdict[node.ID])
536
+ + ", y: 0 , fixed: { x:true, y:false},"
537
+ + "color: '#7BE141', is_consensus:true},"
538
+ )
539
+ else:
540
+ line += "},"
541
+ lines.append(line)
542
+
543
+ lines[-1] = lines[-1][:-1]
544
+ lines.append("];")
545
+
546
+ lines.append(" ")
547
+
548
+ lines.append("var edges = [ ")
549
+ ni = self.nodeiterator()
550
+ for node in ni():
551
+ nodeID = str(node.ID)
552
+ for edge in node.outEdges:
553
+ target = str(edge)
554
+ weight = str(node.outEdges[edge].weight + 1.5)
555
+ lines.append(
556
+ " {from: "
557
+ + nodeID
558
+ + ", to: "
559
+ + target
560
+ + ", value: "
561
+ + weight
562
+ + ", color: '#4b72b0', arrows: 'to'},"
563
+ )
564
+ if verbose:
565
+ for alignededge in node.alignedTo:
566
+ # These edges indicate alignment to different bases, and are
567
+ # undirected; thus make sure we only plot them once:
568
+ if node.ID > alignededge:
569
+ continue
570
+ target = str(alignededge)
571
+ lines.append(
572
+ " {from: "
573
+ + nodeID
574
+ + ", to: "
575
+ + target
576
+ + ', value: 1, style: "dash-line", color: "red"},'
577
+ )
578
+
579
+ lines[-1] = lines[-1][:-1]
580
+ lines.append("];")
581
+ return lines
582
+
583
+ def htmlOutput(
584
+ self,
585
+ outfile,
586
+ verbose: bool = False,
587
+ annotate_consensus: bool = True,
588
+ color_annotations: Dict[int, str] = None,
589
+ ):
590
+ header = """
591
+ <!doctype html>
592
+ <html>
593
+ <head>
594
+ <title>POA Graph Alignment</title>
595
+
596
+ <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
597
+ </head>
598
+
599
+ <body>
600
+
601
+ <div id="loadingProgress">0%</div>
602
+
603
+ <div id="mynetwork"></div>
604
+
605
+ <script type="text/javascript">
606
+ // create a network
607
+ """
608
+ outfile.write(textwrap.dedent(header[1:]))
609
+ lines = self.jsOutput(
610
+ verbose=verbose,
611
+ annotate_consensus=annotate_consensus,
612
+ color_annotations=color_annotations,
613
+ )
614
+ for line in lines:
615
+ outfile.write(line + "\n")
616
+ footer = """
617
+ var container = document.getElementById('mynetwork');
618
+ var data= {
619
+ nodes: nodes,
620
+ edges: edges,
621
+ };
622
+ var options = {
623
+ width: '100%',
624
+ height: '800px',
625
+ physics: {
626
+ enabled: false,
627
+ stabilization: {
628
+ updateInterval: 10,
629
+ },
630
+ },
631
+ edges: {
632
+ color: {
633
+ inherit: false
634
+ }
635
+ },
636
+ layout: {
637
+ hierarchical: {
638
+ direction: "UD",
639
+ sortMethod: "directed",
640
+ shakeTowards: "roots",
641
+ levelSeparation: 150, // Adjust as needed
642
+ nodeSpacing: 800, // Adjust as needed
643
+ treeSpacing: 200, // Adjust as needed
644
+ parentCentralization: true,
645
+ }
646
+ }
647
+ };
648
+ var network = new vis.Network(container, data, options);
649
+
650
+ network.on('beforeDrawing', function(ctx) {
651
+ nodes.forEach(function(node) {
652
+ if (node.isConsensus) {
653
+ // Set the level of spine nodes to the bottom
654
+ network.body.data.nodes.update({
655
+ id: node.id,
656
+ level: 0 // Set level to 0 for spine nodes
657
+ });
658
+ }
659
+ });
660
+ });
661
+
662
+ network.on("stabilizationProgress", function (params) {
663
+ document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
664
+ });
665
+ network.once("stabilizationIterationsDone", function () {
666
+ document.getElementById("loadingProgress").innerText = "100%";
667
+ setTimeout(function () {
668
+ document.getElementById("loadingProgress").style.display = "none";
669
+ }, 500);
670
+ });
671
+
672
+
673
+ </script>
674
+
675
+ </body>
676
+ </html>
677
+ """
678
+ outfile.write(textwrap.dedent(footer))
679
+
680
+
681
+ def multi_consensus_response(self, abstention_threshold: Optional[float] = None, filter: bool = True):
682
+ self.toposort()
683
+ nodesInReverse = self.nodeidlist[::-1]
684
+ maxnodeID = self.end_id
685
+ nextInPath = [-1] * maxnodeID
686
+ scores = np.zeros(len(self.nodeidlist))
687
+
688
+ id_to_index = {node_id: index for index, node_id in enumerate(self.nodeidlist)}
689
+ index_to_id = {index: node_id for index, node_id in enumerate(self.nodeidlist)}
690
+
691
+ for nodeID in nodesInReverse:
692
+ bestWeightScoreEdges = [(-1, -1, None)]
693
+ for neighbourID in self.nodedict[nodeID].outEdges:
694
+ # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
695
+ e = self.nodedict[nodeID].outEdges[neighbourID]
696
+ weightScoreEdge = (e.weight, scores[id_to_index[neighbourID]], neighbourID)
697
+
698
+
699
+ if weightScoreEdge > bestWeightScoreEdges[0]:
700
+ bestWeightScoreEdges = [weightScoreEdge]
701
+ elif weightScoreEdge == bestWeightScoreEdges[0] and filter:
702
+ bestWeightScoreEdges.append(weightScoreEdge)
703
+
704
+
705
+ scores[id_to_index[nodeID]] = sum(bestWeightScoreEdges[0][0:2])
706
+ if bestWeightScoreEdges[0][2] is not None:
707
+ nextInPath[id_to_index[nodeID]] = id_to_index[bestWeightScoreEdges[0][2]]
708
+ else:
709
+ nextInPath[id_to_index[nodeID]] = None
710
+
711
+ pos = np.argmax(scores)
712
+ path = []
713
+ text = []
714
+ labels = []
715
+
716
+ while pos is not None and pos > -1:
717
+ if abstention_threshold is not None and self.nodedict[index_to_id[pos]].variations:
718
+ if (
719
+ len(self.nodedict[index_to_id[pos]].labels) / self.num_sequences
720
+ >= abstention_threshold
721
+ ):
722
+ path.append(index_to_id[pos])
723
+ labels.append(self.nodedict[index_to_id[pos]].labels)
724
+ text.append(self.nodedict[index_to_id[pos]].text)
725
+ else:
726
+ path.append(index_to_id[pos])
727
+ labels.append(self.nodedict[index_to_id[pos]].labels)
728
+ text.append(self.nodedict[index_to_id[pos]].text)
729
+ pos = nextInPath[pos]
730
+
731
+ # ignore END node
732
+ path = path[:-1]
733
+ # ignore END node
734
+ text = text[:-1]
735
+ # ignore START in text
736
+ text[0] = text[0].replace("START", "")
737
+ labels = labels[:-1]
738
+
739
+ return " ".join(text)
740
+
741
+
742
+ def consensus_response(
743
+ self, selection_threshold: Optional[float] = 0.5, api: str = "openai" , model: str = "gpt-4o-mini", task: str = "bio", **kwargs
744
+ ) -> str:
745
+ self.toposort()
746
+
747
+ consensus_node_ids = self.consensus_node_ids
748
+ print(consensus_node_ids)
749
+
750
+ selected_node_ids = []
751
+
752
+ for node_id in consensus_node_ids:
753
+ if node_id == self.start_id or node_id == self.end_id:
754
+ continue
755
+
756
+ selected_node_ids.append(node_id)
757
+
758
+ for neighbor_id in self.nodedict[node_id].outEdges:
759
+ if neighbor_id in consensus_node_ids:
760
+ continue
761
+
762
+ if (
763
+ len(self.nodedict[neighbor_id].labels) / self.num_sequences
764
+ >= selection_threshold
765
+ ):
766
+ selected_node_ids.append(neighbor_id)
767
+
768
+ text = " ".join([self.nodedict[node_id].text for node_id in selected_node_ids])
769
+ print(text)
770
+ cleaned_text = clean_up_text(text, task=task, api=api, model=model, **kwargs)
771
+ return cleaned_text
772
+
773
+ def save_to_pickle(self, filename):
774
+ with open(filename, "wb+") as f:
775
+ pickle.dump(self, f)
776
+
777
+ def refine_graph(
778
+ self,
779
+ verbose: bool = False,
780
+ save_intermediate_file: str = None,
781
+ final_merge: bool = True,
782
+ **kwargs,
783
+ ):
784
+ self.merge_consensus_nodes(verbose=verbose)
785
+
786
+ if save_intermediate_file:
787
+ with open(save_intermediate_file, "w+") as f:
788
+ self.htmlOutput(f, annotate_consensus=False)
789
+
790
+ if not self.consensus_node_ids:
791
+ self.failed = True
792
+ return
793
+
794
+ else:
795
+ self.merge_divergent_paths(verbose=verbose, **kwargs)
796
+
797
+ if final_merge:
798
+ try:
799
+ self.merge_consensus_nodes(verbose=verbose)
800
+ except Exception as e:
801
+ print(e)
802
+ self.failed = True
src/text_poa_graph_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from openai import OpenAI
5
+
6
+
7
+ TEXT_SIMILARITY_JUDGE_PROMPT = """
8
+ You are given two pieces of text. Your task is to determine whether they are semantically equivalent based solely on their factual content.
9
+
10
+ Here are the specific guidelines:
11
+ - Texts are equivalent if they convey the same core information or concept, regardless of wording or structure
12
+ - If one text has information that is a subset of the other text, then the texts are equivalent
13
+ - Focus ONLY on the essential claims, not on:
14
+ * Stylistic differences or tone
15
+ * Level of detail (if the core facts remain the same)
16
+ * Connotative differences between words
17
+ * Implied significance or emphasis
18
+ * Presentation order (if all key information is present in both)
19
+ - Minor additions of non-contradictory information should not make texts non-equivalent
20
+ - For ambiguous cases, prioritize the central claim or purpose of the text
21
+
22
+ Examples of equivalent pairs:
23
+ - "The meeting starts at 3pm" and "The 3 o'clock meeting will begin on time"
24
+ - "Research indicates a 15% increase" and "Studies show a fifteen percent rise"
25
+ - "was influential in the field" and "had a significant impact on the community"
26
+
27
+ Examples of non-equivalent pairs:
28
+ - "The project might be completed by Friday" and "The project will be finished by Friday"
29
+ - "Most experts agree on the approach" and "All experts support the approach"
30
+
31
+ Strictly follow these guidelines and return ONLY:
32
+ - equivalent
33
+ - not equivalent
34
+ """
35
+ MATH_SIMILARITY_JUDGE_PROMPT = """
36
+ You are given two pieces of text from mathematical solutions. Your task is to determine whether the two solution segments are mathematically equivalent in their content, while allowing for stylistic variations.
37
+
38
+ Here are some important guidelines:
39
+ - Solutions should be considered equivalent if:
40
+ 1. They communicate the same mathematical content/approach, even if word choice or phrasing differs
41
+ 2. They contain the same key mathematical ideas, even if expressed differently
42
+ 3. The same mathematical steps are described, even if using different words
43
+ 4. They present the same final answer, regardless of wording style or formatting
44
+
45
+ - Allow for these variations while still considering solutions equivalent:
46
+ 1. Stylistic differences ("we will" vs. "we'll" or "I'll")
47
+ 2. Different levels of formality in the explanation
48
+ 3. Minor rephrasing that preserves the core mathematical content
49
+ 4. Use of synonyms or alternative mathematical terminology for the same concept
50
+
51
+ - Solutions are NOT equivalent if:
52
+ 1. They use fundamentally different mathematical approaches
53
+ 2. They work with different formulas or equations
54
+ 3. They present different mathematical steps or operations
55
+ 4. They reach different conclusions or answers
56
+ 5. One contains substantial mathematical content that the other lacks
57
+
58
+ - When examining final answers, focus on mathematical equivalence rather than stylistic presentation
59
+ - For solution steps, maintain the core mathematical approach while allowing for rephrasing
60
+
61
+ Examples of solutions that SHOULD be considered equivalent:
62
+ - "We will systematically evaluate each possible grouping" and "We'll evaluate each grouping"
63
+ - "The answer is x = 5" and "Therefore, x equals 5"
64
+ - "Using the quadratic formula" and "Applying the quadratic formula"
65
+
66
+ Strictly follow the guidelines above.
67
+ Return your judgment in the following format. Do not include any other text:
68
+ - equivalent
69
+ - not equivalent
70
+ """
71
+
72
+ def path_sim_llm(
73
+ path1_text: str,
74
+ path2_text: str,
75
+ api: str = "openai",
76
+ model: str = "gpt-4.1-mini",
77
+ verbose: bool = False,
78
+ domain: Optional[str] = "text",
79
+ custom_similarity_judge_prompt: str = None,
80
+ ):
81
+ if api == "openai":
82
+ client = OpenAI()
83
+ elif api == "hf":
84
+ client = InferenceClient()
85
+ else:
86
+ raise ValueError(f"Invalid API: {api}")
87
+
88
+ if domain == "text":
89
+ similarity_judge_prompt = (
90
+ f"{TEXT_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
91
+ )
92
+ elif domain == "math":
93
+ similarity_judge_prompt = (
94
+ f"{MATH_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
95
+ )
96
+ elif not domain and custom_similarity_judge_prompt:
97
+ similarity_judge_prompt = (
98
+ f"{custom_similarity_judge_prompt}\n\nText 1: {path1_text}\nText 2: {path2_text}"
99
+ )
100
+ else:
101
+ raise ValueError(f"Invalid domain: {domain} and no custom similarity judge prompt provided")
102
+
103
+ completion = client.chat.completions.create(
104
+ model=model,
105
+ temperature=0,
106
+ messages=[
107
+ {"role": "system", "content": "You are a helpful assistant."},
108
+ {"role": "user", "content": similarity_judge_prompt},
109
+ ],
110
+ )
111
+
112
+ judgement = completion.choices[0].message.content.strip()
113
+ judgement = "".join(c for c in judgement if c.isalpha() or c == " ")
114
+ judgement = judgement.strip()
115
+
116
+ if verbose:
117
+ print(f"{path1_text} \nand \n{path2_text} \nare {judgement}")
118
+
119
+ if judgement == "equivalent":
120
+ return 1, completion.usage.prompt_tokens, completion.usage.completion_tokens
121
+ elif judgement == "not equivalent":
122
+ return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
123
+ else:
124
+ if verbose:
125
+ print(f"Invalid judgement: {judgement}")
126
+ return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
src/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from openai import OpenAI
5
+
6
+
7
+ def detect_abstain(text: str, api: str, model: str):
8
+ if api == "openai":
9
+ client = OpenAI()
10
+ elif api == "hf":
11
+ client = InferenceClient()
12
+ else:
13
+ raise ValueError(f"Invalid API: {api}")
14
+
15
+ detect_abstain_prompt = f"""
16
+ You are given a piece of text that is a part of a biography of an entity.
17
+ Text: {text}
18
+
19
+ If the text claims a lack of knowledge about the topic, return "Abstain".
20
+ Otherwise, return "Not abstain".
21
+ """
22
+
23
+ completion = client.chat.completions.create(
24
+ model=model,
25
+ messages=[
26
+ {"role": "system", "content": "You are a helpful assistant."},
27
+ {"role": "user", "content": detect_abstain_prompt},
28
+ ],
29
+ )
30
+
31
+ return completion.choices[0].message.content.strip()
32
+
33
+
34
+ def calculate_factf1_at_k(
35
+ supported_facts: List[str], unsupported_facts: List[str], k: int
36
+ ) -> float:
37
+ """
38
+ Calculate the F1 score at k for supported and unsupported facts
39
+ """
40
+ if len(supported_facts) == 0:
41
+ return 0
42
+
43
+ precision = len(supported_facts) / (len(supported_facts) + len(unsupported_facts))
44
+ recall = min(len(supported_facts) / k, 1)
45
+ f1 = 2 * precision * recall / (precision + recall)
46
+ return f1
web_interface/.DS_Store ADDED
Binary file (6.15 kB). View file
 
web_interface/README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ConGr Visualizer Web Interface
2
+
3
+ A web-based interface for exploring and visualizing ConGrs (Consensus Graphs) from research datasets.
4
+
5
+ ## Features
6
+
7
+ ### Browse Existing Graphs
8
+ - **Dataset Selection**: Choose from available datasets (BIO, FP, HIST, REFS, MATH, AIME)
9
+ - **Entity Selection**: Browse entities within each dataset
10
+ - **Model Information**: See which language models were used for each graph
11
+ - **Graph Visualization**: Interactive network visualization using vis.js
12
+ - **Metadata Display**: View graph statistics and consensus text
13
+
14
+ ### Create New Graphs
15
+ - **Text Input**: Enter multiple text sequences to create new ConGrs
16
+ - **Real-time Visualization**: See the graph structure as it's created
17
+ - **Save Functionality**: Save created graphs to pickle files
18
+
19
+ ## Available Datasets
20
+
21
+ - **BIO**: Biography datasets with various public figures
22
+ - **FP**: False Presupposition datasets
23
+ - **HIST**: Historical events datasets
24
+ - **REFS**: Reference datasets
25
+ - **MATH**: Mathematical problem datasets
26
+ - **AIME**: American Invitational Mathematics Examination datasets
27
+
28
+ ## Models
29
+
30
+ The graphs are generated using various language models:
31
+ - olmo7b
32
+ - qwen72b
33
+ - llama70b
34
+ - llama8b
35
+
36
+ ## Installation
37
+
38
+ 1. Install dependencies:
39
+ ```bash
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+ 2. Start the server:
44
+ ```bash
45
+ python server.py
46
+ ```
47
+
48
+ 3. Open your browser and navigate to:
49
+ ```
50
+ http://localhost:8080
51
+ ```
52
+
53
+ ## Usage
54
+
55
+ ### Browsing Existing Graphs
56
+
57
+ 1. **Select Dataset**: Choose a dataset from the dropdown menu
58
+ 2. **Select Entity**: Choose an entity from the available options
59
+ 3. **View Graph**: The graph will be automatically loaded and displayed
60
+ 4. **View Information**: Graph metadata and consensus text will be shown
61
+
62
+ ### Creating New Graphs
63
+
64
+ 1. **Enter Text**: Input multiple text sequences (one per line)
65
+ 2. **Create Graph**: Click "Create Graph" to generate a new ConGr
66
+ 3. **Save Graph**: Optionally save the graph to a pickle file
67
+
68
+ ## API Endpoints
69
+
70
+ - `GET /api/datasets` - Get available datasets
71
+ - `GET /api/entities?dataset=<dataset>` - Get entities for a dataset
72
+ - `POST /api/load_existing_graph` - Load an existing graph
73
+ - `POST /api/create_graph` - Create a new graph from text sequences
74
+ - `POST /api/save_graph` - Save a graph to file
75
+
76
+ ## Testing
77
+
78
+ Run the test script to verify the server is working correctly:
79
+
80
+ ```bash
81
+ python test_server.py
82
+ ```
83
+
84
+ ## File Structure
85
+
86
+ ```
87
+ web_interface/
88
+ β”œβ”€β”€ server.py # Flask server
89
+ β”œβ”€β”€ index.html # Web interface
90
+ β”œβ”€β”€ requirements.txt # Python dependencies
91
+ β”œβ”€β”€ test_server.py # Test script
92
+ └── README.md # This file
93
+ ```
94
+
95
+ ## Graph Information
96
+
97
+ When viewing a graph, you can see:
98
+ - **Dataset**: The source dataset
99
+ - **Entity**: The specific entity or topic
100
+ - **Model**: The language model used
101
+ - **Sequences**: Number of input sequences
102
+ - **Nodes**: Number of nodes in the graph
103
+ - **Edges**: Number of edges in the graph
104
+ - **Consensus**: The consensus text generated from the graph
105
+
106
+ ## Visualization Features
107
+
108
+ - **Hierarchical Layout**: Graphs are displayed in a hierarchical structure
109
+ - **Color Coding**: Consensus nodes are highlighted in green
110
+ - **Interactive**: Zoom, pan, and hover for more information
111
+ - **Responsive**: Works on desktop and mobile devices
web_interface/index.html ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ConGr Visualizer</title>
7
+ <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
8
+ <style>
9
+ * {
10
+ margin: 0;
11
+ padding: 0;
12
+ box-sizing: border-box;
13
+ }
14
+
15
+ body {
16
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
17
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
18
+ min-height: 100vh;
19
+ color: #333;
20
+ }
21
+
22
+ .container {
23
+ max-width: 1400px;
24
+ margin: 0 auto;
25
+ padding: 20px;
26
+ }
27
+
28
+ .header {
29
+ text-align: center;
30
+ margin-bottom: 30px;
31
+ color: white;
32
+ }
33
+
34
+ .header h1 {
35
+ font-size: 2.5rem;
36
+ margin-bottom: 10px;
37
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
38
+ }
39
+
40
+ .header p {
41
+ font-size: 1.1rem;
42
+ opacity: 0.9;
43
+ }
44
+
45
+ .main-content {
46
+ display: grid;
47
+ grid-template-columns: 1fr 2fr;
48
+ gap: 30px;
49
+ background: white;
50
+ border-radius: 15px;
51
+ box-shadow: 0 20px 40px rgba(0,0,0,0.1);
52
+ overflow: hidden;
53
+ }
54
+
55
+ .sidebar {
56
+ background: #f8f9fa;
57
+ padding: 30px;
58
+ border-right: 1px solid #e9ecef;
59
+ }
60
+
61
+ .section {
62
+ margin-bottom: 30px;
63
+ }
64
+
65
+ .section h3 {
66
+ color: #495057;
67
+ margin-bottom: 15px;
68
+ font-size: 1.2rem;
69
+ border-bottom: 2px solid #667eea;
70
+ padding-bottom: 5px;
71
+ }
72
+
73
+ .input-group {
74
+ margin-bottom: 20px;
75
+ }
76
+
77
+ .input-group label {
78
+ display: block;
79
+ margin-bottom: 8px;
80
+ font-weight: 600;
81
+ color: #495057;
82
+ }
83
+
84
+ .input-group textarea {
85
+ width: 100%;
86
+ min-height: 120px;
87
+ padding: 12px;
88
+ border: 2px solid #e9ecef;
89
+ border-radius: 8px;
90
+ font-family: inherit;
91
+ font-size: 14px;
92
+ resize: vertical;
93
+ transition: border-color 0.3s ease;
94
+ }
95
+
96
+ .input-group textarea:focus {
97
+ outline: none;
98
+ border-color: #667eea;
99
+ }
100
+
101
+ .input-group select {
102
+ width: 100%;
103
+ padding: 12px;
104
+ border: 2px solid #e9ecef;
105
+ border-radius: 8px;
106
+ font-family: inherit;
107
+ font-size: 14px;
108
+ background: white;
109
+ cursor: pointer;
110
+ transition: border-color 0.3s ease;
111
+ }
112
+
113
+ .input-group select:focus {
114
+ outline: none;
115
+ border-color: #667eea;
116
+ }
117
+
118
+ .btn {
119
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
120
+ color: white;
121
+ border: none;
122
+ padding: 12px 24px;
123
+ border-radius: 8px;
124
+ cursor: pointer;
125
+ font-size: 14px;
126
+ font-weight: 600;
127
+ transition: transform 0.2s ease, box-shadow 0.2s ease;
128
+ width: 100%;
129
+ margin-bottom: 10px;
130
+ }
131
+
132
+ .btn:hover {
133
+ transform: translateY(-2px);
134
+ box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
135
+ }
136
+
137
+ .btn:active {
138
+ transform: translateY(0);
139
+ }
140
+
141
+ .btn-secondary {
142
+ background: linear-gradient(135deg, #6c757d 0%, #495057 100%);
143
+ }
144
+
145
+ .btn-secondary:hover {
146
+ box-shadow: 0 5px 15px rgba(108, 117, 125, 0.4);
147
+ }
148
+
149
+ .graph-container {
150
+ padding: 30px;
151
+ position: relative;
152
+ }
153
+
154
+ #mynetwork {
155
+ width: 100%;
156
+ height: 600px;
157
+ border: 2px solid #e9ecef;
158
+ border-radius: 10px;
159
+ background: white;
160
+ }
161
+
162
+ .loading {
163
+ position: absolute;
164
+ top: 50%;
165
+ left: 50%;
166
+ transform: translate(-50%, -50%);
167
+ background: rgba(255, 255, 255, 0.9);
168
+ padding: 20px;
169
+ border-radius: 10px;
170
+ box-shadow: 0 5px 15px rgba(0,0,0,0.1);
171
+ z-index: 1000;
172
+ }
173
+
174
+ .loading.hidden {
175
+ display: none;
176
+ }
177
+
178
+ .status {
179
+ margin-top: 15px;
180
+ padding: 10px;
181
+ border-radius: 5px;
182
+ font-size: 14px;
183
+ }
184
+
185
+ .status.success {
186
+ background: #d4edda;
187
+ color: #155724;
188
+ border: 1px solid #c3e6cb;
189
+ }
190
+
191
+ .status.error {
192
+ background: #f8d7da;
193
+ color: #721c24;
194
+ border: 1px solid #f5c6cb;
195
+ }
196
+
197
+ .status.info {
198
+ background: #d1ecf1;
199
+ color: #0c5460;
200
+ border: 1px solid #bee5eb;
201
+ }
202
+
203
+ .example-text {
204
+ background: #e9ecef;
205
+ padding: 15px;
206
+ border-radius: 8px;
207
+ font-size: 13px;
208
+ line-height: 1.4;
209
+ margin-top: 10px;
210
+ }
211
+
212
+ .example-text h4 {
213
+ margin-bottom: 8px;
214
+ color: #495057;
215
+ }
216
+
217
+ .example-text p {
218
+ margin-bottom: 8px;
219
+ }
220
+
221
+ .example-text ul {
222
+ margin-left: 20px;
223
+ }
224
+
225
+ .example-text li {
226
+ margin-bottom: 4px;
227
+ }
228
+
229
+ .entity-list {
230
+ max-height: 200px;
231
+ overflow-y: auto;
232
+ border: 1px solid #e9ecef;
233
+ border-radius: 8px;
234
+ background: white;
235
+ }
236
+
237
+ .entity-item {
238
+ padding: 10px 12px;
239
+ border-bottom: 1px solid #f8f9fa;
240
+ cursor: pointer;
241
+ transition: background-color 0.2s ease;
242
+ }
243
+
244
+ .entity-item:hover {
245
+ background-color: #f8f9fa;
246
+ }
247
+
248
+ .entity-item:last-child {
249
+ border-bottom: none;
250
+ }
251
+
252
+ .entity-name {
253
+ font-weight: 600;
254
+ color: #495057;
255
+ }
256
+
257
+ .entity-model {
258
+ font-size: 12px;
259
+ color: #6c757d;
260
+ margin-top: 2px;
261
+ }
262
+
263
+ .graph-info {
264
+ background: #e9ecef;
265
+ padding: 15px;
266
+ border-radius: 8px;
267
+ margin-top: 15px;
268
+ }
269
+
270
+ .graph-info h4 {
271
+ margin-bottom: 10px;
272
+ color: #495057;
273
+ }
274
+
275
+ .graph-info p {
276
+ margin-bottom: 5px;
277
+ font-size: 14px;
278
+ }
279
+
280
+ .consensus-text {
281
+ background: #d4edda;
282
+ padding: 10px;
283
+ border-radius: 5px;
284
+ margin-top: 10px;
285
+ font-style: italic;
286
+ }
287
+
288
+ .sequences-section {
289
+ background: #f8f9fa;
290
+ padding: 15px;
291
+ border-radius: 8px;
292
+ margin-bottom: 15px;
293
+ border: 1px solid #e9ecef;
294
+ }
295
+
296
+ .sequences-section h4 {
297
+ margin-bottom: 10px;
298
+ color: #495057;
299
+ font-size: 1rem;
300
+ }
301
+
302
+ .sequences-list {
303
+ max-height: 150px;
304
+ overflow-y: auto;
305
+ border: 1px solid #e9ecef;
306
+ border-radius: 6px;
307
+ background: white;
308
+ }
309
+
310
+ .sequences-list li {
311
+ padding: 6px 10px;
312
+ border-bottom: 1px solid #f8f9fa;
313
+ cursor: pointer;
314
+ transition: background-color 0.2s ease;
315
+ }
316
+
317
+ .sequences-list li:hover {
318
+ background-color: #f8f9fa;
319
+ }
320
+
321
+ .sequences-list li:last-child {
322
+ border-bottom: none;
323
+ }
324
+
325
+ .sequences-list .sequence-item {
326
+ font-family: 'Courier New', Courier, monospace;
327
+ font-size: 12px;
328
+ line-height: 1.3;
329
+ white-space: pre-wrap;
330
+ word-break: break-all;
331
+ }
332
+
333
+ .consensus-highlight {
334
+ background-color: #ceeab2;
335
+ color: #2d5016;
336
+ font-weight: bold;
337
+ padding: 1px 2px;
338
+ border-radius: 3px;
339
+ }
340
+
341
+ .consensus-section {
342
+ background: #f8f9fa;
343
+ padding: 15px;
344
+ border-radius: 8px;
345
+ margin-bottom: 15px;
346
+ border: 1px solid #e9ecef;
347
+ }
348
+
349
+ .consensus-text {
350
+ font-family: 'Courier New', Courier, monospace;
351
+ font-size: 14px;
352
+ line-height: 1.4;
353
+ white-space: pre-wrap;
354
+ word-break: break-word;
355
+ background: white;
356
+ padding: 10px;
357
+ border: 1px solid #e9ecef;
358
+ border-radius: 6px;
359
+ }
360
+
361
+ @media (max-width: 768px) {
362
+ .main-content {
363
+ grid-template-columns: 1fr;
364
+ }
365
+
366
+ .sidebar {
367
+ border-right: none;
368
+ border-bottom: 1px solid #e9ecef;
369
+ }
370
+
371
+ .header h1 {
372
+ font-size: 2rem;
373
+ }
374
+ }
375
+ </style>
376
+ </head>
377
+ <body>
378
+ <div class="container">
379
+ <div class="header">
380
+ <h1>ConGr Visualizer</h1>
381
+ <p>Explore and visualize ConGrs</p>
382
+ </div>
383
+
384
+ <div class="main-content">
385
+ <div class="sidebar">
386
+ <div class="section">
387
+ <h3>Browse Existing Graphs</h3>
388
+ <div class="input-group">
389
+ <label for="datasetSelect">Select Dataset:</label>
390
+ <select id="datasetSelect" onchange="loadEntities()">
391
+ <option value="">Choose a dataset...</option>
392
+ </select>
393
+ </div>
394
+ <div class="input-group">
395
+ <label for="entitySelect">Select Instance:</label>
396
+ <select id="entitySelect" onchange="loadModels()">
397
+ <option value="">Choose an instance...</option>
398
+ </select>
399
+ </div>
400
+ <div class="input-group">
401
+ <label for="modelSelect">Select Model:</label>
402
+ <select id="modelSelect" onchange="loadSelectedGraph()">
403
+ <option value="">Choose a model...</option>
404
+ </select>
405
+ </div>
406
+ <div id="graphInfo" class="graph-info hidden">
407
+ <h4>Graph Information</h4>
408
+ <div id="graphDetails"></div>
409
+ </div>
410
+ </div>
411
+
412
+ <div class="section">
413
+ <h3>Create New Graph</h3>
414
+ <div class="input-group">
415
+ <label for="textInput">Enter text sequences (one per line):</label>
416
+ <textarea id="textInput" placeholder="Enter your text sequences here..."></textarea>
417
+ </div>
418
+ <button class="btn" onclick="createGraph()">Create Graph</button>
419
+ <div class="input-group">
420
+ <label>
421
+ <input type="checkbox" id="computeConsensus" checked> Display consensus response using consensus decoding
422
+ </label>
423
+ </div>
424
+ </div>
425
+
426
+ <div class="section">
427
+ <h3>Graph Options</h3>
428
+ <div class="input-group">
429
+ <label for="saveFilename">Save filename:</label>
430
+ <input type="text" id="saveFilename" placeholder="graph.pkl" value="graph.pkl">
431
+ </div>
432
+ <button class="btn btn-secondary" onclick="saveGraph()">Save Graph</button>
433
+ <button class="btn btn-secondary" onclick="clearGraph()">Clear Graph</button>
434
+ </div>
435
+
436
+ <div id="status" class="status hidden"></div>
437
+ </div>
438
+
439
+ <div class="graph-container">
440
+ <div id="loadingProgress" class="loading hidden">Processing...</div>
441
+ <div id="originalSequences" class="sequences-section hidden">
442
+ <h4>Original Sequences</h4>
443
+ <div id="sequencesList"></div>
444
+ </div>
445
+ <div id="consensusResponse" class="consensus-section hidden">
446
+ <h4>Consensus Response</h4>
447
+ <div id="consensusText"></div>
448
+ </div>
449
+ <div id="mynetwork"></div>
450
+ </div>
451
+ </div>
452
+ </div>
453
+
454
+ <script>
455
+ let network = null;
456
+ let currentGraphData = null;
457
+ let availableEntities = [];
458
+
459
+ function showStatus(message, type = 'info') {
460
+ const status = document.getElementById('status');
461
+ status.textContent = message;
462
+ status.className = `status ${type}`;
463
+ status.classList.remove('hidden');
464
+
465
+ if (type === 'success') {
466
+ setTimeout(() => {
467
+ status.classList.add('hidden');
468
+ }, 3000);
469
+ }
470
+ }
471
+
472
+ function showLoading() {
473
+ document.getElementById('loadingProgress').classList.remove('hidden');
474
+ }
475
+
476
+ function hideLoading() {
477
+ document.getElementById('loadingProgress').classList.add('hidden');
478
+ }
479
+
480
+ async function loadDatasets() {
481
+ try {
482
+ const response = await fetch('/api/datasets');
483
+ const data = await response.json();
484
+
485
+ const datasetSelect = document.getElementById('datasetSelect');
486
+ datasetSelect.innerHTML = '<option value="">Choose a dataset...</option>';
487
+
488
+ if (data.datasets && data.datasets.length > 0) {
489
+ data.datasets.forEach(dataset => {
490
+ const option = document.createElement('option');
491
+ option.value = dataset.name;
492
+ option.textContent = `${dataset.display_name} (${dataset.count} graphs)`;
493
+ datasetSelect.appendChild(option);
494
+ });
495
+ }
496
+ } catch (error) {
497
+ showStatus('Error loading datasets: ' + error.message, 'error');
498
+ }
499
+ }
500
+
501
+ async function loadEntities() {
502
+ const datasetSelect = document.getElementById('datasetSelect');
503
+ const entitySelect = document.getElementById('entitySelect');
504
+ const modelSelect = document.getElementById('modelSelect');
505
+ const dataset = datasetSelect.value;
506
+
507
+ if (!dataset) {
508
+ entitySelect.innerHTML = '<option value="">Choose an entity...</option>';
509
+ modelSelect.innerHTML = '<option value="">Choose a model...</option>';
510
+ return;
511
+ }
512
+
513
+ showLoading();
514
+ showStatus('Loading entities...', 'info');
515
+
516
+ try {
517
+ const response = await fetch(`/api/entities?dataset=${dataset}`);
518
+ const data = await response.json();
519
+
520
+ if (data.error) {
521
+ showStatus('Error loading entities: ' + data.error, 'error');
522
+ return;
523
+ }
524
+
525
+ availableEntities = data.entities;
526
+ entitySelect.innerHTML = '<option value="">Choose an entity...</option>';
527
+ modelSelect.innerHTML = '<option value="">Choose a model...</option>';
528
+
529
+ if (data.entities && data.entities.length > 0) {
530
+ // Get unique entity names
531
+ const uniqueEntities = [...new Set(data.entities.map(e => e.entity))];
532
+
533
+ // Sort numerically for non-bio datasets
534
+ if (dataset !== 'bio') {
535
+ uniqueEntities.sort((a, b) => {
536
+ // Extract numbers from entity names for sorting
537
+ const numA = parseInt(a.match(/\d+/)?.[0] || '0');
538
+ const numB = parseInt(b.match(/\d+/)?.[0] || '0');
539
+ return numA - numB;
540
+ });
541
+ } else {
542
+ // Sort alphabetically for bio dataset
543
+ uniqueEntities.sort();
544
+ }
545
+
546
+ uniqueEntities.forEach(entityName => {
547
+ const option = document.createElement('option');
548
+ option.value = entityName;
549
+ option.textContent = entityName;
550
+ entitySelect.appendChild(option);
551
+ });
552
+ }
553
+
554
+ showStatus(`Loaded ${data.entities.length} entities from ${dataset} dataset`, 'success');
555
+ } catch (error) {
556
+ showStatus('Error loading entities: ' + error.message, 'error');
557
+ } finally {
558
+ hideLoading();
559
+ }
560
+ }
561
+
562
+ async function loadModels() {
563
+ const datasetSelect = document.getElementById('datasetSelect');
564
+ const entitySelect = document.getElementById('entitySelect');
565
+ const modelSelect = document.getElementById('modelSelect');
566
+ const dataset = datasetSelect.value;
567
+ const entityName = entitySelect.value;
568
+
569
+ if (!entityName) {
570
+ modelSelect.innerHTML = '<option value="">Choose a model...</option>';
571
+ return;
572
+ }
573
+
574
+ showLoading();
575
+ showStatus('Loading models...', 'info');
576
+
577
+ try {
578
+ const response = await fetch(`/api/models?dataset=${dataset}&entity=${encodeURIComponent(entityName)}`);
579
+ const data = await response.json();
580
+
581
+ if (data.error) {
582
+ showStatus('Error loading models: ' + data.error, 'error');
583
+ return;
584
+ }
585
+
586
+ modelSelect.innerHTML = '<option value="">Choose a model...</option>';
587
+ if (data.models && data.models.length > 0) {
588
+ // Sort models by name for consistency
589
+ data.models.sort((a, b) => a.model.localeCompare(b.model));
590
+
591
+ data.models.forEach(model => {
592
+ const option = document.createElement('option');
593
+ option.value = model.filepath;
594
+ option.textContent = model.model;
595
+ modelSelect.appendChild(option);
596
+ });
597
+ } else {
598
+ console.log('No models found for this entity');
599
+ }
600
+ showStatus(`Loaded ${data.models.length} models for ${entityName}`, 'success');
601
+ } catch (error) {
602
+ showStatus('Error loading models: ' + error.message, 'error');
603
+ } finally {
604
+ hideLoading();
605
+ }
606
+ }
607
+
608
+ function displayOriginalSequences(sequences, consensusText = null) {
609
+ const sequencesSection = document.getElementById('originalSequences');
610
+ const sequencesList = document.getElementById('sequencesList');
611
+
612
+ if (!sequences || sequences.length === 0) {
613
+ sequencesSection.classList.add('hidden');
614
+ return;
615
+ }
616
+
617
+ let html = '<ul class="sequences-list">';
618
+ sequences.forEach((sequence, index) => {
619
+ let highlightedSequence = sequence;
620
+
621
+ // Highlight consensus text in green if available
622
+ if (consensusText && consensusText.trim()) {
623
+ const consensusWords = consensusText.trim().split(/\s+/);
624
+ let currentSequence = sequence;
625
+
626
+ consensusWords.forEach(word => {
627
+ if (word.length > 2) { // Only highlight words longer than 2 characters
628
+ const regex = new RegExp(`\\b${word.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}\\b`, 'gi');
629
+ currentSequence = currentSequence.replace(regex, `<span class="consensus-highlight">${word}</span>`);
630
+ }
631
+ });
632
+
633
+ highlightedSequence = currentSequence;
634
+ }
635
+
636
+ html += `<li><div class="sequence-item"><strong>Sequence ${index + 1}:</strong> ${highlightedSequence}</div></li>`;
637
+ });
638
+
639
+ html += '</ul>';
640
+
641
+ sequencesList.innerHTML = html;
642
+ sequencesSection.classList.remove('hidden');
643
+
644
+ // Display consensus response in separate box if available
645
+ displayConsensusResponse(consensusText);
646
+ }
647
+
648
+ function displayConsensusResponse(consensusText) {
649
+ const consensusSection = document.getElementById('consensusResponse');
650
+ const consensusTextDiv = document.getElementById('consensusText');
651
+
652
+ if (!consensusText || !consensusText.trim()) {
653
+ consensusSection.classList.add('hidden');
654
+ return;
655
+ }
656
+
657
+ consensusTextDiv.innerHTML = `<div class="consensus-text">${consensusText}</div>`;
658
+ consensusSection.classList.remove('hidden');
659
+ }
660
+
661
+ async function loadSelectedGraph() {
662
+ const modelSelect = document.getElementById('modelSelect');
663
+ const selectedModel = modelSelect.value;
664
+
665
+ if (!selectedModel) {
666
+ return;
667
+ }
668
+
669
+ showLoading();
670
+ showStatus('Loading graph...', 'info');
671
+
672
+ try {
673
+ const computeConsensus = document.getElementById('computeConsensus').checked;
674
+ const response = await fetch('/api/load_existing_graph', {
675
+ method: 'POST',
676
+ headers: {
677
+ 'Content-Type': 'application/json',
678
+ },
679
+ body: JSON.stringify({
680
+ filepath: selectedModel,
681
+ compute_consensus: computeConsensus
682
+ })
683
+ });
684
+
685
+ const data = await response.json();
686
+
687
+ if (data.success) {
688
+ displayGraph(data.nodes, data.edges);
689
+ displayOriginalSequences(data.original_sequences, data.consensus_text);
690
+ showGraphInfo(data);
691
+ showStatus(`Graph loaded successfully! ${data.num_sequences} sequences, ${data.num_nodes} nodes, ${data.num_edges} edges.`, 'success');
692
+ } else {
693
+ showStatus('Error loading graph: ' + data.error, 'error');
694
+ }
695
+ } catch (error) {
696
+ showStatus('Error loading graph: ' + error.message, 'error');
697
+ } finally {
698
+ hideLoading();
699
+ }
700
+ }
701
+
702
+ function showGraphInfo(data) {
703
+ const graphInfo = document.getElementById('graphInfo');
704
+ const graphDetails = document.getElementById('graphDetails');
705
+
706
+ let detailsHtml = '';
707
+
708
+ if (data.metadata) {
709
+ detailsHtml += `
710
+ <p><strong>Dataset:</strong> ${data.metadata.task}</p>
711
+ <p><strong>Entity:</strong> ${data.metadata.entity}</p>
712
+ <p><strong>Model:</strong> ${data.metadata.model}</p>
713
+ `;
714
+ } else {
715
+ }
716
+
717
+ detailsHtml += `
718
+ <p><strong>Sequences:</strong> ${data.num_sequences}</p>
719
+ <p><strong>Nodes:</strong> ${data.num_nodes}</p>
720
+ <p><strong>Edges:</strong> ${data.num_edges}</p>
721
+ `;
722
+
723
+ graphDetails.innerHTML = detailsHtml;
724
+ graphInfo.classList.remove('hidden');
725
+ }
726
+
727
+ async function createGraph() {
728
+ const textInput = document.getElementById('textInput').value.trim();
729
+ if (!textInput) {
730
+ showStatus('Please enter some text sequences.', 'error');
731
+ return;
732
+ }
733
+
734
+ const sequences = textInput.split('\n').filter(line => line.trim() !== '');
735
+ if (sequences.length < 2) {
736
+ showStatus('Please enter at least 2 text sequences.', 'error');
737
+ return;
738
+ }
739
+
740
+ showLoading();
741
+ showStatus('Creating graph...', 'info');
742
+
743
+ try {
744
+ const computeConsensus = document.getElementById('computeConsensus').checked;
745
+ const response = await fetch('/api/create_graph', {
746
+ method: 'POST',
747
+ headers: {
748
+ 'Content-Type': 'application/json',
749
+ },
750
+ body: JSON.stringify({
751
+ sequences: sequences,
752
+ compute_consensus: computeConsensus
753
+ })
754
+ });
755
+
756
+ const data = await response.json();
757
+
758
+ if (data.success) {
759
+ displayGraph(data.nodes, data.edges);
760
+ displayOriginalSequences(data.original_sequences, data.consensus_text);
761
+ showStatus(`Graph created with ${data.num_sequences} sequences, ${data.num_nodes} nodes, and ${data.num_edges} edges!`, 'success');
762
+ } else {
763
+ showStatus('Error creating graph: ' + data.error, 'error');
764
+ }
765
+ } catch (error) {
766
+ showStatus('Error creating graph: ' + error.message, 'error');
767
+ } finally {
768
+ hideLoading();
769
+ }
770
+ }
771
+
772
+ function displayGraph(nodes, edges) {
773
+ const container = document.getElementById('mynetwork');
774
+
775
+ if (!nodes || nodes.length === 0) {
776
+ console.error('No nodes provided to displayGraph');
777
+ return;
778
+ }
779
+
780
+ // Process nodes without manual level assignment
781
+ const processedNodes = nodes.map(node => ({ ...node }));
782
+
783
+ const data = {
784
+ nodes: new vis.DataSet(processedNodes),
785
+ edges: new vis.DataSet(edges)
786
+ };
787
+
788
+ const options = {
789
+ width: '100%',
790
+ height: '100%',
791
+ physics: {
792
+ enabled: false,
793
+ stabilization: {
794
+ updateInterval: 10,
795
+ },
796
+ },
797
+ edges: {
798
+ color: {
799
+ inherit: false
800
+ }
801
+ },
802
+ layout: {
803
+ hierarchical: {
804
+ direction: "UD",
805
+ sortMethod: "directed",
806
+ shakeTowards: "roots",
807
+ levelSeparation: 150,
808
+ nodeSpacing: 800,
809
+ treeSpacing: 200,
810
+ parentCentralization: true,
811
+ }
812
+ }
813
+ };
814
+
815
+ if (network) {
816
+ network.destroy();
817
+ }
818
+
819
+ try {
820
+ network = new vis.Network(container, data, options);
821
+ } catch (error) {
822
+ console.error('Error creating network:', error);
823
+ return;
824
+ }
825
+
826
+ network.on("stabilizationProgress", function (params) {
827
+ document.getElementById("loadingProgress").innerText =
828
+ "Stabilizing: " + Math.round(params.iterations / params.total * 100) + "%";
829
+ });
830
+
831
+ network.once("stabilizationIterationsDone", function () {
832
+ document.getElementById("loadingProgress").innerText = "100%";
833
+ setTimeout(function () {
834
+ document.getElementById("loadingProgress").classList.add("hidden");
835
+ }, 500);
836
+ });
837
+
838
+ currentGraphData = { nodes, edges };
839
+ }
840
+
841
+ async function saveGraph() {
842
+ const textInput = document.getElementById('textInput').value.trim();
843
+ if (!textInput) {
844
+ showStatus('Please enter some text sequences first.', 'error');
845
+ return;
846
+ }
847
+
848
+ const sequences = textInput.split('\n').filter(line => line.trim() !== '');
849
+ if (sequences.length < 2) {
850
+ showStatus('Please enter at least 2 text sequences.', 'error');
851
+ return;
852
+ }
853
+
854
+ const filename = document.getElementById('saveFilename').value || 'graph.pkl';
855
+
856
+ showLoading();
857
+ showStatus('Saving graph...', 'info');
858
+
859
+ try {
860
+ const response = await fetch('/api/save_graph', {
861
+ method: 'POST',
862
+ headers: {
863
+ 'Content-Type': 'application/json',
864
+ },
865
+ body: JSON.stringify({
866
+ sequences: sequences,
867
+ filename: filename
868
+ })
869
+ });
870
+
871
+ const data = await response.json();
872
+
873
+ if (data.success) {
874
+ showStatus(`Graph saved successfully to ${data.filename}!`, 'success');
875
+ } else {
876
+ showStatus('Error saving graph: ' + data.error, 'error');
877
+ }
878
+ } catch (error) {
879
+ showStatus('Error saving graph: ' + error.message, 'error');
880
+ } finally {
881
+ hideLoading();
882
+ }
883
+ }
884
+
885
+ function clearGraph() {
886
+ if (network) {
887
+ network.destroy();
888
+ network = null;
889
+ }
890
+ currentGraphData = null;
891
+ document.getElementById('textInput').value = '';
892
+ document.getElementById('datasetSelect').value = '';
893
+ document.getElementById('entitySelect').innerHTML = '<option value="">Choose an entity...</option>';
894
+ document.getElementById('modelSelect').innerHTML = '<option value="">Choose a model...</option>';
895
+ document.getElementById('graphInfo').classList.add('hidden');
896
+ document.getElementById('originalSequences').classList.add('hidden');
897
+ showStatus('Graph cleared.', 'info');
898
+ }
899
+
900
+ // Initialize
901
+ document.addEventListener('DOMContentLoaded', function() {
902
+ loadDatasets();
903
+ showStatus('Ready to explore existing graphs or create new ones!', 'info');
904
+ });
905
+ </script>
906
+ </body>
907
+ </html>
web_interface/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ flask==2.3.3
2
+ flask-cors==4.0.0
3
+ numpy==1.24.3
4
+ tqdm==4.66.1
web_interface/server.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flask server for POA Graph Web Interface
4
+ """
5
+
6
+ import glob
7
+ import os
8
+ import pickle
9
+ import re
10
+ import sys
11
+
12
+ from flask import Flask, jsonify, request, send_from_directory
13
+ from flask_cors import CORS
14
+
15
+ # Get the repository root directory (parent of web_interface)
16
+ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ # Add the repository root to the path so we can import the POA graph modules
19
+ sys.path.append(REPO_ROOT)
20
+
21
+ from src.new_text_alignment import TextSeqGraphAlignment
22
+ from src.text_poa_graph import TextPOAGraph
23
+
24
+ try:
25
+ from src.generation_methods import decode_consensus
26
+ except ImportError:
27
+ decode_consensus = None
28
+
29
+ app = Flask(__name__)
30
+ CORS(app) # Enable CORS for all routes
31
+
32
+ # Base paths for different datasets (relative to repo root)
33
+ GRAPH_PATHS = {
34
+ "bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"),
35
+ "fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"),
36
+ "hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"),
37
+ "refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"),
38
+ "math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"),
39
+ "aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"),
40
+ }
41
+
42
+ MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"]
43
+
44
+
45
+ @app.route("/")
46
+ def index():
47
+ """Serve the main HTML file"""
48
+ return send_from_directory(".", "index.html")
49
+
50
+
51
+ @app.route("/api/datasets", methods=["GET"])
52
+ def get_datasets():
53
+ """Get available datasets"""
54
+ datasets = []
55
+ for dataset_name, path in GRAPH_PATHS.items():
56
+ if os.path.exists(path):
57
+ # Count available graphs
58
+ pkl_files = glob.glob(os.path.join(path, "*.pkl"))
59
+ datasets.append(
60
+ {
61
+ "name": dataset_name,
62
+ "display_name": dataset_name.upper(),
63
+ "path": path,
64
+ "count": len(pkl_files),
65
+ }
66
+ )
67
+ return jsonify({"datasets": datasets})
68
+
69
+
70
+ @app.route("/api/models", methods=["GET"])
71
+ def get_models():
72
+ """Get available models for a specific entity"""
73
+ entity = request.args.get("entity")
74
+ dataset = request.args.get("dataset")
75
+
76
+ if not entity:
77
+ return jsonify({"error": "Entity parameter required"}), 400
78
+
79
+ if not dataset or dataset not in GRAPH_PATHS:
80
+ return jsonify({"error": "Invalid dataset"}), 400
81
+
82
+ path = GRAPH_PATHS[dataset]
83
+ if not os.path.exists(path):
84
+ return jsonify({"error": "Dataset path not found"}), 404
85
+
86
+ models = []
87
+ pkl_files = glob.glob(os.path.join(path, "*.pkl"))
88
+
89
+ for pkl_file in pkl_files:
90
+ filename = os.path.basename(pkl_file)
91
+
92
+ # Different filename patterns for different datasets
93
+ if dataset == "bio":
94
+ # Format: bio_graph_{entity}_merged_{model}.pkl
95
+ match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
96
+ if match:
97
+ entity_name, model = match.groups()
98
+ if entity_name == entity:
99
+ models.append({"model": model, "filename": filename, "filepath": pkl_file})
100
+ elif dataset == "fp":
101
+ # Format: fp_graph_{number}_merged_{model}.pkl
102
+ match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
103
+ if match:
104
+ entity_name, model = match.groups()
105
+ if f"Problem {entity_name}" == entity:
106
+ models.append({"model": model, "filename": filename, "filepath": pkl_file})
107
+ elif dataset == "math":
108
+ # Format: qwen72_math_{number}.pkl
109
+ match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
110
+ if match:
111
+ entity_name = match.group(1)
112
+ if f"Math Problem {entity_name}" == entity:
113
+ models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
114
+ elif dataset == "aime":
115
+ # Format: aime_qwen72b_{number}.pkl
116
+ match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
117
+ if match:
118
+ entity_name = match.group(1)
119
+ if f"AIME Problem {entity_name}" == entity:
120
+ models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
121
+ else:
122
+ # Generic pattern for other datasets
123
+ match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
124
+ if match:
125
+ task, entity_name, model = match.groups()
126
+ if entity_name == entity:
127
+ models.append({"model": model, "filename": filename, "filepath": pkl_file})
128
+
129
+ return jsonify({"models": models})
130
+
131
+
132
+ @app.route("/api/entities", methods=["GET"])
133
+ def get_entities():
134
+ """Get available entities for a dataset"""
135
+ dataset = request.args.get("dataset")
136
+ if not dataset or dataset not in GRAPH_PATHS:
137
+ return jsonify({"error": "Invalid dataset"}), 400
138
+
139
+ path = GRAPH_PATHS[dataset]
140
+ if not os.path.exists(path):
141
+ return jsonify({"error": "Dataset path not found"}), 404
142
+
143
+ entities = []
144
+ pkl_files = glob.glob(os.path.join(path, "*.pkl"))
145
+
146
+ for pkl_file in pkl_files:
147
+ filename = os.path.basename(pkl_file)
148
+
149
+ # Different filename patterns for different datasets
150
+ if dataset == "bio":
151
+ # Format: bio_graph_{entity}_merged_{model}.pkl
152
+ match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
153
+ if match:
154
+ entity_name, model = match.groups()
155
+ entities.append(
156
+ {
157
+ "entity": entity_name,
158
+ "model": model,
159
+ "filename": filename,
160
+ "filepath": pkl_file,
161
+ }
162
+ )
163
+ elif dataset == "fp":
164
+ # Format: fp_graph_{number}_merged_{model}.pkl
165
+ match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
166
+ if match:
167
+ entity_name, model = match.groups()
168
+ entities.append(
169
+ {
170
+ "entity": f"Problem {entity_name}",
171
+ "model": model,
172
+ "filename": filename,
173
+ "filepath": pkl_file,
174
+ }
175
+ )
176
+ elif dataset == "math":
177
+ # Format: qwen72_math_{number}.pkl
178
+ match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
179
+ if match:
180
+ entity_name = match.group(1)
181
+ entities.append(
182
+ {
183
+ "entity": f"Math Problem {entity_name}",
184
+ "model": "qwen72b",
185
+ "filename": filename,
186
+ "filepath": pkl_file,
187
+ }
188
+ )
189
+ elif dataset == "aime":
190
+ # Format: aime_qwen72b_{number}.pkl
191
+ match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
192
+ if match:
193
+ entity_name = match.group(1)
194
+ entities.append(
195
+ {
196
+ "entity": f"AIME Problem {entity_name}",
197
+ "model": "qwen72b",
198
+ "filename": filename,
199
+ "filepath": pkl_file,
200
+ }
201
+ )
202
+ else:
203
+ # Generic pattern for other datasets
204
+ match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
205
+ if match:
206
+ task, entity_name, model = match.groups()
207
+ entities.append(
208
+ {
209
+ "entity": entity_name,
210
+ "model": model,
211
+ "filename": filename,
212
+ "filepath": pkl_file,
213
+ }
214
+ )
215
+
216
+ return jsonify({"entities": entities})
217
+
218
+
219
+ @app.route("/api/load_existing_graph", methods=["POST"])
220
+ def load_existing_graph():
221
+ """Load an existing graph from the stored pickle files"""
222
+ try:
223
+ data = request.get_json()
224
+ filepath = data.get("filepath")
225
+
226
+ if not filepath or not os.path.exists(filepath):
227
+ return jsonify({"error": "Graph file not found"}), 404
228
+
229
+ # Read and load the pickle file
230
+ try:
231
+ with open(filepath, "rb") as f:
232
+ graph = pickle.load(f)
233
+ except Exception as e:
234
+ return jsonify({"error": f"Error loading pickle file: {str(e)}"}), 500
235
+
236
+ if not isinstance(graph, TextPOAGraph):
237
+ return jsonify({"error": "File does not contain a valid POA graph"}), 400
238
+
239
+ # Convert to JSON format for vis.js
240
+ nodes = []
241
+ edges = []
242
+
243
+ try:
244
+ # Get consensus nodes for coloring
245
+ consensus_nodes = set(graph.consensus_node_ids)
246
+
247
+ # Create nodes using the same logic as jsOutput
248
+ for node in graph.nodeiterator()():
249
+ title_text = ""
250
+ if node.sequences:
251
+ title_text += f"Sequences: {node.sequences}"
252
+ if node.variations:
253
+ title_text += ";;;".join(
254
+ [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
255
+ )
256
+ title_text = title_text.replace('"', "'")
257
+
258
+ # Use the same color logic as jsOutput
259
+ color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
260
+
261
+ node_data = {
262
+ "id": node.ID,
263
+ "label": f"{node.ID}: {node.text}",
264
+ "title": title_text,
265
+ "color": color,
266
+ }
267
+ nodes.append(node_data)
268
+
269
+ # Create edges using the same logic as jsOutput
270
+ for node in graph.nodeiterator()():
271
+ nodeID = node.ID # Keep as integer
272
+ for edge in node.outEdges:
273
+ target = edge # Keep as integer
274
+ weight = node.outEdges[edge].weight + 1.5
275
+ edge_data = {
276
+ "from": nodeID,
277
+ "to": target,
278
+ "value": weight,
279
+ "color": "#cae0e6",
280
+ "arrows": "to",
281
+ }
282
+ edges.append(edge_data)
283
+ except Exception as e:
284
+ return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
285
+
286
+ # Extract metadata from filename
287
+ filename = os.path.basename(filepath)
288
+ metadata = {}
289
+
290
+ try:
291
+ # Different filename patterns for different datasets
292
+ if filename.startswith("bio_graph_"):
293
+ # Format: bio_graph_{entity}_merged_{model}.pkl
294
+ match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
295
+ if match:
296
+ entity_name, model = match.groups()
297
+ metadata = {
298
+ "task": "bio",
299
+ "entity": entity_name,
300
+ "model": model,
301
+ "filename": filename,
302
+ }
303
+ elif filename.startswith("fp_graph_"):
304
+ # Format: fp_graph_{number}_merged_{model}.pkl
305
+ match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
306
+ if match:
307
+ entity_name, model = match.groups()
308
+ metadata = {
309
+ "task": "fp",
310
+ "entity": f"Problem {entity_name}",
311
+ "model": model,
312
+ "filename": filename,
313
+ }
314
+ elif filename.startswith("qwen72_math_"):
315
+ # Format: qwen72_math_{number}.pkl
316
+ match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
317
+ if match:
318
+ entity_name = match.group(1)
319
+ metadata = {
320
+ "task": "math",
321
+ "entity": f"Math Problem {entity_name}",
322
+ "model": "qwen72b",
323
+ "filename": filename,
324
+ }
325
+ elif filename.startswith("aime_qwen72b_"):
326
+ # Format: aime_qwen72b_{number}.pkl
327
+ match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
328
+ if match:
329
+ entity_name = match.group(1)
330
+ metadata = {
331
+ "task": "aime",
332
+ "entity": f"AIME Problem {entity_name}",
333
+ "model": "qwen72b",
334
+ "filename": filename,
335
+ }
336
+ else:
337
+ # Generic pattern for other datasets
338
+ match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
339
+ if match:
340
+ task, entity_name, model = match.groups()
341
+ metadata = {
342
+ "task": task,
343
+ "entity": entity_name,
344
+ "model": model,
345
+ "filename": filename,
346
+ }
347
+ except Exception:
348
+ # Don't fail the request if metadata extraction fails
349
+ pass
350
+
351
+ # Extract text from consensus nodes
352
+ consensus_text = ""
353
+ try:
354
+ consensus_nodes = set(graph.consensus_node_ids)
355
+ consensus_node_texts = []
356
+ for node in graph.nodeiterator()():
357
+ if node.ID in consensus_nodes and node.text and node.text.strip():
358
+ consensus_node_texts.append(node.text.strip())
359
+ consensus_text = " ".join(consensus_node_texts)
360
+ except Exception:
361
+ consensus_text = ""
362
+
363
+ # Check if we should compute consensus using decode_consensus
364
+ compute_consensus = data.get("compute_consensus", False)
365
+ if compute_consensus and decode_consensus:
366
+ try:
367
+ # Determine task from metadata or default to "bio"
368
+ task = metadata.get("task", "bio") if metadata else "bio"
369
+ consensus_text = decode_consensus(graph, selection_threshold=0.5, task=task)
370
+ except Exception as e:
371
+ print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
372
+ # Keep the original consensus text if decode_consensus fails
373
+
374
+ # Get original sequences
375
+ try:
376
+ raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
377
+ # Process sequences: join with spaces and remove "||"
378
+ print(f"DEBUG: Raw sequences: {raw_sequences}")
379
+ original_sequences = []
380
+ for seq in raw_sequences:
381
+ if isinstance(seq, list):
382
+ # Join list elements with spaces
383
+ processed_seq = " ".join(str(item) for item in seq)
384
+ else:
385
+ processed_seq = str(seq)
386
+ # Remove "||" characters
387
+ processed_seq = processed_seq.replace("||", "")
388
+ print(f"DEBUG: Processed sequence: {processed_seq}")
389
+ original_sequences.append(processed_seq)
390
+ except Exception:
391
+ original_sequences = []
392
+
393
+ result = {
394
+ "success": True,
395
+ "nodes": nodes,
396
+ "edges": edges,
397
+ "num_sequences": graph.num_sequences,
398
+ "num_nodes": len(nodes),
399
+ "num_edges": len(edges),
400
+ "metadata": metadata,
401
+ "consensus_text": consensus_text,
402
+ "original_sequences": original_sequences,
403
+ }
404
+
405
+ return jsonify(result)
406
+
407
+ except Exception as e:
408
+ return jsonify({"error": str(e)}), 500
409
+
410
+
411
+ @app.route("/api/create_graph", methods=["POST"])
412
+ def create_graph():
413
+ """Create a POA graph from text sequences"""
414
+ try:
415
+ data = request.get_json()
416
+ sequences = data.get("sequences", [])
417
+
418
+ if len(sequences) < 2:
419
+ return jsonify({"error": "At least 2 sequences are required"}), 400
420
+
421
+ print(f"DEBUG: Creating graph with sequences: {sequences}")
422
+
423
+ # Create the graph with first sequence as string
424
+ graph = TextPOAGraph(sequences[0], label=0)
425
+ print("DEBUG: Initial graph created")
426
+
427
+ # Add remaining sequences
428
+ for i, sequence in enumerate(sequences[1:], 1):
429
+ print(f"DEBUG: Adding sequence {i}: {sequence}")
430
+ alignment = TextSeqGraphAlignment(
431
+ text=sequence,
432
+ graph=graph,
433
+ fastMethod=True,
434
+ globalAlign=True,
435
+ matchscore=1,
436
+ mismatchscore=-2,
437
+ gap_open=-1,
438
+ )
439
+ graph.incorporateSeqAlignment(alignment, sequence, label=i)
440
+
441
+ print("DEBUG: All sequences added")
442
+
443
+ # Refine the graph with proper domain and model parameters
444
+ graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini")
445
+ print("DEBUG: Graph refined")
446
+
447
+ # Convert to JSON format for vis.js
448
+ nodes = []
449
+ edges = []
450
+
451
+ try:
452
+ print("DEBUG: Starting to process graph data")
453
+ # Get consensus nodes for coloring (make it optional)
454
+ try:
455
+ consensus_nodes = set(graph.consensus_node_ids)
456
+ print(f"DEBUG: Consensus nodes: {consensus_nodes}")
457
+ except Exception as e:
458
+ print(f"DEBUG: Error getting consensus nodes: {e}")
459
+ consensus_nodes = set() # Fallback to empty set if consensus fails
460
+
461
+ # Create nodes using the same logic as jsOutput
462
+ for node in graph.nodeiterator()():
463
+ title_text = ""
464
+ if node.sequences:
465
+ title_text += f"Sequences: {node.sequences}"
466
+ if node.variations:
467
+ title_text += ";;;".join(
468
+ [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
469
+ )
470
+ title_text = title_text.replace('"', "'")
471
+
472
+ # Use the same color logic as jsOutput
473
+ color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
474
+
475
+ node_data = {
476
+ "id": node.ID,
477
+ "label": f"{node.ID}: {node.text}",
478
+ "title": title_text,
479
+ "color": color,
480
+ }
481
+ nodes.append(node_data)
482
+
483
+ print(f"DEBUG: Created {len(nodes)} nodes")
484
+
485
+ # Create edges using the same logic as jsOutput
486
+ for node in graph.nodeiterator()():
487
+ nodeID = node.ID # Keep as integer
488
+ for edge in node.outEdges:
489
+ target = edge # Keep as integer
490
+ weight = node.outEdges[edge].weight + 1.5
491
+ edge_data = {
492
+ "from": nodeID,
493
+ "to": target,
494
+ "value": weight,
495
+ "color": "#cae0e6",
496
+ "arrows": "to",
497
+ }
498
+ edges.append(edge_data)
499
+
500
+ print(f"DEBUG: Created {len(edges)} edges")
501
+ except Exception as e:
502
+ print(f"DEBUG: Error processing graph data: {e}")
503
+ return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
504
+
505
+ # Extract text from consensus nodes
506
+ consensus_text = ""
507
+ try:
508
+ consensus_node_texts = []
509
+ for node in graph.nodeiterator()():
510
+ if node.ID in consensus_nodes and node.text and node.text.strip():
511
+ consensus_node_texts.append(node.text.strip())
512
+ consensus_text = " ".join(consensus_node_texts)
513
+ except Exception:
514
+ consensus_text = ""
515
+
516
+ # Check if we should compute consensus using decode_consensus
517
+ compute_consensus = data.get("compute_consensus", False)
518
+ if compute_consensus and decode_consensus:
519
+ try:
520
+ # Default to "bio" task for new graphs
521
+ consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio")
522
+ except Exception as e:
523
+ print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
524
+ # Keep the original consensus text if decode_consensus fails
525
+
526
+ # Get original sequences
527
+ try:
528
+ raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
529
+ # Process sequences: join with spaces and remove "||"
530
+ original_sequences = []
531
+ for seq in raw_sequences:
532
+ if isinstance(seq, list):
533
+ # Join list elements with spaces
534
+ processed_seq = " ".join(str(item) for item in seq)
535
+ else:
536
+ processed_seq = str(seq)
537
+ # Remove "||" characters
538
+ processed_seq = processed_seq.replace("||", "")
539
+ original_sequences.append(processed_seq)
540
+ except Exception:
541
+ original_sequences = []
542
+
543
+ print("DEBUG: Returning success response")
544
+ return jsonify(
545
+ {
546
+ "success": True,
547
+ "nodes": nodes,
548
+ "edges": edges,
549
+ "num_sequences": len(sequences),
550
+ "num_nodes": len(nodes),
551
+ "num_edges": len(edges),
552
+ "original_sequences": original_sequences,
553
+ "consensus_text": consensus_text,
554
+ }
555
+ )
556
+
557
+ except Exception as e:
558
+ print(f"DEBUG: Main exception in create_graph: {e}")
559
+ return jsonify({"error": str(e)}), 500
560
+
561
+
562
+ @app.route("/api/save_graph", methods=["POST"])
563
+ def save_graph():
564
+ """Save a POA graph to a pickle file"""
565
+ try:
566
+ data = request.get_json()
567
+ sequences = data.get("sequences", [])
568
+ filename = data.get("filename", "graph.pkl")
569
+
570
+ if len(sequences) < 2:
571
+ return jsonify({"error": "At least 2 sequences are required"}), 400
572
+
573
+ # Create the graph
574
+ graph = TextPOAGraph(sequences[0], label=0)
575
+
576
+ # Add remaining sequences
577
+ for i, sequence in enumerate(sequences[1:], 1):
578
+ alignment = TextSeqGraphAlignment(
579
+ text=sequence,
580
+ graph=graph,
581
+ fastMethod=True,
582
+ globalAlign=True,
583
+ matchscore=1,
584
+ mismatchscore=-2,
585
+ gap_open=-1,
586
+ )
587
+ graph.incorporateSeqAlignment(alignment, sequence, label=i)
588
+
589
+ # Refine the graph
590
+ graph.refine_graph(verbose=False)
591
+
592
+ # Save to pickle file
593
+ graph.save_to_pickle(filename)
594
+
595
+ return jsonify(
596
+ {"success": True, "filename": filename, "message": f"Graph saved to {filename}"}
597
+ )
598
+
599
+ except Exception as e:
600
+ return jsonify({"error": str(e)}), 500
601
+
602
+
603
+ @app.route("/api/graph_info", methods=["POST"])
604
+ def graph_info():
605
+ """Get information about a graph without creating the full visualization"""
606
+ try:
607
+ data = request.get_json()
608
+ sequences = data.get("sequences", [])
609
+
610
+ if len(sequences) < 2:
611
+ return jsonify({"error": "At least 2 sequences are required"}), 400
612
+
613
+ # Create the graph
614
+ graph = TextPOAGraph(sequences[0], label=0)
615
+
616
+ # Add remaining sequences
617
+ for i, sequence in enumerate(sequences[1:], 1):
618
+ alignment = TextSeqGraphAlignment(
619
+ text=sequence,
620
+ graph=graph,
621
+ fastMethod=True,
622
+ globalAlign=True,
623
+ matchscore=1,
624
+ mismatchscore=-2,
625
+ gap_open=-1,
626
+ )
627
+ graph.incorporateSeqAlignment(alignment, sequence, label=i)
628
+
629
+ # Refine the graph
630
+ graph.refine_graph(verbose=False)
631
+
632
+ # Get consensus response
633
+ consensus_text = graph.consensus_response()
634
+
635
+ return jsonify(
636
+ {
637
+ "success": True,
638
+ "num_sequences": len(sequences),
639
+ "num_nodes": graph._nnodes,
640
+ "consensus_text": consensus_text,
641
+ "consensus_node_ids": graph.consensus_node_ids,
642
+ }
643
+ )
644
+
645
+ except Exception as e:
646
+ return jsonify({"error": str(e)}), 500
647
+
648
+
649
+ if __name__ == "__main__":
650
+ print("Starting POA Graph Web Interface Server...")
651
+ print("Open http://localhost:8080 in your browser")
652
+ app.run(debug=True, host="0.0.0.0", port=8080)
web_interface/start.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Starting POA Graph Web Interface..."
4
+
5
+ # Check if Python is installed
6
+ if ! command -v python3 &> /dev/null; then
7
+ echo "Error: Python 3 is not installed or not in PATH"
8
+ exit 1
9
+ fi
10
+
11
+ # Check if we're in the right directory
12
+ if [ ! -f "server.py" ]; then
13
+ echo "Error: server.py not found. Please run this script from the web_interface directory."
14
+ exit 1
15
+ fi
16
+
17
+ # Install dependencies if requirements.txt exists
18
+ if [ -f "requirements.txt" ]; then
19
+ echo "Installing dependencies..."
20
+ pip3 install -r requirements.txt
21
+ fi
22
+
23
+ # Start the server
24
+ echo "Starting server on http://localhost:5000"
25
+ echo "Press Ctrl+C to stop the server"
26
+ python3 server.py
web_interface/test_server.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for the POA Graph Web Interface Server
4
+ """
5
+
6
+ import requests
7
+ import json
8
+
9
+ BASE_URL = "http://localhost:8080"
10
+
11
+ def test_datasets():
12
+ """Test the datasets endpoint"""
13
+ print("Testing /api/datasets...")
14
+ try:
15
+ response = requests.get(f"{BASE_URL}/api/datasets")
16
+ if response.status_code == 200:
17
+ data = response.json()
18
+ print(f"βœ“ Success! Found {len(data['datasets'])} datasets:")
19
+ for dataset in data['datasets']:
20
+ print(f" - {dataset['display_name']}: {dataset['count']} graphs")
21
+ else:
22
+ print(f"βœ— Error: {response.status_code}")
23
+ except Exception as e:
24
+ print(f"βœ— Exception: {e}")
25
+
26
+ def test_entities():
27
+ """Test the entities endpoint"""
28
+ print("\nTesting /api/entities?dataset=bio...")
29
+ try:
30
+ response = requests.get(f"{BASE_URL}/api/entities?dataset=bio")
31
+ if response.status_code == 200:
32
+ data = response.json()
33
+ print(f"βœ“ Success! Found {len(data['entities'])} entities in bio dataset")
34
+ if data['entities']:
35
+ print(f" Sample entity: {data['entities'][0]['entity']} ({data['entities'][0]['model']})")
36
+ else:
37
+ print(f"βœ— Error: {response.status_code}")
38
+ except Exception as e:
39
+ print(f"βœ— Exception: {e}")
40
+
41
+ def test_load_graph():
42
+ """Test loading a specific graph"""
43
+ print("\nTesting /api/load_existing_graph...")
44
+ try:
45
+ # First get entities to find a valid filepath
46
+ response = requests.get(f"{BASE_URL}/api/entities?dataset=bio")
47
+ if response.status_code == 200:
48
+ data = response.json()
49
+ if data['entities']:
50
+ filepath = data['entities'][0]['filepath']
51
+ print(f" Loading graph: {filepath}")
52
+
53
+ # Test loading the graph
54
+ response = requests.post(
55
+ f"{BASE_URL}/api/load_existing_graph",
56
+ json={"filepath": filepath}
57
+ )
58
+
59
+ if response.status_code == 200:
60
+ graph_data = response.json()
61
+ if graph_data['success']:
62
+ print(f"βœ“ Success! Loaded graph with {graph_data['num_nodes']} nodes and {graph_data['num_edges']} edges")
63
+ print(f" Entity: {graph_data['metadata']['entity']}")
64
+ print(f" Model: {graph_data['metadata']['model']}")
65
+ else:
66
+ print(f"βœ— Error: {graph_data['error']}")
67
+ else:
68
+ print(f"βœ— Error: {response.status_code}")
69
+ else:
70
+ print("βœ— No entities found to test with")
71
+ else:
72
+ print(f"βœ— Error getting entities: {response.status_code}")
73
+ except Exception as e:
74
+ print(f"βœ— Exception: {e}")
75
+
76
+ if __name__ == "__main__":
77
+ print("Testing POA Graph Web Interface Server...")
78
+ print("Make sure the server is running on http://localhost:8080")
79
+ print("=" * 50)
80
+
81
+ test_datasets()
82
+ test_entities()
83
+ test_load_graph()
84
+
85
+ print("\n" + "=" * 50)
86
+ print("Test completed!")