Spaces:
Sleeping
Sleeping
Commit Β·
102ae18
1
Parent(s): 2cb9f34
initial commit
Browse files- Dockerfile +28 -0
- README.md +150 -7
- app.py +584 -0
- dockerignore +46 -0
- requirements.txt +13 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/alignment.cpython-312.pyc +0 -0
- src/__pycache__/generation_methods.cpython-312.pyc +0 -0
- src/__pycache__/generation_utils.cpython-312.pyc +0 -0
- src/__pycache__/global_edit_utils.cpython-312.pyc +0 -0
- src/__pycache__/new_alignment.cpython-312.pyc +0 -0
- src/__pycache__/new_text_alignment.cpython-312.pyc +0 -0
- src/__pycache__/poa_graph.cpython-312.pyc +0 -0
- src/__pycache__/text_poa_graph.cpython-312.pyc +0 -0
- src/__pycache__/text_poa_graph_utils.cpython-312.pyc +0 -0
- src/alignment.py +256 -0
- src/generation_methods.py +299 -0
- src/generation_utils.py +190 -0
- src/global_edit_utils.py +127 -0
- src/new_alignment.py +150 -0
- src/new_text_alignment.py +134 -0
- src/poa_graph.py +685 -0
- src/text_poa_graph.py +802 -0
- src/text_poa_graph_utils.py +126 -0
- src/utils.py +46 -0
- web_interface/.DS_Store +0 -0
- web_interface/README.md +111 -0
- web_interface/index.html +907 -0
- web_interface/requirements.txt +4 -0
- web_interface/server.py +652 -0
- web_interface/start.sh +26 -0
- web_interface/test_server.py +86 -0
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
build-essential \
|
| 9 |
+
git \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements first for better caching
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Download spaCy model
|
| 19 |
+
RUN python -m spacy download en_core_web_sm
|
| 20 |
+
|
| 21 |
+
# Copy the entire application
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# Expose port 7860 (required by Hugging Face Spaces)
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Use gunicorn for production
|
| 28 |
+
CMD ["gunicorn", "-b", "0.0.0.0:7860", "--timeout", "120", "--workers", "2", "app:app"]
|
README.md
CHANGED
|
@@ -1,12 +1,155 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
license:
|
| 9 |
-
short_description: Explore and visualize ConGrs (https://www.google.com/search)
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ConGr Visualizer
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
license: mit
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# ConGr Visualizer
|
| 13 |
+
|
| 14 |
+
A standalone web-based interface for exploring and visualizing ConGrs (Consensus Graphs) from research datasets.
|
| 15 |
+
|
| 16 |
+
## Overview
|
| 17 |
+
|
| 18 |
+
This repository contains the web interface and necessary dependencies for visualizing ConGrs. It has been separated from the main sample-fusion repository to provide a standalone visualization tool.
|
| 19 |
+
|
| 20 |
+
## Features
|
| 21 |
+
|
| 22 |
+
### Browse Existing Graphs
|
| 23 |
+
- **Dataset Selection**: Choose from available datasets (BIO, FP, HIST, REFS, MATH, AIME)
|
| 24 |
+
- **Entity Selection**: Browse entities within each dataset
|
| 25 |
+
- **Model Information**: See which language models were used for each graph
|
| 26 |
+
- **Graph Visualization**: Interactive network visualization using vis.js
|
| 27 |
+
- **Metadata Display**: View graph statistics and consensus text
|
| 28 |
+
|
| 29 |
+
### Create New Graphs
|
| 30 |
+
- **Text Input**: Enter multiple text sequences to create new ConGrs
|
| 31 |
+
- **Real-time Visualization**: See the graph structure as it's created
|
| 32 |
+
- **Save Functionality**: Save created graphs to pickle files
|
| 33 |
+
|
| 34 |
+
## Deployment on Hugging Face Spaces
|
| 35 |
+
|
| 36 |
+
This application is configured to run on Hugging Face Spaces using Docker.
|
| 37 |
+
|
| 38 |
+
### Project Structure
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
congr-visualizer/
|
| 42 |
+
βββ Dockerfile # Docker configuration for HF Spaces
|
| 43 |
+
βββ app.py # Main Flask application
|
| 44 |
+
βββ requirements.txt # Python dependencies
|
| 45 |
+
βββ README.md # This file
|
| 46 |
+
βββ web_interface/ # Web interface files
|
| 47 |
+
β βββ index.html # Web interface
|
| 48 |
+
βββ src/ # Source code modules
|
| 49 |
+
β βββ alignment.py
|
| 50 |
+
β βββ new_alignment.py
|
| 51 |
+
β βββ poa_graph.py
|
| 52 |
+
β βββ new_text_alignment.py
|
| 53 |
+
β βββ text_poa_graph.py
|
| 54 |
+
β βββ text_poa_graph_utils.py
|
| 55 |
+
β βββ global_edit_utils.py
|
| 56 |
+
β βββ generation_utils.py
|
| 57 |
+
β βββ generation_methods.py
|
| 58 |
+
β βββ utils.py
|
| 59 |
+
βββ results/ # Graph data
|
| 60 |
+
βββ graphs/
|
| 61 |
+
βββ HALoGEN/
|
| 62 |
+
βββ bio/
|
| 63 |
+
βββ fp/
|
| 64 |
+
βββ hist/
|
| 65 |
+
βββ refs/
|
| 66 |
+
βββ MATH/
|
| 67 |
+
βββ AIME/
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Local Development
|
| 71 |
+
|
| 72 |
+
If you want to run this locally:
|
| 73 |
+
|
| 74 |
+
1. Clone the repository:
|
| 75 |
+
```bash
|
| 76 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/congr-visualizer
|
| 77 |
+
cd congr-visualizer
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
2. Install dependencies:
|
| 81 |
+
```bash
|
| 82 |
+
pip install -r requirements.txt
|
| 83 |
+
python -m spacy download en_core_web_sm
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
3. Run the application:
|
| 87 |
+
```bash
|
| 88 |
+
python app.py
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
The server will start on `http://localhost:7860`
|
| 92 |
+
|
| 93 |
+
## Available Datasets
|
| 94 |
+
|
| 95 |
+
- **BIO**: Biography datasets with various public figures
|
| 96 |
+
- **FP**: False Presupposition datasets
|
| 97 |
+
- **HIST**: Historical events datasets
|
| 98 |
+
- **REFS**: Reference datasets
|
| 99 |
+
- **MATH**: Mathematical problem datasets
|
| 100 |
+
- **AIME**: American Invitational Mathematics Examination datasets
|
| 101 |
+
|
| 102 |
+
## Models
|
| 103 |
+
|
| 104 |
+
The graphs are generated using various language models:
|
| 105 |
+
- olmo7b, olmo32b
|
| 106 |
+
- qwen72b, qwen7b
|
| 107 |
+
- llama70b, llama8b
|
| 108 |
+
|
| 109 |
+
## API Endpoints
|
| 110 |
+
|
| 111 |
+
- `GET /api/datasets` - Get available datasets
|
| 112 |
+
- `GET /api/entities?dataset=<dataset>` - Get entities for a dataset
|
| 113 |
+
- `GET /api/models?dataset=<dataset>&entity=<entity>` - Get models for an entity
|
| 114 |
+
- `POST /api/load_existing_graph` - Load an existing graph
|
| 115 |
+
- `POST /api/create_graph` - Create a new graph from text sequences
|
| 116 |
+
- `POST /api/save_graph` - Save a graph to file
|
| 117 |
+
- `POST /api/graph_info` - Get graph information without full visualization
|
| 118 |
+
|
| 119 |
+
## Graph Information
|
| 120 |
+
|
| 121 |
+
When viewing a graph, you can see:
|
| 122 |
+
- **Dataset**: The source dataset
|
| 123 |
+
- **Entity**: The specific entity or topic
|
| 124 |
+
- **Model**: The language model used
|
| 125 |
+
- **Sequences**: Number of input sequences
|
| 126 |
+
- **Nodes**: Number of nodes in the graph
|
| 127 |
+
- **Edges**: Number of edges in the graph
|
| 128 |
+
- **Consensus**: The consensus text generated from the graph
|
| 129 |
+
|
| 130 |
+
## Visualization Features
|
| 131 |
+
|
| 132 |
+
- **Hierarchical Layout**: Graphs are displayed in a hierarchical structure
|
| 133 |
+
- **Color Coding**: Consensus nodes are highlighted in green
|
| 134 |
+
- **Interactive**: Zoom, pan, and hover for more information
|
| 135 |
+
- **Responsive**: Works on desktop and mobile devices
|
| 136 |
+
|
| 137 |
+
## Environment Variables
|
| 138 |
+
|
| 139 |
+
For full functionality (especially consensus decoding), you may need to set:
|
| 140 |
+
- `OPENAI_API_KEY`: For OpenAI API calls
|
| 141 |
+
- `HUGGINGFACE_API_KEY`: For HuggingFace API calls
|
| 142 |
+
|
| 143 |
+
These can be set in the Hugging Face Spaces settings.
|
| 144 |
+
|
| 145 |
+
## Technical Details
|
| 146 |
+
|
| 147 |
+
- **Framework**: Flask with CORS enabled
|
| 148 |
+
- **Server**: Gunicorn for production
|
| 149 |
+
- **Port**: 7860 (required by Hugging Face Spaces)
|
| 150 |
+
- **Visualization**: vis.js for interactive graph rendering
|
| 151 |
+
- **Graph Format**: Pickle files for serialized POA graphs
|
| 152 |
+
|
| 153 |
+
## License
|
| 154 |
+
|
| 155 |
+
This is a standalone visualization tool extracted from the sample-fusion research project.
|
app.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Flask server for POA Graph Web Interface
|
| 4 |
+
Modified for Hugging Face Spaces deployment
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import glob
|
| 8 |
+
import os
|
| 9 |
+
import pickle
|
| 10 |
+
import re
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
from flask import Flask, jsonify, request, send_from_directory
|
| 14 |
+
from flask_cors import CORS
|
| 15 |
+
|
| 16 |
+
# Get the directory where this script is located (should be project root)
|
| 17 |
+
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
|
| 19 |
+
# Add the repository root to the path so we can import the POA graph modules
|
| 20 |
+
sys.path.append(REPO_ROOT)
|
| 21 |
+
|
| 22 |
+
from src.new_text_alignment import TextSeqGraphAlignment
|
| 23 |
+
from src.text_poa_graph import TextPOAGraph
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from src.generation_methods import decode_consensus
|
| 27 |
+
except ImportError:
|
| 28 |
+
decode_consensus = None
|
| 29 |
+
|
| 30 |
+
app = Flask(__name__)
|
| 31 |
+
CORS(app) # Enable CORS for all routes
|
| 32 |
+
|
| 33 |
+
# Base paths for different datasets (relative to repo root)
|
| 34 |
+
GRAPH_PATHS = {
|
| 35 |
+
"bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"),
|
| 36 |
+
"fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"),
|
| 37 |
+
"hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"),
|
| 38 |
+
"refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"),
|
| 39 |
+
"math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"),
|
| 40 |
+
"aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@app.route("/")
|
| 47 |
+
def index():
|
| 48 |
+
"""Serve the main HTML file from web_interface directory"""
|
| 49 |
+
web_interface_path = os.path.join(REPO_ROOT, "web_interface")
|
| 50 |
+
return send_from_directory(web_interface_path, "index.html")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@app.route("/<path:path>")
|
| 54 |
+
def serve_static(path):
|
| 55 |
+
"""Serve static files from web_interface directory"""
|
| 56 |
+
web_interface_path = os.path.join(REPO_ROOT, "web_interface")
|
| 57 |
+
return send_from_directory(web_interface_path, path)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.route("/api/datasets", methods=["GET"])
|
| 61 |
+
def get_datasets():
|
| 62 |
+
"""Get available datasets"""
|
| 63 |
+
datasets = []
|
| 64 |
+
for dataset_name, path in GRAPH_PATHS.items():
|
| 65 |
+
if os.path.exists(path):
|
| 66 |
+
# Count available graphs
|
| 67 |
+
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
|
| 68 |
+
datasets.append(
|
| 69 |
+
{
|
| 70 |
+
"name": dataset_name,
|
| 71 |
+
"display_name": dataset_name.upper(),
|
| 72 |
+
"path": path,
|
| 73 |
+
"count": len(pkl_files),
|
| 74 |
+
}
|
| 75 |
+
)
|
| 76 |
+
return jsonify({"datasets": datasets})
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@app.route("/api/models", methods=["GET"])
|
| 80 |
+
def get_models():
|
| 81 |
+
"""Get available models for a specific entity"""
|
| 82 |
+
entity = request.args.get("entity")
|
| 83 |
+
dataset = request.args.get("dataset")
|
| 84 |
+
|
| 85 |
+
if not entity:
|
| 86 |
+
return jsonify({"error": "Entity parameter required"}), 400
|
| 87 |
+
|
| 88 |
+
if not dataset or dataset not in GRAPH_PATHS:
|
| 89 |
+
return jsonify({"error": "Invalid dataset"}), 400
|
| 90 |
+
|
| 91 |
+
path = GRAPH_PATHS[dataset]
|
| 92 |
+
if not os.path.exists(path):
|
| 93 |
+
return jsonify({"error": "Dataset path not found"}), 404
|
| 94 |
+
|
| 95 |
+
models = []
|
| 96 |
+
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
|
| 97 |
+
|
| 98 |
+
for pkl_file in pkl_files:
|
| 99 |
+
filename = os.path.basename(pkl_file)
|
| 100 |
+
|
| 101 |
+
# Different filename patterns for different datasets
|
| 102 |
+
if dataset == "bio":
|
| 103 |
+
# Format: bio_graph_{entity}_merged_{model}.pkl
|
| 104 |
+
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 105 |
+
if match:
|
| 106 |
+
entity_name, model = match.groups()
|
| 107 |
+
if entity_name == entity:
|
| 108 |
+
models.append({"model": model, "filename": filename, "filepath": pkl_file})
|
| 109 |
+
elif dataset == "fp":
|
| 110 |
+
# Format: fp_graph_{number}_merged_{model}.pkl
|
| 111 |
+
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
|
| 112 |
+
if match:
|
| 113 |
+
entity_name, model = match.groups()
|
| 114 |
+
if f"Problem {entity_name}" == entity:
|
| 115 |
+
models.append({"model": model, "filename": filename, "filepath": pkl_file})
|
| 116 |
+
elif dataset == "math":
|
| 117 |
+
# Format: qwen72_math_{number}.pkl
|
| 118 |
+
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
|
| 119 |
+
if match:
|
| 120 |
+
entity_name = match.group(1)
|
| 121 |
+
if f"Math Problem {entity_name}" == entity:
|
| 122 |
+
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
|
| 123 |
+
elif dataset == "aime":
|
| 124 |
+
# Format: aime_qwen72b_{number}.pkl
|
| 125 |
+
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
|
| 126 |
+
if match:
|
| 127 |
+
entity_name = match.group(1)
|
| 128 |
+
if f"AIME Problem {entity_name}" == entity:
|
| 129 |
+
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
|
| 130 |
+
else:
|
| 131 |
+
# Generic pattern for other datasets
|
| 132 |
+
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 133 |
+
if match:
|
| 134 |
+
task, entity_name, model = match.groups()
|
| 135 |
+
if entity_name == entity:
|
| 136 |
+
models.append({"model": model, "filename": filename, "filepath": pkl_file})
|
| 137 |
+
|
| 138 |
+
return jsonify({"models": models})
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@app.route("/api/entities", methods=["GET"])
|
| 142 |
+
def get_entities():
|
| 143 |
+
"""Get available entities for a dataset"""
|
| 144 |
+
dataset = request.args.get("dataset")
|
| 145 |
+
if not dataset or dataset not in GRAPH_PATHS:
|
| 146 |
+
return jsonify({"error": "Invalid dataset"}), 400
|
| 147 |
+
|
| 148 |
+
path = GRAPH_PATHS[dataset]
|
| 149 |
+
if not os.path.exists(path):
|
| 150 |
+
return jsonify({"error": "Dataset path not found"}), 404
|
| 151 |
+
|
| 152 |
+
entities = []
|
| 153 |
+
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
|
| 154 |
+
|
| 155 |
+
for pkl_file in pkl_files:
|
| 156 |
+
filename = os.path.basename(pkl_file)
|
| 157 |
+
|
| 158 |
+
# Different filename patterns for different datasets
|
| 159 |
+
if dataset == "bio":
|
| 160 |
+
# Format: bio_graph_{entity}_merged_{model}.pkl
|
| 161 |
+
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 162 |
+
if match:
|
| 163 |
+
entity_name, model = match.groups()
|
| 164 |
+
entities.append(
|
| 165 |
+
{
|
| 166 |
+
"entity": entity_name,
|
| 167 |
+
"model": model,
|
| 168 |
+
"filename": filename,
|
| 169 |
+
"filepath": pkl_file,
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
elif dataset == "fp":
|
| 173 |
+
# Format: fp_graph_{number}_merged_{model}.pkl
|
| 174 |
+
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
|
| 175 |
+
if match:
|
| 176 |
+
entity_name, model = match.groups()
|
| 177 |
+
entities.append(
|
| 178 |
+
{
|
| 179 |
+
"entity": f"Problem {entity_name}",
|
| 180 |
+
"model": model,
|
| 181 |
+
"filename": filename,
|
| 182 |
+
"filepath": pkl_file,
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
elif dataset == "math":
|
| 186 |
+
# Format: qwen72_math_{number}.pkl
|
| 187 |
+
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
|
| 188 |
+
if match:
|
| 189 |
+
entity_name = match.group(1)
|
| 190 |
+
entities.append(
|
| 191 |
+
{
|
| 192 |
+
"entity": f"Math Problem {entity_name}",
|
| 193 |
+
"model": "qwen72b",
|
| 194 |
+
"filename": filename,
|
| 195 |
+
"filepath": pkl_file,
|
| 196 |
+
}
|
| 197 |
+
)
|
| 198 |
+
elif dataset == "aime":
|
| 199 |
+
# Format: aime_qwen72b_{number}.pkl
|
| 200 |
+
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
|
| 201 |
+
if match:
|
| 202 |
+
entity_name = match.group(1)
|
| 203 |
+
entities.append(
|
| 204 |
+
{
|
| 205 |
+
"entity": f"AIME Problem {entity_name}",
|
| 206 |
+
"model": "qwen72b",
|
| 207 |
+
"filename": filename,
|
| 208 |
+
"filepath": pkl_file,
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
# Generic pattern for other datasets
|
| 213 |
+
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 214 |
+
if match:
|
| 215 |
+
task, entity_name, model = match.groups()
|
| 216 |
+
entities.append(
|
| 217 |
+
{
|
| 218 |
+
"entity": entity_name,
|
| 219 |
+
"model": model,
|
| 220 |
+
"filename": filename,
|
| 221 |
+
"filepath": pkl_file,
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Get unique entities
|
| 226 |
+
unique_entities = {}
|
| 227 |
+
for entity_data in entities:
|
| 228 |
+
entity_key = entity_data["entity"]
|
| 229 |
+
if entity_key not in unique_entities:
|
| 230 |
+
unique_entities[entity_key] = entity_data
|
| 231 |
+
|
| 232 |
+
return jsonify({"entities": list(unique_entities.values())})
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@app.route("/api/load_existing_graph", methods=["POST"])
|
| 236 |
+
def load_existing_graph():
|
| 237 |
+
"""Load an existing graph from pickle file"""
|
| 238 |
+
try:
|
| 239 |
+
data = request.get_json()
|
| 240 |
+
filepath = data.get("filepath")
|
| 241 |
+
|
| 242 |
+
if not filepath or not os.path.exists(filepath):
|
| 243 |
+
return jsonify({"error": "Invalid filepath"}), 400
|
| 244 |
+
|
| 245 |
+
# Load the graph from pickle
|
| 246 |
+
with open(filepath, "rb") as f:
|
| 247 |
+
graph = pickle.load(f)
|
| 248 |
+
|
| 249 |
+
# Convert to JSON format for vis.js
|
| 250 |
+
nodes = []
|
| 251 |
+
edges = []
|
| 252 |
+
|
| 253 |
+
# Get consensus nodes for coloring
|
| 254 |
+
try:
|
| 255 |
+
consensus_nodes = set(graph.consensus_node_ids)
|
| 256 |
+
except Exception:
|
| 257 |
+
consensus_nodes = set()
|
| 258 |
+
|
| 259 |
+
# Create nodes
|
| 260 |
+
for node in graph.nodeiterator()():
|
| 261 |
+
title_text = ""
|
| 262 |
+
if node.sequences:
|
| 263 |
+
title_text += f"Sequences: {node.sequences}"
|
| 264 |
+
if node.variations:
|
| 265 |
+
title_text += ";;;".join(
|
| 266 |
+
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
|
| 267 |
+
)
|
| 268 |
+
title_text = title_text.replace('"', "'")
|
| 269 |
+
|
| 270 |
+
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
|
| 271 |
+
|
| 272 |
+
node_data = {
|
| 273 |
+
"id": node.ID,
|
| 274 |
+
"label": f"{node.ID}: {node.text}",
|
| 275 |
+
"title": title_text,
|
| 276 |
+
"color": color,
|
| 277 |
+
}
|
| 278 |
+
nodes.append(node_data)
|
| 279 |
+
|
| 280 |
+
# Create edges
|
| 281 |
+
for node in graph.nodeiterator()():
|
| 282 |
+
nodeID = node.ID
|
| 283 |
+
for edge in node.outEdges:
|
| 284 |
+
target = edge
|
| 285 |
+
weight = node.outEdges[edge].weight + 1.5
|
| 286 |
+
edge_data = {
|
| 287 |
+
"from": nodeID,
|
| 288 |
+
"to": target,
|
| 289 |
+
"value": weight,
|
| 290 |
+
"color": "#cae0e6",
|
| 291 |
+
"arrows": "to",
|
| 292 |
+
}
|
| 293 |
+
edges.append(edge_data)
|
| 294 |
+
|
| 295 |
+
# Get consensus text
|
| 296 |
+
consensus_text = ""
|
| 297 |
+
try:
|
| 298 |
+
consensus_node_texts = []
|
| 299 |
+
for node in graph.nodeiterator()():
|
| 300 |
+
if node.ID in consensus_nodes and node.text and node.text.strip():
|
| 301 |
+
consensus_node_texts.append(node.text.strip())
|
| 302 |
+
consensus_text = " ".join(consensus_node_texts)
|
| 303 |
+
except Exception:
|
| 304 |
+
consensus_text = ""
|
| 305 |
+
|
| 306 |
+
# Get original sequences
|
| 307 |
+
try:
|
| 308 |
+
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
|
| 309 |
+
original_sequences = []
|
| 310 |
+
for seq in raw_sequences:
|
| 311 |
+
if isinstance(seq, list):
|
| 312 |
+
processed_seq = " ".join(str(item) for item in seq)
|
| 313 |
+
else:
|
| 314 |
+
processed_seq = str(seq)
|
| 315 |
+
processed_seq = processed_seq.replace("||", "")
|
| 316 |
+
original_sequences.append(processed_seq)
|
| 317 |
+
except Exception:
|
| 318 |
+
original_sequences = []
|
| 319 |
+
|
| 320 |
+
return jsonify(
|
| 321 |
+
{
|
| 322 |
+
"success": True,
|
| 323 |
+
"nodes": nodes,
|
| 324 |
+
"edges": edges,
|
| 325 |
+
"num_sequences": len(original_sequences),
|
| 326 |
+
"num_nodes": len(nodes),
|
| 327 |
+
"num_edges": len(edges),
|
| 328 |
+
"original_sequences": original_sequences,
|
| 329 |
+
"consensus_text": consensus_text,
|
| 330 |
+
}
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
return jsonify({"error": str(e)}), 500
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
@app.route("/api/create_graph", methods=["POST"])
|
| 338 |
+
def create_graph():
|
| 339 |
+
"""Create a new POA graph from text sequences"""
|
| 340 |
+
try:
|
| 341 |
+
print("DEBUG: Received create_graph request")
|
| 342 |
+
data = request.get_json()
|
| 343 |
+
sequences = data.get("sequences", [])
|
| 344 |
+
|
| 345 |
+
print(f"DEBUG: Number of sequences: {len(sequences)}")
|
| 346 |
+
|
| 347 |
+
if len(sequences) < 2:
|
| 348 |
+
return jsonify({"error": "At least 2 sequences are required"}), 400
|
| 349 |
+
|
| 350 |
+
print("DEBUG: Creating initial graph")
|
| 351 |
+
# Create the graph from first sequence
|
| 352 |
+
graph = TextPOAGraph(sequences[0], label=0)
|
| 353 |
+
print("DEBUG: Initial graph created")
|
| 354 |
+
|
| 355 |
+
# Add remaining sequences
|
| 356 |
+
for i, sequence in enumerate(sequences[1:], 1):
|
| 357 |
+
print(f"DEBUG: Adding sequence {i}")
|
| 358 |
+
alignment = TextSeqGraphAlignment(
|
| 359 |
+
text=sequence,
|
| 360 |
+
graph=graph,
|
| 361 |
+
fastMethod=True,
|
| 362 |
+
globalAlign=True,
|
| 363 |
+
matchscore=1,
|
| 364 |
+
mismatchscore=-2,
|
| 365 |
+
gap_open=-1,
|
| 366 |
+
)
|
| 367 |
+
graph.incorporateSeqAlignment(alignment, sequence, label=i)
|
| 368 |
+
|
| 369 |
+
print("DEBUG: All sequences added")
|
| 370 |
+
|
| 371 |
+
# Refine the graph with proper domain and model parameters
|
| 372 |
+
graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini")
|
| 373 |
+
print("DEBUG: Graph refined")
|
| 374 |
+
|
| 375 |
+
# Convert to JSON format for vis.js
|
| 376 |
+
nodes = []
|
| 377 |
+
edges = []
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
print("DEBUG: Starting to process graph data")
|
| 381 |
+
# Get consensus nodes for coloring (make it optional)
|
| 382 |
+
try:
|
| 383 |
+
consensus_nodes = set(graph.consensus_node_ids)
|
| 384 |
+
print(f"DEBUG: Consensus nodes: {consensus_nodes}")
|
| 385 |
+
except Exception as e:
|
| 386 |
+
print(f"DEBUG: Error getting consensus nodes: {e}")
|
| 387 |
+
consensus_nodes = set() # Fallback to empty set if consensus fails
|
| 388 |
+
|
| 389 |
+
# Create nodes using the same logic as jsOutput
|
| 390 |
+
for node in graph.nodeiterator()():
|
| 391 |
+
title_text = ""
|
| 392 |
+
if node.sequences:
|
| 393 |
+
title_text += f"Sequences: {node.sequences}"
|
| 394 |
+
if node.variations:
|
| 395 |
+
title_text += ";;;".join(
|
| 396 |
+
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
|
| 397 |
+
)
|
| 398 |
+
title_text = title_text.replace('"', "'")
|
| 399 |
+
|
| 400 |
+
# Use the same color logic as jsOutput
|
| 401 |
+
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
|
| 402 |
+
|
| 403 |
+
node_data = {
|
| 404 |
+
"id": node.ID,
|
| 405 |
+
"label": f"{node.ID}: {node.text}",
|
| 406 |
+
"title": title_text,
|
| 407 |
+
"color": color,
|
| 408 |
+
}
|
| 409 |
+
nodes.append(node_data)
|
| 410 |
+
|
| 411 |
+
print(f"DEBUG: Created {len(nodes)} nodes")
|
| 412 |
+
|
| 413 |
+
# Create edges using the same logic as jsOutput
|
| 414 |
+
for node in graph.nodeiterator()():
|
| 415 |
+
nodeID = node.ID # Keep as integer
|
| 416 |
+
for edge in node.outEdges:
|
| 417 |
+
target = edge # Keep as integer
|
| 418 |
+
weight = node.outEdges[edge].weight + 1.5
|
| 419 |
+
edge_data = {
|
| 420 |
+
"from": nodeID,
|
| 421 |
+
"to": target,
|
| 422 |
+
"value": weight,
|
| 423 |
+
"color": "#cae0e6",
|
| 424 |
+
"arrows": "to",
|
| 425 |
+
}
|
| 426 |
+
edges.append(edge_data)
|
| 427 |
+
|
| 428 |
+
print(f"DEBUG: Created {len(edges)} edges")
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"DEBUG: Error processing graph data: {e}")
|
| 431 |
+
return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
|
| 432 |
+
|
| 433 |
+
# Extract text from consensus nodes
|
| 434 |
+
consensus_text = ""
|
| 435 |
+
try:
|
| 436 |
+
consensus_node_texts = []
|
| 437 |
+
for node in graph.nodeiterator()():
|
| 438 |
+
if node.ID in consensus_nodes and node.text and node.text.strip():
|
| 439 |
+
consensus_node_texts.append(node.text.strip())
|
| 440 |
+
consensus_text = " ".join(consensus_node_texts)
|
| 441 |
+
except Exception:
|
| 442 |
+
consensus_text = ""
|
| 443 |
+
|
| 444 |
+
# Check if we should compute consensus using decode_consensus
|
| 445 |
+
compute_consensus = data.get("compute_consensus", False)
|
| 446 |
+
if compute_consensus and decode_consensus:
|
| 447 |
+
try:
|
| 448 |
+
# Default to "bio" task for new graphs
|
| 449 |
+
consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio")
|
| 450 |
+
except Exception as e:
|
| 451 |
+
print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
|
| 452 |
+
# Keep the original consensus text if decode_consensus fails
|
| 453 |
+
|
| 454 |
+
# Get original sequences
|
| 455 |
+
try:
|
| 456 |
+
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
|
| 457 |
+
# Process sequences: join with spaces and remove "||"
|
| 458 |
+
original_sequences = []
|
| 459 |
+
for seq in raw_sequences:
|
| 460 |
+
if isinstance(seq, list):
|
| 461 |
+
# Join list elements with spaces
|
| 462 |
+
processed_seq = " ".join(str(item) for item in seq)
|
| 463 |
+
else:
|
| 464 |
+
processed_seq = str(seq)
|
| 465 |
+
# Remove "||" characters
|
| 466 |
+
processed_seq = processed_seq.replace("||", "")
|
| 467 |
+
original_sequences.append(processed_seq)
|
| 468 |
+
except Exception:
|
| 469 |
+
original_sequences = []
|
| 470 |
+
|
| 471 |
+
print("DEBUG: Returning success response")
|
| 472 |
+
return jsonify(
|
| 473 |
+
{
|
| 474 |
+
"success": True,
|
| 475 |
+
"nodes": nodes,
|
| 476 |
+
"edges": edges,
|
| 477 |
+
"num_sequences": len(sequences),
|
| 478 |
+
"num_nodes": len(nodes),
|
| 479 |
+
"num_edges": len(edges),
|
| 480 |
+
"original_sequences": original_sequences,
|
| 481 |
+
"consensus_text": consensus_text,
|
| 482 |
+
}
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
except Exception as e:
|
| 486 |
+
print(f"DEBUG: Main exception in create_graph: {e}")
|
| 487 |
+
return jsonify({"error": str(e)}), 500
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
@app.route("/api/save_graph", methods=["POST"])
|
| 491 |
+
def save_graph():
|
| 492 |
+
"""Save a POA graph to a pickle file"""
|
| 493 |
+
try:
|
| 494 |
+
data = request.get_json()
|
| 495 |
+
sequences = data.get("sequences", [])
|
| 496 |
+
filename = data.get("filename", "graph.pkl")
|
| 497 |
+
|
| 498 |
+
if len(sequences) < 2:
|
| 499 |
+
return jsonify({"error": "At least 2 sequences are required"}), 400
|
| 500 |
+
|
| 501 |
+
# Create the graph
|
| 502 |
+
graph = TextPOAGraph(sequences[0], label=0)
|
| 503 |
+
|
| 504 |
+
# Add remaining sequences
|
| 505 |
+
for i, sequence in enumerate(sequences[1:], 1):
|
| 506 |
+
alignment = TextSeqGraphAlignment(
|
| 507 |
+
text=sequence,
|
| 508 |
+
graph=graph,
|
| 509 |
+
fastMethod=True,
|
| 510 |
+
globalAlign=True,
|
| 511 |
+
matchscore=1,
|
| 512 |
+
mismatchscore=-2,
|
| 513 |
+
gap_open=-1,
|
| 514 |
+
)
|
| 515 |
+
graph.incorporateSeqAlignment(alignment, sequence, label=i)
|
| 516 |
+
|
| 517 |
+
# Refine the graph
|
| 518 |
+
graph.refine_graph(verbose=False)
|
| 519 |
+
|
| 520 |
+
# Save to pickle file
|
| 521 |
+
graph.save_to_pickle(filename)
|
| 522 |
+
|
| 523 |
+
return jsonify(
|
| 524 |
+
{"success": True, "filename": filename, "message": f"Graph saved to {filename}"}
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
except Exception as e:
|
| 528 |
+
return jsonify({"error": str(e)}), 500
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@app.route("/api/graph_info", methods=["POST"])
|
| 532 |
+
def graph_info():
|
| 533 |
+
"""Get information about a graph without creating the full visualization"""
|
| 534 |
+
try:
|
| 535 |
+
data = request.get_json()
|
| 536 |
+
sequences = data.get("sequences", [])
|
| 537 |
+
|
| 538 |
+
if len(sequences) < 2:
|
| 539 |
+
return jsonify({"error": "At least 2 sequences are required"}), 400
|
| 540 |
+
|
| 541 |
+
# Create the graph
|
| 542 |
+
graph = TextPOAGraph(sequences[0], label=0)
|
| 543 |
+
|
| 544 |
+
# Add remaining sequences
|
| 545 |
+
for i, sequence in enumerate(sequences[1:], 1):
|
| 546 |
+
alignment = TextSeqGraphAlignment(
|
| 547 |
+
text=sequence,
|
| 548 |
+
graph=graph,
|
| 549 |
+
fastMethod=True,
|
| 550 |
+
globalAlign=True,
|
| 551 |
+
matchscore=1,
|
| 552 |
+
mismatchscore=-2,
|
| 553 |
+
gap_open=-1,
|
| 554 |
+
)
|
| 555 |
+
graph.incorporateSeqAlignment(alignment, sequence, label=i)
|
| 556 |
+
|
| 557 |
+
# Refine the graph
|
| 558 |
+
graph.refine_graph(verbose=False)
|
| 559 |
+
|
| 560 |
+
# Get consensus response
|
| 561 |
+
consensus_text = graph.consensus_response()
|
| 562 |
+
|
| 563 |
+
return jsonify(
|
| 564 |
+
{
|
| 565 |
+
"success": True,
|
| 566 |
+
"num_sequences": len(sequences),
|
| 567 |
+
"num_nodes": graph._nnodes,
|
| 568 |
+
"consensus_text": consensus_text,
|
| 569 |
+
"consensus_node_ids": graph.consensus_node_ids,
|
| 570 |
+
}
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
except Exception as e:
|
| 574 |
+
return jsonify({"error": str(e)}), 500
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if __name__ == "__main__":
|
| 578 |
+
# For HF Spaces, port must be 7860
|
| 579 |
+
port = int(os.environ.get("PORT", 7860))
|
| 580 |
+
print("Starting POA Graph Web Interface Server...")
|
| 581 |
+
print(f"Repository root: {REPO_ROOT}")
|
| 582 |
+
print(f"Serving static files from: {os.path.join(REPO_ROOT, 'web_interface')}")
|
| 583 |
+
print(f"Open http://localhost:{port} in your browser")
|
| 584 |
+
app.run(debug=False, host="0.0.0.0", port=port)
|
dockerignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git files
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
.gitattributes
|
| 5 |
+
|
| 6 |
+
# Python cache
|
| 7 |
+
__pycache__
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*$py.class
|
| 10 |
+
*.so
|
| 11 |
+
.Python
|
| 12 |
+
*.egg-info/
|
| 13 |
+
dist/
|
| 14 |
+
build/
|
| 15 |
+
|
| 16 |
+
# Virtual environments
|
| 17 |
+
venv/
|
| 18 |
+
env/
|
| 19 |
+
ENV/
|
| 20 |
+
|
| 21 |
+
# IDE
|
| 22 |
+
.vscode/
|
| 23 |
+
.idea/
|
| 24 |
+
*.swp
|
| 25 |
+
*.swo
|
| 26 |
+
*~
|
| 27 |
+
|
| 28 |
+
# OS files
|
| 29 |
+
.DS_Store
|
| 30 |
+
Thumbs.db
|
| 31 |
+
|
| 32 |
+
# Logs
|
| 33 |
+
*.log
|
| 34 |
+
|
| 35 |
+
# Testing
|
| 36 |
+
.pytest_cache/
|
| 37 |
+
.coverage
|
| 38 |
+
htmlcov/
|
| 39 |
+
|
| 40 |
+
# Documentation
|
| 41 |
+
docs/
|
| 42 |
+
*.md.backup
|
| 43 |
+
|
| 44 |
+
# Large unnecessary files
|
| 45 |
+
*.tar.gz
|
| 46 |
+
*.zip
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask==2.3.3
|
| 2 |
+
flask-cors==4.0.0
|
| 3 |
+
numpy==1.24.3
|
| 4 |
+
tqdm==4.66.1
|
| 5 |
+
huggingface_hub==0.25.1
|
| 6 |
+
openai==1.63.2
|
| 7 |
+
python-dotenv==1.0.1
|
| 8 |
+
sentence_transformers==3.3.1
|
| 9 |
+
torch==2.5.1
|
| 10 |
+
transformers==4.46.3
|
| 11 |
+
nltk==3.9.1
|
| 12 |
+
spacy==3.7.6
|
| 13 |
+
gunicorn==21.2.0
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (162 Bytes). View file
|
|
|
src/__pycache__/alignment.cpython-312.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
src/__pycache__/generation_methods.cpython-312.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
src/__pycache__/generation_utils.cpython-312.pyc
ADDED
|
Binary file (9.5 kB). View file
|
|
|
src/__pycache__/global_edit_utils.cpython-312.pyc
ADDED
|
Binary file (5.54 kB). View file
|
|
|
src/__pycache__/new_alignment.cpython-312.pyc
ADDED
|
Binary file (8.39 kB). View file
|
|
|
src/__pycache__/new_text_alignment.cpython-312.pyc
ADDED
|
Binary file (7.24 kB). View file
|
|
|
src/__pycache__/poa_graph.cpython-312.pyc
ADDED
|
Binary file (28.9 kB). View file
|
|
|
src/__pycache__/text_poa_graph.cpython-312.pyc
ADDED
|
Binary file (34.7 kB). View file
|
|
|
src/__pycache__/text_poa_graph_utils.cpython-312.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
src/alignment.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from Jonathan Dursi
|
| 3 |
+
https://github.com/ljdursi/poapy
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SeqGraphAlignment(object):
|
| 10 |
+
__matchscore = 1
|
| 11 |
+
__mismatchscore = -2
|
| 12 |
+
__gap = -1
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
sequence,
|
| 17 |
+
graph,
|
| 18 |
+
fastMethod=True,
|
| 19 |
+
globalAlign=False,
|
| 20 |
+
matchscore=__matchscore,
|
| 21 |
+
mismatchscore=__mismatchscore,
|
| 22 |
+
gapscore=__gap,
|
| 23 |
+
*args,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
self._mismatchscore = mismatchscore
|
| 27 |
+
self._matchscore = matchscore
|
| 28 |
+
self._gap = gapscore
|
| 29 |
+
self.sequence = sequence
|
| 30 |
+
self.graph = graph
|
| 31 |
+
self.stringidxs = None
|
| 32 |
+
self.nodeidxs = None
|
| 33 |
+
self.globalAlign = globalAlign
|
| 34 |
+
if fastMethod:
|
| 35 |
+
matches = self.alignStringToGraphFast(*args, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
matches = self.alignStringToGraphSimple(*args, **kwargs)
|
| 38 |
+
self.stringidxs, self.nodeidxs = matches
|
| 39 |
+
|
| 40 |
+
def alignmentStrings(self):
|
| 41 |
+
return "".join(
|
| 42 |
+
self.sequence[i] if i is not None else "-" for i in self.stringidxs
|
| 43 |
+
), "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs)
|
| 44 |
+
|
| 45 |
+
def matchscore(self, c1, c2):
|
| 46 |
+
if c1 == c2:
|
| 47 |
+
return self._matchscore
|
| 48 |
+
else:
|
| 49 |
+
return self._mismatchscore
|
| 50 |
+
|
| 51 |
+
def matchscoreVec(self, c, v):
|
| 52 |
+
return numpy.where(v == c, self._matchscore, self._mismatchscore)
|
| 53 |
+
|
| 54 |
+
def alignStringToGraphSimple(self):
|
| 55 |
+
"""Align string to graph, following same approach as smith waterman
|
| 56 |
+
example"""
|
| 57 |
+
if type(self.sequence) is not str:
|
| 58 |
+
raise TypeError("Invalid Type")
|
| 59 |
+
|
| 60 |
+
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
|
| 61 |
+
self.initializeDynamicProgrammingData()
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Dynamic Programming
|
| 65 |
+
ni = self.graph.nodeiterator()
|
| 66 |
+
for i, node in enumerate(ni()):
|
| 67 |
+
pbase = node.text
|
| 68 |
+
|
| 69 |
+
for j, sbase in enumerate(self.sequence):
|
| 70 |
+
# add all candidates to a list, pick the best
|
| 71 |
+
candidates = [(scores[i + 1, j] + self._gap, i + 1, j, "INS")]
|
| 72 |
+
for predIndex in self.prevIndices(node, nodeIDtoIndex):
|
| 73 |
+
candidates += [
|
| 74 |
+
(scores[predIndex + 1, j + 1] + self._gap, predIndex + 1, j + 1, "DEL")
|
| 75 |
+
]
|
| 76 |
+
candidates += [
|
| 77 |
+
(
|
| 78 |
+
scores[predIndex + 1, j] + self.matchscore(sbase, pbase),
|
| 79 |
+
predIndex + 1,
|
| 80 |
+
j,
|
| 81 |
+
"MATCH",
|
| 82 |
+
)
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
(
|
| 86 |
+
scores[i + 1, j + 1],
|
| 87 |
+
backGrphIdx[i + 1, j + 1],
|
| 88 |
+
backStrIdx[i + 1, j + 1],
|
| 89 |
+
movetype,
|
| 90 |
+
) = max(candidates)
|
| 91 |
+
|
| 92 |
+
if not self.globalAlign and scores[i + 1, j + 1] < 0:
|
| 93 |
+
scores[i + 1, j + 1] = 0.0
|
| 94 |
+
backGrphIdx[i + 1, j + 1] = -1
|
| 95 |
+
backStrIdx[i + 1, j + 1] = -1
|
| 96 |
+
|
| 97 |
+
return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
|
| 98 |
+
|
| 99 |
+
def alignStringToGraphFast(self):
|
| 100 |
+
"""Align string to graph - using numpy to vectorize across the string
|
| 101 |
+
at each iteration."""
|
| 102 |
+
if type(self.sequence) is not str:
|
| 103 |
+
raise TypeError("Invalid Type")
|
| 104 |
+
|
| 105 |
+
l2 = len(self.sequence)
|
| 106 |
+
seqvec = numpy.array(list(self.sequence))
|
| 107 |
+
|
| 108 |
+
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
|
| 109 |
+
self.initializeDynamicProgrammingData()
|
| 110 |
+
)
|
| 111 |
+
inserted = numpy.zeros((l2), dtype=bool)
|
| 112 |
+
|
| 113 |
+
# having the inner loop as a function improves performance
|
| 114 |
+
# can use Cython, etc on this for significant further improvements
|
| 115 |
+
# can't vectorize this since there's a loop-carried dependency
|
| 116 |
+
# along the string
|
| 117 |
+
def insertions(i, l2, scores, inserted):
|
| 118 |
+
inserted[:] = False
|
| 119 |
+
for j in range(l2):
|
| 120 |
+
insscore = scores[i + 1, j] + self._gap
|
| 121 |
+
if insscore >= scores[i + 1, j + 1]:
|
| 122 |
+
scores[i + 1, j + 1] = insscore
|
| 123 |
+
inserted[j] = True
|
| 124 |
+
|
| 125 |
+
# Dynamic Programming
|
| 126 |
+
ni = self.graph.nodeiterator()
|
| 127 |
+
for i, node in enumerate(ni()):
|
| 128 |
+
gbase = node.text
|
| 129 |
+
predecessors = self.prevIndices(node, nodeIDtoIndex)
|
| 130 |
+
|
| 131 |
+
# calculate all best deletions, matches in one go over all
|
| 132 |
+
# predecessors.
|
| 133 |
+
|
| 134 |
+
# First calculate for the first predecessor, over all string posns:
|
| 135 |
+
deletescore = scores[predecessors[0] + 1, 1:] + self._gap
|
| 136 |
+
bestdelete = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
|
| 137 |
+
|
| 138 |
+
matchpoints = self.matchscoreVec(gbase, seqvec)
|
| 139 |
+
matchscore = scores[predecessors[0] + 1, 0:-1] + matchpoints
|
| 140 |
+
bestmatch = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
|
| 141 |
+
|
| 142 |
+
# then, the remaining
|
| 143 |
+
for predecessor in predecessors[1:]:
|
| 144 |
+
newdeletescore = scores[predecessor + 1, 1:] + self._gap
|
| 145 |
+
bestdelete = numpy.where(newdeletescore > deletescore, predecessor + 1, bestdelete)
|
| 146 |
+
deletescore = numpy.maximum(newdeletescore, deletescore)
|
| 147 |
+
|
| 148 |
+
gbase = self.graph.nodeIdxToBase(predecessor)
|
| 149 |
+
matchpoints = self.matchscoreVec(gbase, seqvec)
|
| 150 |
+
newmatchscore = scores[predecessor + 1, 0:-1] + matchpoints
|
| 151 |
+
bestmatch = numpy.where(newmatchscore > matchscore, predecessor + 1, bestmatch)
|
| 152 |
+
matchscore = numpy.maximum(newmatchscore, matchscore)
|
| 153 |
+
|
| 154 |
+
# choose best options available of match, delete
|
| 155 |
+
deleted = deletescore >= matchscore
|
| 156 |
+
backGrphIdx[i + 1, 1:] = numpy.where(deleted, bestdelete, bestmatch)
|
| 157 |
+
backStrIdx[i + 1, 1:] = numpy.where(
|
| 158 |
+
deleted, numpy.arange(1, l2 + 1), numpy.arange(0, l2)
|
| 159 |
+
)
|
| 160 |
+
scores[i + 1, 1:] = numpy.where(deleted, deletescore, matchscore)
|
| 161 |
+
|
| 162 |
+
# insertions: updated in place, don't depend on predecessors
|
| 163 |
+
insertions(i, l2, scores, inserted)
|
| 164 |
+
backGrphIdx[i + 1, 1:] = numpy.where(inserted, i + 1, backGrphIdx[i + 1, 1:])
|
| 165 |
+
backStrIdx[i + 1, 1:] = numpy.where(inserted, numpy.arange(l2), backStrIdx[i + 1, 1:])
|
| 166 |
+
|
| 167 |
+
# if we're doing local alignment, don't let bad global alignment
|
| 168 |
+
# drag us negative
|
| 169 |
+
if not self.globalAlign:
|
| 170 |
+
backGrphIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backGrphIdx[i + 1, :], -1)
|
| 171 |
+
backStrIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backStrIdx[i + 1, :], -1)
|
| 172 |
+
scores[i + 1, :] = numpy.maximum(scores[i + 1, :], 0)
|
| 173 |
+
|
| 174 |
+
return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
|
| 175 |
+
|
| 176 |
+
def prevIndices(self, node, nodeIDtoIndex):
|
| 177 |
+
"""Return a list of the previous dynamic programming table indices
|
| 178 |
+
corresponding to predecessors of the current node."""
|
| 179 |
+
prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
|
| 180 |
+
# if no predecessors, point to just before the graph
|
| 181 |
+
if not prev:
|
| 182 |
+
prev = [-1]
|
| 183 |
+
return prev
|
| 184 |
+
|
| 185 |
+
def initializeDynamicProgrammingData(self):
|
| 186 |
+
"""Initalize the dynamic programming tables:
|
| 187 |
+
- set up scores array
|
| 188 |
+
- set up backtracking array
|
| 189 |
+
- create index to Node ID table and vice versa"""
|
| 190 |
+
l1 = self.graph.nNodes
|
| 191 |
+
l2 = len(self.sequence)
|
| 192 |
+
|
| 193 |
+
nodeIDtoIndex = {}
|
| 194 |
+
nodeIndexToID = {-1: None}
|
| 195 |
+
# generate a dict of (nodeID) -> (index into nodelist (and thus matrix))
|
| 196 |
+
ni = self.graph.nodeiterator()
|
| 197 |
+
for index, node in enumerate(ni()):
|
| 198 |
+
nodeIDtoIndex[node.ID] = index
|
| 199 |
+
nodeIndexToID[index] = node.ID
|
| 200 |
+
|
| 201 |
+
# Dynamic Programming data structures; scores matrix and backtracking
|
| 202 |
+
# matrix
|
| 203 |
+
scores = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 204 |
+
|
| 205 |
+
# initialize insertion score
|
| 206 |
+
# if global align, penalty for starting at head != 0
|
| 207 |
+
if self.globalAlign:
|
| 208 |
+
scores[0, :] = numpy.arange(l2 + 1) * self._gap
|
| 209 |
+
|
| 210 |
+
ni = self.graph.nodeiterator()
|
| 211 |
+
for index, node in enumerate(ni()):
|
| 212 |
+
prevIdxs = self.prevIndices(node, nodeIDtoIndex)
|
| 213 |
+
best = scores[prevIdxs[0] + 1, 0]
|
| 214 |
+
for prevIdx in prevIdxs:
|
| 215 |
+
best = max(best, scores[prevIdx + 1, 0])
|
| 216 |
+
scores[index + 1, 0] = best + self._gap
|
| 217 |
+
|
| 218 |
+
# backtracking matrices
|
| 219 |
+
backStrIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 220 |
+
backGrphIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 221 |
+
|
| 222 |
+
return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx
|
| 223 |
+
|
| 224 |
+
def backtrack(self, scores, backStrIdx, backGrphIdx, nodeIndexToID):
|
| 225 |
+
"""Backtrack through the scores and backtrack arrays.
|
| 226 |
+
Return a list of sequence indices and node IDs (not indices, which
|
| 227 |
+
depend on ordering)."""
|
| 228 |
+
besti, bestj = scores.shape
|
| 229 |
+
besti -= 1
|
| 230 |
+
bestj -= 1
|
| 231 |
+
if not self.globalAlign:
|
| 232 |
+
besti, bestj = numpy.argwhere(scores == numpy.amax(scores))[-1]
|
| 233 |
+
else:
|
| 234 |
+
ni = self.graph.nodeiterator()
|
| 235 |
+
# still have to find best final index to start from
|
| 236 |
+
terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
|
| 237 |
+
print(terminalIndices)
|
| 238 |
+
besti = terminalIndices[0] + 1
|
| 239 |
+
bestscore = scores[besti, bestj]
|
| 240 |
+
for i in terminalIndices[1:]:
|
| 241 |
+
score = scores[i + 1, bestj]
|
| 242 |
+
if score > bestscore:
|
| 243 |
+
bestscore, besti = score, i + 1
|
| 244 |
+
|
| 245 |
+
matches = []
|
| 246 |
+
strindexes = []
|
| 247 |
+
while (self.globalAlign or scores[besti, bestj] > 0) and (besti != 0 or bestj != 0):
|
| 248 |
+
nexti, nextj = backGrphIdx[besti, bestj], backStrIdx[besti, bestj]
|
| 249 |
+
curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
|
| 250 |
+
|
| 251 |
+
strindexes.insert(0, curstridx if nextj != bestj else None)
|
| 252 |
+
matches.insert(0, curnodeidx if nexti != besti else None)
|
| 253 |
+
|
| 254 |
+
besti, bestj = nexti, nextj
|
| 255 |
+
|
| 256 |
+
return strindexes, matches
|
src/generation_methods.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from src.generation_utils import (
|
| 4 |
+
extract_alternative_paths,
|
| 5 |
+
extract_context,
|
| 6 |
+
extract_equivalent_classes,
|
| 7 |
+
self_complete,
|
| 8 |
+
verify_correctness_pairwise,
|
| 9 |
+
)
|
| 10 |
+
from src.global_edit_utils import clean_up_text
|
| 11 |
+
from src.text_poa_graph import TextPOAGraph
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
|
| 15 |
+
Only the primary variation of selected variable nodes are selected.
|
| 16 |
+
Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
text_poa_graph: The TextPOAGraph object to decode.
|
| 20 |
+
selection_threshold: The threshold for selecting nodes.
|
| 21 |
+
model: The model to use for decoding.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
A string of the decoded text.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def decode_consensus(
|
| 29 |
+
text_poa_graph: TextPOAGraph,
|
| 30 |
+
selection_threshold: Optional[float] = 0.5,
|
| 31 |
+
task: str = "bio",
|
| 32 |
+
verbose: bool = False,
|
| 33 |
+
**kwargs,
|
| 34 |
+
) -> str:
|
| 35 |
+
if text_poa_graph.failed:
|
| 36 |
+
return "Abstain"
|
| 37 |
+
|
| 38 |
+
text_poa_graph.toposort()
|
| 39 |
+
|
| 40 |
+
consensus_node_ids = text_poa_graph.consensus_node_ids
|
| 41 |
+
|
| 42 |
+
selected_node_ids = []
|
| 43 |
+
|
| 44 |
+
for node_id in consensus_node_ids:
|
| 45 |
+
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
selected_node_ids.append(node_id)
|
| 49 |
+
|
| 50 |
+
for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
|
| 51 |
+
if neighbor_id in consensus_node_ids:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
if (
|
| 55 |
+
len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
|
| 56 |
+
>= selection_threshold
|
| 57 |
+
):
|
| 58 |
+
selected_node_ids.append(neighbor_id)
|
| 59 |
+
|
| 60 |
+
texts = []
|
| 61 |
+
for node_id in selected_node_ids:
|
| 62 |
+
if not text_poa_graph.nodedict[node_id].variations:
|
| 63 |
+
texts.append(text_poa_graph.nodedict[node_id].text)
|
| 64 |
+
else:
|
| 65 |
+
all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
|
| 66 |
+
all_texts.append(text_poa_graph.nodedict[node_id].text)
|
| 67 |
+
# select the variation that is longest
|
| 68 |
+
texts.append(max(all_texts, key=len))
|
| 69 |
+
text = " ".join(texts)
|
| 70 |
+
edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
|
| 71 |
+
if verbose:
|
| 72 |
+
return text, edited_text
|
| 73 |
+
else:
|
| 74 |
+
return edited_text
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def decode_self_verified(
|
| 78 |
+
text_poa_graph: TextPOAGraph,
|
| 79 |
+
problem: str,
|
| 80 |
+
uncertainty_threshold: float = 0.6,
|
| 81 |
+
verification_api: str = "openai",
|
| 82 |
+
verification_model: str = "gpt-4o-mini",
|
| 83 |
+
grace_period: bool = True,
|
| 84 |
+
):
|
| 85 |
+
high_uncertainty_nodes = []
|
| 86 |
+
for node_id in text_poa_graph.consensus_node_ids:
|
| 87 |
+
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
|
| 91 |
+
branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences
|
| 92 |
+
|
| 93 |
+
if branching_factor > uncertainty_threshold:
|
| 94 |
+
high_uncertainty_nodes.append(node_id)
|
| 95 |
+
|
| 96 |
+
selected_labels = list(text_poa_graph._seq_paths.keys())
|
| 97 |
+
masked_candidates = {}
|
| 98 |
+
uncertain_region = False
|
| 99 |
+
for label in selected_labels:
|
| 100 |
+
text = ""
|
| 101 |
+
for node_id in text_poa_graph._seq_paths[label]:
|
| 102 |
+
if uncertain_region:
|
| 103 |
+
text += f" *START_SEPARATOR*_{node_id} "
|
| 104 |
+
if node_id in high_uncertainty_nodes:
|
| 105 |
+
uncertain_region = True
|
| 106 |
+
|
| 107 |
+
if len(text_poa_graph.nodedict[node_id].variations) > 0:
|
| 108 |
+
text += text_poa_graph.nodedict[node_id].variations[label]
|
| 109 |
+
text += " "
|
| 110 |
+
else:
|
| 111 |
+
text += text_poa_graph.nodedict[node_id].text
|
| 112 |
+
text += " "
|
| 113 |
+
|
| 114 |
+
if uncertain_region and node_id not in high_uncertainty_nodes:
|
| 115 |
+
text += f" *END_SEPARATOR*_{node_id} "
|
| 116 |
+
uncertain_region = False
|
| 117 |
+
masked_candidates[label] = text
|
| 118 |
+
|
| 119 |
+
patch_start_node = None
|
| 120 |
+
uncertain_ids = []
|
| 121 |
+
|
| 122 |
+
# give a grace period for the first incorrect step
|
| 123 |
+
prev_step = {label: None for label in selected_labels}
|
| 124 |
+
|
| 125 |
+
for node_id in high_uncertainty_nodes:
|
| 126 |
+
uncertain_ids.append(node_id)
|
| 127 |
+
context_before = extract_context(text_poa_graph, node_id)
|
| 128 |
+
alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
|
| 129 |
+
equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
|
| 130 |
+
new_labels = selected_labels.copy()
|
| 131 |
+
|
| 132 |
+
# Only do self-verifaction for labels from different sematically equivalent branches
|
| 133 |
+
if len(equivalent_classes) <= 1:
|
| 134 |
+
continue
|
| 135 |
+
i = 0
|
| 136 |
+
while i < len(equivalent_classes):
|
| 137 |
+
if i + 1 < len(equivalent_classes):
|
| 138 |
+
label_a = equivalent_classes[i][0]
|
| 139 |
+
label_b = equivalent_classes[i + 1][0]
|
| 140 |
+
full_a = context_before[label_a] + alternative_paths[label_a]
|
| 141 |
+
full_b = context_before[label_b] + alternative_paths[label_b]
|
| 142 |
+
|
| 143 |
+
score = verify_correctness_pairwise(
|
| 144 |
+
full_text_1=full_a,
|
| 145 |
+
full_text_2=full_b,
|
| 146 |
+
verification_model=verification_model,
|
| 147 |
+
problem=problem,
|
| 148 |
+
api=verification_api,
|
| 149 |
+
)
|
| 150 |
+
if float(score[0]) < 1.0:
|
| 151 |
+
print(f"Label {label_a} is incorrect at node {node_id}")
|
| 152 |
+
masked_candidates[label_a] = (
|
| 153 |
+
masked_candidates[label_a]
|
| 154 |
+
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
|
| 155 |
+
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
|
| 156 |
+
)
|
| 157 |
+
if not prev_step[label_a]:
|
| 158 |
+
prev_step[label_a] = True
|
| 159 |
+
if prev_step[label_a] and grace_period or not grace_period:
|
| 160 |
+
for label_i in equivalent_classes[i]:
|
| 161 |
+
new_labels.remove(label_i)
|
| 162 |
+
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
|
| 163 |
+
if float(score[0]) == 1.0:
|
| 164 |
+
prev_step[label_a] = False
|
| 165 |
+
if float(score[1]) < 1.0:
|
| 166 |
+
print(f"Label {label_b} is incorrect at node {node_id}")
|
| 167 |
+
masked_candidates[label_b] = (
|
| 168 |
+
masked_candidates[label_b]
|
| 169 |
+
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
|
| 170 |
+
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
|
| 171 |
+
)
|
| 172 |
+
if not prev_step[label_b]:
|
| 173 |
+
prev_step[label_b] = True
|
| 174 |
+
if prev_step[label_b] and grace_period or not grace_period:
|
| 175 |
+
for label_i in equivalent_classes[i + 1]:
|
| 176 |
+
new_labels.remove(label_i)
|
| 177 |
+
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
|
| 178 |
+
if float(score[1]) == 1.0:
|
| 179 |
+
prev_step[label_b] = False
|
| 180 |
+
i += 2
|
| 181 |
+
else:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
if len(new_labels) == 0:
|
| 185 |
+
patch_start_node = node_id
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
selected_labels = new_labels.copy()
|
| 189 |
+
|
| 190 |
+
# These are the pruned approaches with masking
|
| 191 |
+
print(masked_candidates)
|
| 192 |
+
masked_approaches = "\n".join(
|
| 193 |
+
[
|
| 194 |
+
f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
|
| 195 |
+
for label in selected_labels
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
# These are all approaches with masking
|
| 199 |
+
all_approaches = "\n".join(
|
| 200 |
+
[f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
default_prompt = f"""
|
| 204 |
+
Solve the following math problem with mathematical precision and clarity.
|
| 205 |
+
|
| 206 |
+
Problem: {problem}
|
| 207 |
+
|
| 208 |
+
Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*).
|
| 209 |
+
These sections may contain conceptual or computational errors.
|
| 210 |
+
|
| 211 |
+
There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
|
| 212 |
+
A verification step indicated that these steps are highly likely to contain errors.
|
| 213 |
+
|
| 214 |
+
Potential Approaches:
|
| 215 |
+
{masked_approaches}
|
| 216 |
+
|
| 217 |
+
Your task:
|
| 218 |
+
1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
|
| 219 |
+
If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
|
| 220 |
+
2. Using the sections with special markers, identify potential errors.
|
| 221 |
+
3. Develop a rigorous, step-by-step solution based on sound mathematical principles
|
| 222 |
+
4. For uncertain regions:
|
| 223 |
+
- Verify each step using algebraic or numerical validation
|
| 224 |
+
- If correct, incorporate these steps with appropriate justification
|
| 225 |
+
- If incorrect, provide clear corrections with mathematical reasoning for your changes
|
| 226 |
+
5. Follow a comparative approach, using the differences between approaches to identify potential errors.
|
| 227 |
+
6. Do not blindly follow the approaches, but rather use them to identify potential errors.
|
| 228 |
+
|
| 229 |
+
Guidelines for your solution:
|
| 230 |
+
- Begin with a strategic overview of your chosen approach
|
| 231 |
+
- Present each mathematical step with clear notation and justification
|
| 232 |
+
- Pay special attention to areas that were previously marked uncertain
|
| 233 |
+
|
| 234 |
+
Conclude your solution with:
|
| 235 |
+
Therefore, the final answer is: $\\boxed{{answer}}$.
|
| 236 |
+
|
| 237 |
+
Solution:
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
patch_prompt = f"""
|
| 241 |
+
Solve the following mathematical problem with precision and clarity.
|
| 242 |
+
|
| 243 |
+
Problem: {problem}
|
| 244 |
+
|
| 245 |
+
You have been provided with several partial solution approaches that attempted to solve this problem.
|
| 246 |
+
None of these approaches are correct, but may contain valuable insights.
|
| 247 |
+
Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
|
| 248 |
+
A verification step indicated that these steps are likely to contain errors.
|
| 249 |
+
|
| 250 |
+
INSTRUCTIONS:
|
| 251 |
+
1. Synthesize a correct solution using insights from the previous approaches
|
| 252 |
+
2. Pay special attention to fixing the problematic areas marked by separators
|
| 253 |
+
3. Develop your solution step-by-step, showing clear mathematical reasoning
|
| 254 |
+
4. Focus especially on mathematical correctness in areas where previous solutions diverged
|
| 255 |
+
5. Present your work in a logical, sequential manner suitable for an advanced reader
|
| 256 |
+
|
| 257 |
+
GUIDELINES FOR MATHEMATICAL RIGOR:
|
| 258 |
+
1. MAINTAIN MATHEMATICAL RIGOR
|
| 259 |
+
- Verify that all mathematical operations follow from established principles and definitions
|
| 260 |
+
- Ensure dimensional consistency throughout calculations
|
| 261 |
+
- Check that algebraic manipulations preserve equality and do not introduce errors
|
| 262 |
+
|
| 263 |
+
2. CONSIDER ALTERNATIVE PERSPECTIVES
|
| 264 |
+
- Even when approaches reach the same conclusion, examine their reasoning independently
|
| 265 |
+
- Look for more elegant or insightful connections that may be missed across all approaches
|
| 266 |
+
- Consider whether fundamental mathematical principles suggest a different path
|
| 267 |
+
|
| 268 |
+
3. CRITICAL VALIDATION
|
| 269 |
+
- Test conclusions using known mathematical properties and relationships
|
| 270 |
+
- When possible, verify results using alternative methods
|
| 271 |
+
- Be especially cautious when all approaches agree on a result but use similar reasoning
|
| 272 |
+
|
| 273 |
+
4. USE PRECISION IN CORRECTIONS
|
| 274 |
+
- When correcting uncertain regions, specify exactly what was incorrect and why
|
| 275 |
+
- Provide clear mathematical justification for any changes
|
| 276 |
+
- Ensure corrections align with standard mathematical principles and notations
|
| 277 |
+
|
| 278 |
+
Previous Approaches (for reference only):
|
| 279 |
+
{all_approaches}
|
| 280 |
+
|
| 281 |
+
Your Solution:
|
| 282 |
+
[Begin with a clear statement of your approach]
|
| 283 |
+
[Provide detailed mathematical steps]
|
| 284 |
+
[Ensure correct handling of complex mathematical operations]
|
| 285 |
+
[Verify your work at key points, especially in previously problematic areas]
|
| 286 |
+
|
| 287 |
+
Always conclude with:
|
| 288 |
+
Therefore, the final answer is: $\\boxed{{answer}}$
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
if patch_start_node is not None or len(masked_candidates.keys()) == 1:
|
| 292 |
+
print("None correct, patching")
|
| 293 |
+
prompt = patch_prompt
|
| 294 |
+
else:
|
| 295 |
+
prompt = default_prompt
|
| 296 |
+
|
| 297 |
+
return self_complete(
|
| 298 |
+
verification_prompt=prompt, verification_model=verification_model, api=verification_api
|
| 299 |
+
), masked_candidates
|
src/generation_utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from together import Together
|
| 6 |
+
|
| 7 |
+
from src.text_poa_graph import TextPOAGraph
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def extract_context(text_poa_graph, node_id):
|
| 11 |
+
"""Extract context up to and including the specified node_id."""
|
| 12 |
+
contexts = {}
|
| 13 |
+
for label, path in text_poa_graph._seq_paths.items():
|
| 14 |
+
idx = path.index(node_id)
|
| 15 |
+
context = path[: idx + 1]
|
| 16 |
+
contexts[label] = " ".join(
|
| 17 |
+
text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
|
| 18 |
+
for nid in context
|
| 19 |
+
)
|
| 20 |
+
return contexts
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def extract_alternative_paths(text_poa_graph: TextPOAGraph, node_id):
|
| 24 |
+
"""Extract all alternative paths from this uncertainty point to the next consensus node."""
|
| 25 |
+
alternative_paths = {}
|
| 26 |
+
for label, path in text_poa_graph._seq_paths.items():
|
| 27 |
+
idx = path.index(node_id)
|
| 28 |
+
next_cn = None
|
| 29 |
+
for i in range(idx + 1, len(path)):
|
| 30 |
+
if path[i] in text_poa_graph.consensus_node_ids:
|
| 31 |
+
next_cn = path[i]
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
if next_cn:
|
| 35 |
+
next_cn_idx = path.index(next_cn)
|
| 36 |
+
alternative_segment = path[idx + 1 : next_cn_idx + 1]
|
| 37 |
+
else:
|
| 38 |
+
alternative_segment = []
|
| 39 |
+
|
| 40 |
+
alternative_paths[label] = " ".join(
|
| 41 |
+
text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
|
| 42 |
+
for nid in alternative_segment
|
| 43 |
+
)
|
| 44 |
+
return alternative_paths
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def is_same_branch(text_poa_graph: TextPOAGraph, node_id, lable_1, label_2):
|
| 48 |
+
"""Check if the next vaiable nodes for two sequences are the same after node_id."""
|
| 49 |
+
path_1 = text_poa_graph._seq_paths[lable_1]
|
| 50 |
+
path_2 = text_poa_graph._seq_paths[label_2]
|
| 51 |
+
idx_1 = path_1.index(node_id)
|
| 52 |
+
idx_2 = path_2.index(node_id)
|
| 53 |
+
return path_1[idx_1 + 1] == path_2[idx_2 + 1]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def extract_equivalent_classes(text_poa_graph: TextPOAGraph, node_id, selected_labels):
|
| 57 |
+
"""Extract equivalent classes from the text POA graph."""
|
| 58 |
+
if not selected_labels:
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
equivalent_classes = []
|
| 62 |
+
for label in selected_labels:
|
| 63 |
+
matched = False
|
| 64 |
+
for class_group in equivalent_classes:
|
| 65 |
+
if is_same_branch(text_poa_graph, node_id, class_group[0], label):
|
| 66 |
+
class_group.append(label)
|
| 67 |
+
matched = True
|
| 68 |
+
break
|
| 69 |
+
if not matched:
|
| 70 |
+
equivalent_classes.append([label])
|
| 71 |
+
return equivalent_classes
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def verify_correctness_pairwise(
|
| 75 |
+
full_text_1: str, full_text_2: str, verification_model: str, problem: str, api: str = "openai"
|
| 76 |
+
):
|
| 77 |
+
"""Pairwise verification of two partial solution paths."""
|
| 78 |
+
if api == "openai":
|
| 79 |
+
client = OpenAI()
|
| 80 |
+
elif api == "hf":
|
| 81 |
+
client = InferenceClient()
|
| 82 |
+
elif api == "together":
|
| 83 |
+
client = Together()
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Invalid API: {api}")
|
| 86 |
+
|
| 87 |
+
prompt = f"""
|
| 88 |
+
You will be given a problem and 2 partial solutions.
|
| 89 |
+
Your task is to use comparison as an EFFICIENCY TOOL to quickly identify potential errors.
|
| 90 |
+
You will be given guidelines to follow, and you will be penalized if you do not follow them.
|
| 91 |
+
|
| 92 |
+
Problem: {problem}
|
| 93 |
+
|
| 94 |
+
Partial Solution 1: {full_text_1}
|
| 95 |
+
Partial Solution 2: {full_text_2}
|
| 96 |
+
|
| 97 |
+
CRITICAL GUIDELINES:
|
| 98 |
+
- DO NOT penalize a solution for being incomplete or having missing steps
|
| 99 |
+
- DO NOT make a comparison of which solution is better
|
| 100 |
+
- DO NOT consider steps incorrect just because they differ between solutions
|
| 101 |
+
- DO NOT prematurely evaluate based on final answers or future steps
|
| 102 |
+
- DO NOT expect both solutions to be at the same stage of completion
|
| 103 |
+
- DO NOT consider a step incorrect just because it lacks sufficient detail or justification
|
| 104 |
+
|
| 105 |
+
KEY EFFICIENCY PRINCIPLE:
|
| 106 |
+
- Use agreement between solutions as evidence of correctness
|
| 107 |
+
- Use disagreement as a signal to investigate more deeply
|
| 108 |
+
- Only label a step as an error if it contains a specific mathematical mistake
|
| 109 |
+
- Incompleteness is not a mathematical error.
|
| 110 |
+
|
| 111 |
+
Here are the instructions for how to complete your task:
|
| 112 |
+
|
| 113 |
+
EFFICIENT VERIFICATION APPROACH:
|
| 114 |
+
|
| 115 |
+
1. QUICK COMPARISON (Use this to focus your attention):
|
| 116 |
+
- Immediately identify where the solutions differ in approach or results
|
| 117 |
+
- Use these differences as "error hotspots" to prioritize your verification
|
| 118 |
+
- When solutions agree, you can generally assume that part is correct
|
| 119 |
+
- When solutions disagree, investigate those specific points deeply
|
| 120 |
+
|
| 121 |
+
2. TARGETED VERIFICATION (Only where needed):
|
| 122 |
+
- Most important: Do not consider any incomplete steps as errors
|
| 123 |
+
- Focus your mathematical verification on the "hotspots" identified above
|
| 124 |
+
- Check mathematical validity only at points of difference or uncertainty
|
| 125 |
+
- Avoid line-by-line checking of steps where solutions agree
|
| 126 |
+
- For each potential error spot, verify if the mathematical reasoning is valid
|
| 127 |
+
- If an intermediate step is later corrected, do not penalize the solution for having the incorrect intermediate step
|
| 128 |
+
|
| 129 |
+
After your targeted verification, propose a score tuple (score_1, score_2):
|
| 130 |
+
- Score (1,1) if both partial solutions are valid
|
| 131 |
+
- Score (1,0) if only the first solution is valid
|
| 132 |
+
- Score (0,1) if only the second solution is valid
|
| 133 |
+
- Score (0,0) if both solutions are invalid
|
| 134 |
+
|
| 135 |
+
In case you score a solution as 0, you must give an explanation for each check below:
|
| 136 |
+
3. FINAL CHECKS:
|
| 137 |
+
- If you score a solution as 0, you MUST identify the specific mathematical error.
|
| 138 |
+
- You must also double check the problem statement. Reconsider your score and determine if you have misinterpreted the problem statement.
|
| 139 |
+
- You must also check whether you have penalized a solution for being incomplete or having missing steps.
|
| 140 |
+
|
| 141 |
+
Before outputting your final score, you must answer these questions:
|
| 142 |
+
STOP! Did you give a score of 0 to a solution that was incomplete?
|
| 143 |
+
STOP! Did you penalize a solution for being incomplete or having missing steps?
|
| 144 |
+
STOP! Did you make a comparison of which solution is better?
|
| 145 |
+
STOP! Did you consider steps incorrect just because they differ between solutions?
|
| 146 |
+
STOP! Did you prematurely evaluate based on final answers?
|
| 147 |
+
STOP! Did you consider a step incorrect just because it lacks sufficient detail or justification?
|
| 148 |
+
|
| 149 |
+
Now give your final score:
|
| 150 |
+
Final score:
|
| 151 |
+
"""
|
| 152 |
+
completion = client.chat.completions.create(
|
| 153 |
+
model=verification_model,
|
| 154 |
+
messages=[
|
| 155 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 156 |
+
{"role": "user", "content": prompt},
|
| 157 |
+
],
|
| 158 |
+
temperature=0.0,
|
| 159 |
+
)
|
| 160 |
+
response = completion.choices[0].message.content.strip()
|
| 161 |
+
print(full_text_1)
|
| 162 |
+
print(full_text_2)
|
| 163 |
+
print(f"Correctness score: {response} \n")
|
| 164 |
+
score_match = re.findall(r"\(\s*([01](?:\.0)?)\s*,\s*([01](?:\.0)?)\s*\)", response)
|
| 165 |
+
score = score_match[-1] if score_match else (0, 0)
|
| 166 |
+
return score
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def self_complete(verification_prompt: str, verification_model: str, api: str = "openai"):
|
| 170 |
+
print(verification_prompt)
|
| 171 |
+
"""Completetion method"""
|
| 172 |
+
if api == "openai":
|
| 173 |
+
client = OpenAI()
|
| 174 |
+
elif api == "hf":
|
| 175 |
+
client = InferenceClient()
|
| 176 |
+
elif api == "together":
|
| 177 |
+
client = Together()
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError(f"Invalid API: {api}")
|
| 180 |
+
|
| 181 |
+
completion = client.chat.completions.create(
|
| 182 |
+
model=verification_model,
|
| 183 |
+
messages=[
|
| 184 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 185 |
+
{"role": "user", "content": verification_prompt},
|
| 186 |
+
],
|
| 187 |
+
temperature=0.0,
|
| 188 |
+
)
|
| 189 |
+
response = completion.choices[0].message.content.strip()
|
| 190 |
+
return response
|
src/global_edit_utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
from openai import OpenAI
|
| 3 |
+
|
| 4 |
+
bio_prompt = """
|
| 5 |
+
You are given a piece of text that is a part of a biography of an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 6 |
+
Then, remove any redundant information.
|
| 7 |
+
Text: {text}
|
| 8 |
+
|
| 9 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 10 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 11 |
+
Only return the cleaned up text. Do not include any other text:
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
fp_prompt = """
|
| 15 |
+
You are given a piece of text that is a part of a false presupposition task which includes outputting a list of items.
|
| 16 |
+
This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 17 |
+
Then, remove any redundant information.
|
| 18 |
+
Text: {text}
|
| 19 |
+
|
| 20 |
+
The resulting list of items should be separated by semicolons with no other text.
|
| 21 |
+
If this list it not possible to generate, return "Abstain".
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
hist_prompt = """
|
| 25 |
+
You are given a piece of text that is a part of a historical event task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 26 |
+
Then, remove any redundant information.
|
| 27 |
+
Text: {text}
|
| 28 |
+
|
| 29 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 30 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 31 |
+
Only return the cleaned up text. Do not include any other text:
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
refs_prompt = """
|
| 35 |
+
You are given a piece of text that is a part of a reference task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 36 |
+
Then, remove any redundant information.
|
| 37 |
+
Text: {text}
|
| 38 |
+
|
| 39 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 40 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 41 |
+
Only return the cleaned up text. Do not include any other text:
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
gpqa_prompt = """
|
| 45 |
+
You are given a piece of text that is a part of a graduate level question answering task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 46 |
+
Then, remove any redundant information.
|
| 47 |
+
Text: {text}
|
| 48 |
+
Only return the cleaned up text. Do not include any other text:
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
popqa_prompt = """
|
| 52 |
+
You are given a piece of text that is a part of a paragraph which details facts related to an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 53 |
+
Then, remove any redundant information.
|
| 54 |
+
Text: {text}
|
| 55 |
+
|
| 56 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 57 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 58 |
+
Only return the cleaned up text. Do not include any other text:
|
| 59 |
+
"""
|
| 60 |
+
task_to_prompt = {
|
| 61 |
+
"bio": bio_prompt,
|
| 62 |
+
"fp": fp_prompt,
|
| 63 |
+
"hist": hist_prompt,
|
| 64 |
+
"refs": refs_prompt,
|
| 65 |
+
"gpqa": gpqa_prompt,
|
| 66 |
+
"popqa": popqa_prompt
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
'''
|
| 70 |
+
Cleans up disfluencies in the draft response in consensus decoding.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
text: The text to clean up.
|
| 74 |
+
api: The API to use for cleaning up the text.
|
| 75 |
+
task: The task : biography, false presupposition, historical event, reference, graduate question answering, paragraph question answering.
|
| 76 |
+
model: The model to use for cleaning up the text.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
A string of the cleaned up text.
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
def clean_up_text(text: str, api: str, task: str, model: str = "gpt-4.1-mini", **kwargs):
|
| 83 |
+
if api == "openai":
|
| 84 |
+
client = OpenAI()
|
| 85 |
+
elif api == "hf":
|
| 86 |
+
tokenizer = kwargs.get("tokenizer")
|
| 87 |
+
model = kwargs.get("hf_model")
|
| 88 |
+
|
| 89 |
+
if tokenizer is None or model is None:
|
| 90 |
+
raise ValueError("For 'hf', both 'tokenizer' and 'model' must be provided.")
|
| 91 |
+
|
| 92 |
+
clean_up_prompt = task_to_prompt[task].format(text=text)
|
| 93 |
+
|
| 94 |
+
messages = [{"role": "user", "content": clean_up_prompt}]
|
| 95 |
+
input_ids = tokenizer.apply_chat_template(
|
| 96 |
+
messages,
|
| 97 |
+
add_generation_prompt=True,
|
| 98 |
+
return_tensors="pt"
|
| 99 |
+
).to(model.device)
|
| 100 |
+
|
| 101 |
+
terminators = [ tokenizer.eos_token_id, ]
|
| 102 |
+
outputs = model.generate(
|
| 103 |
+
input_ids,
|
| 104 |
+
max_new_tokens=500,
|
| 105 |
+
do_sample=False,
|
| 106 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 107 |
+
eos_token_id=terminators,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return tokenizer.decode(
|
| 111 |
+
outputs[0][input_ids.shape[-1]:],
|
| 112 |
+
skip_special_tokens=True
|
| 113 |
+
).strip()
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Invalid API: {api}")
|
| 116 |
+
|
| 117 |
+
clean_up_prompt = task_to_prompt[task].format(text=text)
|
| 118 |
+
|
| 119 |
+
completion = client.chat.completions.create(
|
| 120 |
+
model=model,
|
| 121 |
+
messages=[
|
| 122 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 123 |
+
{"role": "user", "content": clean_up_prompt},
|
| 124 |
+
],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return completion.choices[0].message.content.strip()
|
src/new_alignment.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ScoreParam:
|
| 5 |
+
def __init__(self, match, mismatch, gap_open, gap_extend):
|
| 6 |
+
self.match = match
|
| 7 |
+
self.mismatch = mismatch
|
| 8 |
+
self.gap_open = gap_open
|
| 9 |
+
self.gap_extend = gap_extend
|
| 10 |
+
|
| 11 |
+
def __str__(self):
|
| 12 |
+
return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SeqGraphAlignment(object):
|
| 16 |
+
__default_score = ScoreParam(1, -3, -2, -1)
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
sequence,
|
| 21 |
+
graph,
|
| 22 |
+
fastMethod=True,
|
| 23 |
+
globalAlign=False,
|
| 24 |
+
score_params=__default_score,
|
| 25 |
+
*args,
|
| 26 |
+
**kwargs,
|
| 27 |
+
):
|
| 28 |
+
self.score = score_params
|
| 29 |
+
self.sequence = sequence
|
| 30 |
+
self.graph = graph
|
| 31 |
+
self.stringidxs = None
|
| 32 |
+
self.nodeidxs = None
|
| 33 |
+
self.globalAlign = globalAlign
|
| 34 |
+
if fastMethod:
|
| 35 |
+
matches = self.alignStringToGraphFast(*args, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
matches = self.alignStringToGraphSimple(*args, **kwargs)
|
| 38 |
+
self.stringidxs, self.nodeidxs = matches
|
| 39 |
+
|
| 40 |
+
def alignmentStrings(self):
|
| 41 |
+
return (
|
| 42 |
+
"".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs),
|
| 43 |
+
"".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def matchscore(self, c1, c2):
|
| 47 |
+
if c1 == c2:
|
| 48 |
+
return self.score.match
|
| 49 |
+
else:
|
| 50 |
+
return self.score.mismatch
|
| 51 |
+
|
| 52 |
+
def matchscoreVec(self, c, v):
|
| 53 |
+
return numpy.where(v == c, self.score.match, self.score.mismatch)
|
| 54 |
+
|
| 55 |
+
def prevIndices(self, node, nodeIDtoIndex):
|
| 56 |
+
prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
|
| 57 |
+
if not prev:
|
| 58 |
+
prev = [-1]
|
| 59 |
+
return prev
|
| 60 |
+
|
| 61 |
+
def initializeDynamicProgrammingData(self):
|
| 62 |
+
l1 = self.graph.nNodes
|
| 63 |
+
l2 = len(self.sequence)
|
| 64 |
+
|
| 65 |
+
nodeIDtoIndex = {}
|
| 66 |
+
nodeIndexToID = {-1: None}
|
| 67 |
+
ni = self.graph.nodeiterator()
|
| 68 |
+
for index, node in enumerate(ni()):
|
| 69 |
+
nodeIDtoIndex[node.ID] = index
|
| 70 |
+
nodeIndexToID[index] = node.ID
|
| 71 |
+
|
| 72 |
+
scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 73 |
+
|
| 74 |
+
if self.globalAlign:
|
| 75 |
+
# M[0, i] = -inf
|
| 76 |
+
scores[0, 0, :] = [
|
| 77 |
+
-1000000000 for i in range(l2+1)
|
| 78 |
+
]
|
| 79 |
+
scores[0, 0, 0] = 0
|
| 80 |
+
# X[0, i] = gap_open + i * gap_extend
|
| 81 |
+
scores[1, 0, :] = [
|
| 82 |
+
self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1)
|
| 83 |
+
]
|
| 84 |
+
scores[1, 0, 0] = -1000000000
|
| 85 |
+
# Y[0, i] = -inf
|
| 86 |
+
scores[2, 0, :] = [
|
| 87 |
+
-1000000000 for i in range(l2+1)
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
ni = self.graph.nodeiterator()
|
| 91 |
+
# After topology sort, the predcessors will have index less than the current node
|
| 92 |
+
for index, node in enumerate(ni()):
|
| 93 |
+
scores[0, index + 1, 0] = -1000000000
|
| 94 |
+
scores[1, index + 1, 0] = -1000000000
|
| 95 |
+
prevIdxs = self.prevIndices(node, nodeIDtoIndex)
|
| 96 |
+
best = scores[2 ,prevIdxs[0] + 1, 0]
|
| 97 |
+
for prevIdx in prevIdxs:
|
| 98 |
+
best = max(best, scores[2, prevIdx + 1, 0])
|
| 99 |
+
# If we have no predecessors, we start the gap
|
| 100 |
+
if prevIdxs == [-1]:
|
| 101 |
+
scores[2, index + 1, 0] = self.score.gap_open + self.score.gap_extend
|
| 102 |
+
else:
|
| 103 |
+
scores[2, index + 1, 0] = best + self.score.gap_extend
|
| 104 |
+
|
| 105 |
+
# 3D Backtracking
|
| 106 |
+
backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 107 |
+
backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 108 |
+
backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 109 |
+
|
| 110 |
+
return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx
|
| 111 |
+
|
| 112 |
+
def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID):
|
| 113 |
+
besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1
|
| 114 |
+
#Storing best matrices for each [i,j]
|
| 115 |
+
scores_arr = numpy.array(scores)
|
| 116 |
+
max_m = numpy.argmax(scores_arr, axis=0)
|
| 117 |
+
|
| 118 |
+
if self.globalAlign:
|
| 119 |
+
ni = self.graph.nodeiterator()
|
| 120 |
+
# Finding the best node to start from
|
| 121 |
+
terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
|
| 122 |
+
print(terminalIndices)
|
| 123 |
+
besti = terminalIndices[0] + 1
|
| 124 |
+
bestscore = scores[max_m[besti, bestj], besti, bestj]
|
| 125 |
+
for i in terminalIndices[1:]:
|
| 126 |
+
score = scores[max_m[i + 1, bestj], i + 1, bestj]
|
| 127 |
+
if score > bestscore:
|
| 128 |
+
bestscore, besti = score, i + 1
|
| 129 |
+
bestm = max_m[besti, bestj]
|
| 130 |
+
|
| 131 |
+
matches = []
|
| 132 |
+
strindexes = []
|
| 133 |
+
|
| 134 |
+
while (besti != 0 or bestj != 0):
|
| 135 |
+
nextm, nexti, nextj, = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj]
|
| 136 |
+
curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
|
| 137 |
+
|
| 138 |
+
if bestm == 0:
|
| 139 |
+
matches.insert(0, curnodeidx)
|
| 140 |
+
strindexes.insert(0, curstridx)
|
| 141 |
+
elif bestm == 1:
|
| 142 |
+
matches.insert(0, None)
|
| 143 |
+
strindexes.insert(0, curstridx)
|
| 144 |
+
else:
|
| 145 |
+
matches.insert(0, curnodeidx)
|
| 146 |
+
strindexes.insert(0, None)
|
| 147 |
+
|
| 148 |
+
bestm, besti, bestj = nextm, nexti, nextj
|
| 149 |
+
|
| 150 |
+
return strindexes, matches
|
src/new_text_alignment.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from difflib import SequenceMatcher
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from .new_alignment import ScoreParam, SeqGraphAlignment
|
| 6 |
+
|
| 7 |
+
PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"]
|
| 8 |
+
|
| 9 |
+
class TextSeqGraphAlignment(SeqGraphAlignment):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
text,
|
| 13 |
+
graph,
|
| 14 |
+
fastMethod=True,
|
| 15 |
+
globalAlign=True,
|
| 16 |
+
matchscore=1,
|
| 17 |
+
mismatchscore=-3,
|
| 18 |
+
gap_open=-2,
|
| 19 |
+
gap_extend=-1,
|
| 20 |
+
position_weight=0.1,
|
| 21 |
+
*args,
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
score_params = ScoreParam(
|
| 25 |
+
match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if isinstance(text, str):
|
| 29 |
+
self.original_text = text
|
| 30 |
+
self.sequence = text.split()
|
| 31 |
+
else:
|
| 32 |
+
self.sequence = text
|
| 33 |
+
self.original_text = " ".join(text)
|
| 34 |
+
self.position_weight = position_weight
|
| 35 |
+
|
| 36 |
+
super().__init__(
|
| 37 |
+
self.sequence,
|
| 38 |
+
graph,
|
| 39 |
+
fastMethod,
|
| 40 |
+
globalAlign=globalAlign,
|
| 41 |
+
score_params=score_params,
|
| 42 |
+
*args,
|
| 43 |
+
**kwargs,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def string_similarity(self, s1, s2):
|
| 47 |
+
"""Get edit-distance based similarity between two strings"""
|
| 48 |
+
return SequenceMatcher(None, s1, s2).ratio()
|
| 49 |
+
|
| 50 |
+
def matchscore(self, word1: str, word2: str) -> float:
|
| 51 |
+
"""Enhanced scoring function that considers string similarity
|
| 52 |
+
and relative position"""
|
| 53 |
+
# Calculate basic string similarity
|
| 54 |
+
similarity = self.string_similarity(word1, word2)
|
| 55 |
+
|
| 56 |
+
# If words are very similar, treat as match
|
| 57 |
+
if similarity > 0.8: # Can tune this threshold
|
| 58 |
+
similarity = self.score.match
|
| 59 |
+
# For less similar words, scale score based on similarity
|
| 60 |
+
elif similarity > 0.5: # Can tune this threshold too
|
| 61 |
+
similarity = self.score.match * similarity
|
| 62 |
+
else:
|
| 63 |
+
similarity = self.score.mismatch
|
| 64 |
+
return similarity
|
| 65 |
+
|
| 66 |
+
# add weight if any punctuation mark is present
|
| 67 |
+
if any(char in word1 for char in PUNCTUATION_MARKS) or any(
|
| 68 |
+
char in word2 for char in PUNCTUATION_MARKS
|
| 69 |
+
):
|
| 70 |
+
similarity = similarity * 1.5
|
| 71 |
+
|
| 72 |
+
return similarity
|
| 73 |
+
|
| 74 |
+
def alignmentStrings(self):
|
| 75 |
+
"""Override to handle word-based alignment"""
|
| 76 |
+
aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs]
|
| 77 |
+
aligned_graph = [
|
| 78 |
+
self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs
|
| 79 |
+
]
|
| 80 |
+
return " ".join(aligned_seq), " ".join(aligned_graph)
|
| 81 |
+
|
| 82 |
+
def alignStringToGraphFast(self):
|
| 83 |
+
if not isinstance(self.sequence, list):
|
| 84 |
+
raise TypeError("Sequence must be a list of words")
|
| 85 |
+
|
| 86 |
+
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = (
|
| 87 |
+
self.initializeDynamicProgrammingData()
|
| 88 |
+
)
|
| 89 |
+
# M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence
|
| 90 |
+
M, X, Y = 0, 1, 2
|
| 91 |
+
|
| 92 |
+
ni = self.graph.nodeiterator()
|
| 93 |
+
for i, node in enumerate(ni()):
|
| 94 |
+
gbase = node.text
|
| 95 |
+
|
| 96 |
+
for j, sbase in enumerate(self.sequence):
|
| 97 |
+
candidates_X , candidates_Y , candidates_M = [], [], []
|
| 98 |
+
candidates_X += [
|
| 99 |
+
(self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M),
|
| 100 |
+
(self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X),
|
| 101 |
+
(self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y)
|
| 102 |
+
]
|
| 103 |
+
for predIndex in self.prevIndices(node, nodeIDtoIndex):
|
| 104 |
+
candidates_Y += [
|
| 105 |
+
(self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M),
|
| 106 |
+
(self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X),
|
| 107 |
+
(self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y)
|
| 108 |
+
]
|
| 109 |
+
candidates_M += [
|
| 110 |
+
(self.matchscore(sbase, gbase) + scores[0, predIndex + 1, j], predIndex + 1, j, M),
|
| 111 |
+
(self.matchscore(sbase, gbase) + scores[1, predIndex + 1, j], predIndex + 1, j, X),
|
| 112 |
+
(self.matchscore(sbase, gbase) + scores[2, predIndex + 1, j], predIndex + 1, j, Y)
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
(
|
| 116 |
+
scores[0, i + 1, j + 1],
|
| 117 |
+
backGrphIdx[0, i + 1, j + 1],
|
| 118 |
+
backStrIdx[0, i + 1, j + 1],
|
| 119 |
+
backMtxIdx[0, i + 1, j + 1],
|
| 120 |
+
) = max(candidates_M)
|
| 121 |
+
(
|
| 122 |
+
scores[1, i + 1, j + 1],
|
| 123 |
+
backGrphIdx[1, i + 1, j + 1],
|
| 124 |
+
backStrIdx[1, i + 1, j + 1],
|
| 125 |
+
backMtxIdx[1, i + 1, j + 1],
|
| 126 |
+
) = max(candidates_X)
|
| 127 |
+
(
|
| 128 |
+
scores[2, i + 1, j + 1],
|
| 129 |
+
backGrphIdx[2, i + 1, j + 1],
|
| 130 |
+
backStrIdx[2, i + 1, j + 1],
|
| 131 |
+
backMtxIdx[2, i + 1, j + 1],
|
| 132 |
+
) = max(candidates_Y)
|
| 133 |
+
|
| 134 |
+
return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)
|
src/poa_graph.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from Jonathan Dursi
|
| 3 |
+
https://github.com/ljdursi/poapy
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import collections
|
| 7 |
+
import textwrap
|
| 8 |
+
from typing import Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import numpy
|
| 11 |
+
|
| 12 |
+
from .alignment import SeqGraphAlignment
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Node(object):
|
| 16 |
+
def __init__(self, nodeID: int = -1, text: str = ""):
|
| 17 |
+
self.ID = nodeID
|
| 18 |
+
self.text = text
|
| 19 |
+
self.inEdges = {}
|
| 20 |
+
self.outEdges = {}
|
| 21 |
+
self.alignedTo = []
|
| 22 |
+
|
| 23 |
+
def __str__(self):
|
| 24 |
+
return "(%d:%s)" % (self.ID, self.text)
|
| 25 |
+
|
| 26 |
+
def _add_edge(
|
| 27 |
+
self,
|
| 28 |
+
edgeset: Dict[int, "Node"],
|
| 29 |
+
neighbourID: int,
|
| 30 |
+
label: Union[int, List[int]],
|
| 31 |
+
from_neighbour: bool,
|
| 32 |
+
weight: int = 1,
|
| 33 |
+
):
|
| 34 |
+
if neighbourID is None:
|
| 35 |
+
return
|
| 36 |
+
# already present? just update labels
|
| 37 |
+
# otherwise create appropriately-ordered edge and proceed
|
| 38 |
+
if neighbourID in edgeset:
|
| 39 |
+
edgeset[neighbourID].weight += weight
|
| 40 |
+
if isinstance(label, list):
|
| 41 |
+
edgeset[neighbourID].labels.extend(label)
|
| 42 |
+
else:
|
| 43 |
+
edgeset[neighbourID].labels.append(label)
|
| 44 |
+
# remove duplicates
|
| 45 |
+
edgeset[neighbourID].labels = list(set(edgeset[neighbourID].labels))
|
| 46 |
+
else:
|
| 47 |
+
if from_neighbour:
|
| 48 |
+
edge = Edge(outNodeID=neighbourID, inNodeID=self.ID, label=label, weight=weight)
|
| 49 |
+
else:
|
| 50 |
+
edge = Edge(outNodeID=self.ID, inNodeID=neighbourID, label=label, weight=weight)
|
| 51 |
+
edgeset[neighbourID] = edge
|
| 52 |
+
|
| 53 |
+
def addInEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
|
| 54 |
+
self._add_edge(self.inEdges, neighbourID, label, from_neighbour=True, weight=weight)
|
| 55 |
+
|
| 56 |
+
def addOutEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
|
| 57 |
+
self._add_edge(self.outEdges, neighbourID, label, from_neighbour=False, weight=weight)
|
| 58 |
+
|
| 59 |
+
def nextNode(self, label: int):
|
| 60 |
+
"""Returns the first (presumably only) outward neighbour
|
| 61 |
+
having the given edge label"""
|
| 62 |
+
nextID = None
|
| 63 |
+
for e in self.outEdges:
|
| 64 |
+
if label in self.outEdges[e].labels:
|
| 65 |
+
nextID = e
|
| 66 |
+
return nextID
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def inDegree(self):
|
| 70 |
+
return len(self.inEdges)
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def outDegree(self):
|
| 74 |
+
return len(self.outEdges)
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def weightedInDegree(self):
|
| 78 |
+
return sum(edge.weight for edge in self.inEdges.values())
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def weightedOutDegree(self):
|
| 82 |
+
return sum(edge.weight for edge in self.outEdges.values())
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def labels(self):
|
| 86 |
+
"""Returns all the labels associated with an in-edge or an out edge."""
|
| 87 |
+
labelset = set([])
|
| 88 |
+
for e in list(self.inEdges.values()):
|
| 89 |
+
labelset = labelset.union(e.labels)
|
| 90 |
+
for e in list(self.outEdges.values()):
|
| 91 |
+
labelset = labelset.union(e.labels)
|
| 92 |
+
return list(labelset)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Edge(object):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
inNodeID: int = -1,
|
| 99 |
+
outNodeID: int = -1,
|
| 100 |
+
label: Optional[Union[int, List[int]]] = None,
|
| 101 |
+
weight: int = 1,
|
| 102 |
+
):
|
| 103 |
+
self.inNodeID = inNodeID
|
| 104 |
+
self.outNodeID = outNodeID
|
| 105 |
+
|
| 106 |
+
self.weight = weight
|
| 107 |
+
|
| 108 |
+
if label is None:
|
| 109 |
+
self.labels = []
|
| 110 |
+
elif isinstance(label, list):
|
| 111 |
+
self.labels = label
|
| 112 |
+
else:
|
| 113 |
+
self.labels = [label]
|
| 114 |
+
|
| 115 |
+
def addLabel(self, newlabel):
|
| 116 |
+
self.labels.append(newlabel)
|
| 117 |
+
|
| 118 |
+
def __str__(self):
|
| 119 |
+
nodestr = "(%d) -> (%d) " % (self.inNodeID, self.outNodeID)
|
| 120 |
+
if self.labels is None:
|
| 121 |
+
return nodestr
|
| 122 |
+
else:
|
| 123 |
+
return nodestr + self.labels.__str__()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class POAGraph(object):
|
| 127 |
+
def addUnmatchedSeq(self, seq, label: int = -1, updateSequences=True):
|
| 128 |
+
"""Add a completely independant (sub)string to the graph,
|
| 129 |
+
and return node index to initial and final node"""
|
| 130 |
+
if seq is None:
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
firstID, lastID = None, None
|
| 134 |
+
neededSort = self.needsSort
|
| 135 |
+
|
| 136 |
+
for text in seq:
|
| 137 |
+
nodeID = self.addNode(text)
|
| 138 |
+
if firstID is None:
|
| 139 |
+
firstID = nodeID
|
| 140 |
+
if lastID is not None:
|
| 141 |
+
self.addEdge(lastID, nodeID, label)
|
| 142 |
+
lastID = nodeID
|
| 143 |
+
|
| 144 |
+
self._needsort = neededSort # no new order problems introduced
|
| 145 |
+
if updateSequences:
|
| 146 |
+
self._seqs.append(seq)
|
| 147 |
+
self._labels.append(label)
|
| 148 |
+
self._starts.append(firstID)
|
| 149 |
+
return firstID, lastID
|
| 150 |
+
|
| 151 |
+
def __init__(self, seq=None, label: Optional[Union[int, List[int]]] = None):
|
| 152 |
+
self._nextnodeID = 0
|
| 153 |
+
self._nnodes = 0
|
| 154 |
+
self._nedges = 0
|
| 155 |
+
self.nodedict = {}
|
| 156 |
+
self.nodeidlist = [] # allows a (partial) order to be imposed on the nodes
|
| 157 |
+
self._needsort = False
|
| 158 |
+
self._labels = []
|
| 159 |
+
self._seqs = []
|
| 160 |
+
self._starts = []
|
| 161 |
+
|
| 162 |
+
if seq is not None:
|
| 163 |
+
self.addUnmatchedSeq(seq, label)
|
| 164 |
+
|
| 165 |
+
def nodeIdxToBase(self, idx):
|
| 166 |
+
return self.nodedict[self.nodeidlist[idx]].text
|
| 167 |
+
|
| 168 |
+
def addNode(self, text):
|
| 169 |
+
nid = self._nextnodeID
|
| 170 |
+
newnode = Node(nid, text)
|
| 171 |
+
self.nodedict[nid] = newnode
|
| 172 |
+
self.nodeidlist.append(nid)
|
| 173 |
+
self._nnodes += 1
|
| 174 |
+
self._nextnodeID += 1
|
| 175 |
+
self._needsSort = True
|
| 176 |
+
return nid
|
| 177 |
+
|
| 178 |
+
def addEdge(self, start, end, label, weight: int = 1):
|
| 179 |
+
if start is None or end is None:
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
if start not in self.nodedict:
|
| 183 |
+
raise KeyError("addEdge: Start node not in graph: " + str(start))
|
| 184 |
+
if end not in self.nodedict:
|
| 185 |
+
raise KeyError("addEdge: End node not in graph: " + str(end))
|
| 186 |
+
|
| 187 |
+
oldNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
|
| 188 |
+
|
| 189 |
+
self.nodedict[start].addOutEdge(end, label, weight)
|
| 190 |
+
self.nodedict[end].addInEdge(start, label, weight)
|
| 191 |
+
|
| 192 |
+
newNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
|
| 193 |
+
|
| 194 |
+
if newNodeEdges != oldNodeEdges:
|
| 195 |
+
self._nedges += 1
|
| 196 |
+
|
| 197 |
+
self._needsSort = True
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def needsSort(self):
|
| 202 |
+
return self._needsort
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def nNodes(self):
|
| 206 |
+
return self._nnodes
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def nEdges(self):
|
| 210 |
+
return self._nedges
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def num_sequences(self):
|
| 214 |
+
return len(self._seqs)
|
| 215 |
+
|
| 216 |
+
def get_sequences(self):
|
| 217 |
+
return self._seqs
|
| 218 |
+
|
| 219 |
+
def _simplified_graph_rep(self):
|
| 220 |
+
|
| 221 |
+
node_to_pn = {}
|
| 222 |
+
pn_to_nodes = {}
|
| 223 |
+
|
| 224 |
+
# Find the mappings from nodes to pseudonodes
|
| 225 |
+
cur_pnid = 0
|
| 226 |
+
for _, node in self.nodedict.items():
|
| 227 |
+
if node.ID not in node_to_pn:
|
| 228 |
+
node_ids = [node.ID] + node.alignedTo
|
| 229 |
+
pn_to_nodes[cur_pnid] = node_ids
|
| 230 |
+
for nid in node_ids:
|
| 231 |
+
node_to_pn[nid] = cur_pnid
|
| 232 |
+
cur_pnid += 1
|
| 233 |
+
|
| 234 |
+
# create the pseudonodes
|
| 235 |
+
Pseudonode = collections.namedtuple(
|
| 236 |
+
"Pseudonode", ["pnode_id", "predecessors", "successors", "node_ids"]
|
| 237 |
+
)
|
| 238 |
+
pseudonodes = []
|
| 239 |
+
|
| 240 |
+
for pnid in range(cur_pnid):
|
| 241 |
+
nids, preds, succs = pn_to_nodes[pnid], [], []
|
| 242 |
+
for nid in nids:
|
| 243 |
+
node = self.nodedict[nid]
|
| 244 |
+
preds += [node_to_pn[inEdge.outNodeID] for _, inEdge in node.inEdges.items()]
|
| 245 |
+
succs += [node_to_pn[outEdge.inNodeID] for _, outEdge in node.outEdges.items()]
|
| 246 |
+
|
| 247 |
+
pn = Pseudonode(pnode_id=pnid, predecessors=preds, successors=succs, node_ids=nids)
|
| 248 |
+
pseudonodes.append(pn)
|
| 249 |
+
|
| 250 |
+
return pseudonodes
|
| 251 |
+
|
| 252 |
+
def toposort(self):
|
| 253 |
+
"""Sorts node list so that all incoming edges come from nodes earlier in the list."""
|
| 254 |
+
sortedlist = []
|
| 255 |
+
completed = set([])
|
| 256 |
+
|
| 257 |
+
#
|
| 258 |
+
# The topological sort of this graph is complicated by the alignedTo edges;
|
| 259 |
+
# we want to nodes connected by such edges to remain near each other in the
|
| 260 |
+
# topological sort.
|
| 261 |
+
#
|
| 262 |
+
# Here we'll create a simple version of the graph that merges nodes that
|
| 263 |
+
# are alignedTo each other, performs the sort, and then decomposes the
|
| 264 |
+
# 'pseudonodes'.
|
| 265 |
+
#
|
| 266 |
+
# The need for this suggests that the way the graph is currently represented
|
| 267 |
+
# isn't quite right and needs some rethinking.
|
| 268 |
+
#
|
| 269 |
+
|
| 270 |
+
pseudonodes = self._simplified_graph_rep()
|
| 271 |
+
|
| 272 |
+
def dfs(start, complete, sortedlist):
|
| 273 |
+
stack, started = [start], set()
|
| 274 |
+
while stack:
|
| 275 |
+
pnodeID = stack.pop()
|
| 276 |
+
|
| 277 |
+
if pnodeID in complete:
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
if pnodeID in started:
|
| 281 |
+
complete.add(pnodeID)
|
| 282 |
+
for nid in pseudonodes[pnodeID].node_ids:
|
| 283 |
+
sortedlist.insert(0, nid)
|
| 284 |
+
started.remove(pnodeID)
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
successors = pseudonodes[pnodeID].successors
|
| 288 |
+
started.add(pnodeID)
|
| 289 |
+
stack.append(pnodeID)
|
| 290 |
+
stack.extend(successors)
|
| 291 |
+
|
| 292 |
+
while len(sortedlist) < self.nNodes:
|
| 293 |
+
found = None
|
| 294 |
+
for pnid in range(len(pseudonodes)):
|
| 295 |
+
if pnid not in completed and len(pseudonodes[pnid].predecessors) == 0:
|
| 296 |
+
found = pnid
|
| 297 |
+
break
|
| 298 |
+
assert found is not None
|
| 299 |
+
dfs(found, completed, sortedlist)
|
| 300 |
+
|
| 301 |
+
assert len(sortedlist) == self.nNodes
|
| 302 |
+
self.nodeidlist = sortedlist
|
| 303 |
+
self._needsSort = False
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
def testsort(self):
|
| 307 |
+
"""Test the nodeidlist to make sure it is topologically sorted:
|
| 308 |
+
eg, all predecessors of a node preceed the node in the list"""
|
| 309 |
+
if self.nodeidlist is None:
|
| 310 |
+
return
|
| 311 |
+
seen_nodes = set()
|
| 312 |
+
for nodeidx in self.nodeidlist:
|
| 313 |
+
node = self.nodedict[nodeidx]
|
| 314 |
+
for in_neighbour in node.inEdges:
|
| 315 |
+
assert in_neighbour in seen_nodes
|
| 316 |
+
seen_nodes.add(nodeidx)
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
def nodeiterator(self):
|
| 320 |
+
if self.needsSort:
|
| 321 |
+
self.toposort()
|
| 322 |
+
|
| 323 |
+
def nodegenerator():
|
| 324 |
+
for nodeidx in self.nodeidlist:
|
| 325 |
+
yield self.nodedict[nodeidx]
|
| 326 |
+
|
| 327 |
+
return nodegenerator
|
| 328 |
+
|
| 329 |
+
def __str__(self):
|
| 330 |
+
selfstr = ""
|
| 331 |
+
ni = self.nodeiterator()
|
| 332 |
+
for node in ni():
|
| 333 |
+
selfstr += node.__str__() + "\n"
|
| 334 |
+
for outIdx in node.outEdges:
|
| 335 |
+
selfstr += " " + node.outEdges[outIdx].__str__() + "\n"
|
| 336 |
+
return selfstr
|
| 337 |
+
|
| 338 |
+
def incorporateSeqAlignment(self, alignment: SeqGraphAlignment, seq, label: int = -1):
|
| 339 |
+
"""Incorporate a SeqGraphAlignment into the graph."""
|
| 340 |
+
newseq = alignment.sequence
|
| 341 |
+
stringidxs = alignment.stringidxs
|
| 342 |
+
nodeidxs = alignment.nodeidxs
|
| 343 |
+
|
| 344 |
+
firstID = None
|
| 345 |
+
headID = None
|
| 346 |
+
tailID = None
|
| 347 |
+
|
| 348 |
+
path = []
|
| 349 |
+
# head, tail of sequence may be unaligned; just add those into the
|
| 350 |
+
# graph directly
|
| 351 |
+
validstringidxs = [si for si in stringidxs if si is not None]
|
| 352 |
+
startSeqIdx, endSeqIdx = validstringidxs[0], validstringidxs[-1]
|
| 353 |
+
if startSeqIdx > 0:
|
| 354 |
+
firstID, headID = self.addUnmatchedSeq(
|
| 355 |
+
newseq[0:startSeqIdx], label, updateSequences=False
|
| 356 |
+
)
|
| 357 |
+
if endSeqIdx < len(newseq):
|
| 358 |
+
tailID, __ = self.addUnmatchedSeq(newseq[endSeqIdx + 1 :], label, updateSequences=False)
|
| 359 |
+
|
| 360 |
+
# now we march along the aligned part. For each text, we find or create
|
| 361 |
+
# a node in the graph:
|
| 362 |
+
# - if unmatched, the corresponding node is a new node
|
| 363 |
+
# - if matched:
|
| 364 |
+
# - if matched to a node with the same text, the node is that node
|
| 365 |
+
# - if matched to a node with a different text whch is in turn
|
| 366 |
+
# aligned to a node with the same text, that aligned node is
|
| 367 |
+
# the node
|
| 368 |
+
# - otherwise, we create a new node.
|
| 369 |
+
# In all cases, we create edges (or add labels) threading through the
|
| 370 |
+
# nodes.
|
| 371 |
+
for sindex, matchID in zip(stringidxs, nodeidxs):
|
| 372 |
+
if sindex is None:
|
| 373 |
+
continue
|
| 374 |
+
text = newseq[sindex]
|
| 375 |
+
if matchID is None:
|
| 376 |
+
nodeID = self.addNode(text)
|
| 377 |
+
elif self.nodedict[matchID].text == text:
|
| 378 |
+
nodeID = matchID
|
| 379 |
+
else:
|
| 380 |
+
otherAligns = self.nodedict[matchID].alignedTo
|
| 381 |
+
foundNode = None
|
| 382 |
+
for otherNodeID in otherAligns:
|
| 383 |
+
if self.nodedict[otherNodeID].text == text:
|
| 384 |
+
foundNode = otherNodeID
|
| 385 |
+
if foundNode is None:
|
| 386 |
+
nodeID = self.addNode(text)
|
| 387 |
+
self.nodedict[nodeID].alignedTo = [matchID] + otherAligns
|
| 388 |
+
for otherNodeID in [matchID] + otherAligns:
|
| 389 |
+
self.nodedict[otherNodeID].alignedTo.append(nodeID)
|
| 390 |
+
else:
|
| 391 |
+
nodeID = foundNode
|
| 392 |
+
|
| 393 |
+
self.addEdge(headID, nodeID, label)
|
| 394 |
+
headID = nodeID
|
| 395 |
+
if firstID is None:
|
| 396 |
+
firstID = headID
|
| 397 |
+
|
| 398 |
+
path.append(nodeID)
|
| 399 |
+
|
| 400 |
+
# finished the unaligned portion: now add an edge from the current headID to the tailID.
|
| 401 |
+
self.addEdge(headID, tailID, label)
|
| 402 |
+
|
| 403 |
+
# resort
|
| 404 |
+
self.toposort()
|
| 405 |
+
|
| 406 |
+
self._seqs.append(seq)
|
| 407 |
+
self._labels.append(label)
|
| 408 |
+
self._starts.append(firstID)
|
| 409 |
+
self._seq_paths[label] = path
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
def consensus(self, excludeLabels=None):
|
| 413 |
+
if excludeLabels is None:
|
| 414 |
+
excludeLabels = []
|
| 415 |
+
|
| 416 |
+
if self.needsSort:
|
| 417 |
+
self.toposort()
|
| 418 |
+
|
| 419 |
+
nodesInReverse = self.nodeidlist[::-1]
|
| 420 |
+
maxnodeID = max(nodesInReverse) + 1
|
| 421 |
+
nextInPath = [-1] * maxnodeID
|
| 422 |
+
scores = numpy.zeros((maxnodeID))
|
| 423 |
+
|
| 424 |
+
for nodeID in nodesInReverse:
|
| 425 |
+
bestWeightScoreEdge = (-1, -1, None)
|
| 426 |
+
for neighbourID in self.nodedict[nodeID].outEdges:
|
| 427 |
+
# print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
|
| 428 |
+
e = self.nodedict[nodeID].outEdges[neighbourID]
|
| 429 |
+
weightScoreEdge = (e.weight, scores[neighbourID], neighbourID)
|
| 430 |
+
|
| 431 |
+
if weightScoreEdge > bestWeightScoreEdge:
|
| 432 |
+
bestWeightScoreEdge = weightScoreEdge
|
| 433 |
+
|
| 434 |
+
scores[nodeID] = sum(bestWeightScoreEdge[0:2])
|
| 435 |
+
nextInPath[nodeID] = bestWeightScoreEdge[2]
|
| 436 |
+
|
| 437 |
+
pos = numpy.argmax(scores)
|
| 438 |
+
path = []
|
| 439 |
+
bases = []
|
| 440 |
+
labels = []
|
| 441 |
+
while pos is not None and pos > -1:
|
| 442 |
+
path.append(pos)
|
| 443 |
+
bases.append(self.nodedict[pos].text)
|
| 444 |
+
labels.append(self.nodedict[pos].labels)
|
| 445 |
+
pos = nextInPath[pos]
|
| 446 |
+
|
| 447 |
+
# ignore END node
|
| 448 |
+
path = path[:-1]
|
| 449 |
+
bases = bases[:-1]
|
| 450 |
+
labels = labels[:-1]
|
| 451 |
+
return path, bases, labels
|
| 452 |
+
|
| 453 |
+
def allConsenses(self, maxfraction=0.5):
|
| 454 |
+
allpaths = []
|
| 455 |
+
allbases = []
|
| 456 |
+
alllabels = []
|
| 457 |
+
exclusions = []
|
| 458 |
+
|
| 459 |
+
passno = 0
|
| 460 |
+
lastlen = 1000
|
| 461 |
+
maxpasses = 10
|
| 462 |
+
|
| 463 |
+
while len(exclusions) < len(self._labels) and lastlen >= 10 and passno < maxpasses:
|
| 464 |
+
path, bases, labellists = self.consensus(exclusions)
|
| 465 |
+
if len(path) > 0:
|
| 466 |
+
allpaths.append(path)
|
| 467 |
+
allbases.append(bases)
|
| 468 |
+
alllabels.append(labellists)
|
| 469 |
+
|
| 470 |
+
labelcounts = collections.defaultdict(int)
|
| 471 |
+
for ll in labellists:
|
| 472 |
+
for label in ll:
|
| 473 |
+
labelcounts[label] += 1
|
| 474 |
+
|
| 475 |
+
for label, seq in zip(self._labels, self._seqs):
|
| 476 |
+
if label in labelcounts and labelcounts[label] >= maxfraction * len(seq):
|
| 477 |
+
exclusions.append(label)
|
| 478 |
+
|
| 479 |
+
lastlen = len(path)
|
| 480 |
+
passno += 1
|
| 481 |
+
|
| 482 |
+
return list(zip(allpaths, allbases, alllabels))
|
| 483 |
+
|
| 484 |
+
def generateAlignmentStrings(self):
|
| 485 |
+
"""Return a list of strings corresponding to the alignments in the graph"""
|
| 486 |
+
|
| 487 |
+
# Step 1: assign node IDs to columns in the output
|
| 488 |
+
# column_index[node.ID] is the position in the toposorted node list
|
| 489 |
+
# of the node itself, or the earliest node it is aligned to.
|
| 490 |
+
column_index = {}
|
| 491 |
+
current_column = 0
|
| 492 |
+
|
| 493 |
+
# go through nodes in toposort order
|
| 494 |
+
ni = self.nodeiterator()
|
| 495 |
+
for node in ni():
|
| 496 |
+
other_columns = [
|
| 497 |
+
column_index[other] for other in node.alignedTo if other in column_index
|
| 498 |
+
]
|
| 499 |
+
if other_columns:
|
| 500 |
+
found_idx = min(other_columns)
|
| 501 |
+
else:
|
| 502 |
+
found_idx = current_column
|
| 503 |
+
current_column += 1
|
| 504 |
+
|
| 505 |
+
column_index[node.ID] = found_idx
|
| 506 |
+
|
| 507 |
+
ncolumns = current_column
|
| 508 |
+
|
| 509 |
+
# Step 2: given the column indexes, populate the strings
|
| 510 |
+
# corresponding to the sequences inserted in the graph
|
| 511 |
+
seqnames = []
|
| 512 |
+
alignstrings = []
|
| 513 |
+
for label, start in zip(self._labels, self._starts):
|
| 514 |
+
seqnames.append(label)
|
| 515 |
+
curnode_id = start
|
| 516 |
+
charlist = ["-"] * ncolumns
|
| 517 |
+
while curnode_id is not None:
|
| 518 |
+
node = self.nodedict[curnode_id]
|
| 519 |
+
charlist[column_index[curnode_id]] = node.text
|
| 520 |
+
curnode_id = node.nextNode(label)
|
| 521 |
+
alignstrings.append("".join(charlist))
|
| 522 |
+
|
| 523 |
+
# Step 3: Same as step 2, but with consensus sequences
|
| 524 |
+
consenses = self.allConsenses()
|
| 525 |
+
for i, consensus in enumerate(consenses):
|
| 526 |
+
seqnames.append("Consensus" + str(i))
|
| 527 |
+
charlist = ["-"] * ncolumns
|
| 528 |
+
for path, text in zip(consensus[0], consensus[1]):
|
| 529 |
+
charlist[column_index[path]] = text
|
| 530 |
+
alignstrings.append("".join(charlist))
|
| 531 |
+
|
| 532 |
+
return list(zip(seqnames, alignstrings))
|
| 533 |
+
|
| 534 |
+
def jsOutput(self, verbose: bool = False, annotate_consensus: bool = True):
|
| 535 |
+
"""returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
|
| 536 |
+
|
| 537 |
+
# get the consensus sequence, which we'll use as the "spine" of the
|
| 538 |
+
# graph
|
| 539 |
+
pathdict = {}
|
| 540 |
+
if annotate_consensus:
|
| 541 |
+
path, __, __ = self.consensus()
|
| 542 |
+
lines = ["var nodes = ["]
|
| 543 |
+
|
| 544 |
+
ni = self.nodeiterator()
|
| 545 |
+
count = 0
|
| 546 |
+
for node in ni():
|
| 547 |
+
line = " {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text + '"'
|
| 548 |
+
if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
|
| 549 |
+
line += (
|
| 550 |
+
", x: "
|
| 551 |
+
+ str(pathdict[node.ID])
|
| 552 |
+
+ ", y: 0 , fixed: { x:true, y:false},"
|
| 553 |
+
+ "color: '#7BE141', is_consensus:true},"
|
| 554 |
+
)
|
| 555 |
+
else:
|
| 556 |
+
line += "},"
|
| 557 |
+
lines.append(line)
|
| 558 |
+
|
| 559 |
+
lines[-1] = lines[-1][:-1]
|
| 560 |
+
lines.append("];")
|
| 561 |
+
|
| 562 |
+
lines.append(" ")
|
| 563 |
+
|
| 564 |
+
lines.append("var edges = [")
|
| 565 |
+
ni = self.nodeiterator()
|
| 566 |
+
for node in ni():
|
| 567 |
+
nodeID = str(node.ID)
|
| 568 |
+
for edge in node.outEdges:
|
| 569 |
+
target = str(edge)
|
| 570 |
+
weight = str(len(node.outEdges[edge].labels) + 1.5)
|
| 571 |
+
lines.append(
|
| 572 |
+
" {from: "
|
| 573 |
+
+ nodeID
|
| 574 |
+
+ ", to: "
|
| 575 |
+
+ target
|
| 576 |
+
+ ", value: "
|
| 577 |
+
+ weight
|
| 578 |
+
+ ", color: '#4b72b0', arrows: 'to'},"
|
| 579 |
+
)
|
| 580 |
+
if verbose:
|
| 581 |
+
for alignededge in node.alignedTo:
|
| 582 |
+
# These edges indicate alignment to different bases, and are
|
| 583 |
+
# undirected; thus make sure we only plot them once:
|
| 584 |
+
if node.ID > alignededge:
|
| 585 |
+
continue
|
| 586 |
+
target = str(alignededge)
|
| 587 |
+
lines.append(
|
| 588 |
+
" {from: "
|
| 589 |
+
+ nodeID
|
| 590 |
+
+ ", to: "
|
| 591 |
+
+ target
|
| 592 |
+
+ ', value: 1, style: "dash-line", color: "red"},'
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
lines[-1] = lines[-1][:-1]
|
| 596 |
+
lines.append("];")
|
| 597 |
+
return lines
|
| 598 |
+
|
| 599 |
+
def htmlOutput(self, outfile, verbose: bool = False, annotate_consensus: bool = True):
|
| 600 |
+
header = """
|
| 601 |
+
<!doctype html>
|
| 602 |
+
<html>
|
| 603 |
+
<head>
|
| 604 |
+
<title>POA Graph Alignment</title>
|
| 605 |
+
|
| 606 |
+
<script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
|
| 607 |
+
</head>
|
| 608 |
+
|
| 609 |
+
<body>
|
| 610 |
+
|
| 611 |
+
<div id="loadingProgress">0%</div>
|
| 612 |
+
|
| 613 |
+
<div id="mynetwork"></div>
|
| 614 |
+
|
| 615 |
+
<script type="text/javascript">
|
| 616 |
+
// create a network
|
| 617 |
+
"""
|
| 618 |
+
outfile.write(textwrap.dedent(header[1:]))
|
| 619 |
+
lines = self.jsOutput(verbose=verbose, annotate_consensus=annotate_consensus)
|
| 620 |
+
for line in lines:
|
| 621 |
+
outfile.write(line + "\n")
|
| 622 |
+
footer = """
|
| 623 |
+
var container = document.getElementById('mynetwork');
|
| 624 |
+
var data= {
|
| 625 |
+
nodes: nodes,
|
| 626 |
+
edges: edges,
|
| 627 |
+
};
|
| 628 |
+
var options = {
|
| 629 |
+
width: '100%',
|
| 630 |
+
height: '800px',
|
| 631 |
+
physics: {
|
| 632 |
+
enabled: false,
|
| 633 |
+
stabilization: {
|
| 634 |
+
updateInterval: 10,
|
| 635 |
+
},
|
| 636 |
+
hierarchicalRepulsion: {
|
| 637 |
+
avoidOverlap: 0.9,
|
| 638 |
+
},
|
| 639 |
+
},
|
| 640 |
+
edges: {
|
| 641 |
+
color: {
|
| 642 |
+
inherit: false
|
| 643 |
+
}
|
| 644 |
+
},
|
| 645 |
+
layout: {
|
| 646 |
+
hierarchical: {
|
| 647 |
+
direction: "UD",
|
| 648 |
+
sortMethod: "directed",
|
| 649 |
+
shakeTowards: "roots",
|
| 650 |
+
levelSeparation: 150, // Adjust as needed
|
| 651 |
+
nodeSpacing: 100, // Adjust as needed
|
| 652 |
+
treeSpacing: 200, // Adjust as needed
|
| 653 |
+
parentCentralization: true,
|
| 654 |
+
}
|
| 655 |
+
}
|
| 656 |
+
};
|
| 657 |
+
var network = new vis.Network(container, data, options);
|
| 658 |
+
|
| 659 |
+
network.on('beforeDrawing', function(ctx) {
|
| 660 |
+
nodes.forEach(function(node) {
|
| 661 |
+
if (node.isConsensus) {
|
| 662 |
+
// Set the level of spine nodes to the bottom
|
| 663 |
+
network.body.data.nodes.update({
|
| 664 |
+
id: node.id,
|
| 665 |
+
level: 0 // Set level to 0 for spine nodes
|
| 666 |
+
});
|
| 667 |
+
}
|
| 668 |
+
});
|
| 669 |
+
});
|
| 670 |
+
|
| 671 |
+
network.on("stabilizationProgress", function (params) {
|
| 672 |
+
document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
|
| 673 |
+
});
|
| 674 |
+
network.once("stabilizationIterationsDone", function () {
|
| 675 |
+
document.getElementById("loadingProgress").innerText = "100%";
|
| 676 |
+
setTimeout(function () {
|
| 677 |
+
document.getElementById("loadingProgress").style.display = "none";
|
| 678 |
+
}, 500);
|
| 679 |
+
});
|
| 680 |
+
</script>
|
| 681 |
+
|
| 682 |
+
</body>
|
| 683 |
+
</html>
|
| 684 |
+
"""
|
| 685 |
+
outfile.write(textwrap.dedent(footer))
|
src/text_poa_graph.py
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced version of POAGraph for text alignment
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pickle
|
| 6 |
+
import textwrap
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from src.text_poa_graph_utils import path_sim_llm
|
| 13 |
+
from src.global_edit_utils import clean_up_text
|
| 14 |
+
|
| 15 |
+
from .new_text_alignment import TextSeqGraphAlignment
|
| 16 |
+
from .poa_graph import Node, POAGraph
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TextNode(Node):
|
| 20 |
+
def __init__(self, nodeID=-1, text=""):
|
| 21 |
+
super().__init__(nodeID, text)
|
| 22 |
+
self.variations = {} # Track alternate phrasings
|
| 23 |
+
self.sequences = [] # Track sequences that contain this node
|
| 24 |
+
self.influenceScore = 0
|
| 25 |
+
self.num_tokens_used = 0
|
| 26 |
+
|
| 27 |
+
def add_variation(self, text, sequence_id):
|
| 28 |
+
self.variations[sequence_id] = text
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def is_stable(self):
|
| 32 |
+
"""A node is stable if it appears frequently enough relative to total sequences"""
|
| 33 |
+
return self.frequency >= self.graph.stability_threshold
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TextPOAGraph(POAGraph):
|
| 37 |
+
def __init__(self, text=None, label=-1):
|
| 38 |
+
self.consensus_node_ids = []
|
| 39 |
+
self._seq_paths = {}
|
| 40 |
+
self.end_id = -1
|
| 41 |
+
self.start_id = -1
|
| 42 |
+
self.failed = False
|
| 43 |
+
self.num_input_tokens_used = 0
|
| 44 |
+
self.num_output_tokens_used = 0
|
| 45 |
+
super().__init__(text, label)
|
| 46 |
+
|
| 47 |
+
def addNode(self, text):
|
| 48 |
+
"""Override to use TextNode"""
|
| 49 |
+
nid = self._nextnodeID
|
| 50 |
+
newnode = TextNode(nid, text)
|
| 51 |
+
self.nodedict[nid] = newnode
|
| 52 |
+
self.nodeidlist.append(nid)
|
| 53 |
+
self._nnodes += 1
|
| 54 |
+
self._nextnodeID += 1
|
| 55 |
+
self._needsSort = True
|
| 56 |
+
return nid
|
| 57 |
+
|
| 58 |
+
def addUnmatchedSeq(self, text, label=-1, updateSequences=True):
|
| 59 |
+
"""Modified to handle text sequences"""
|
| 60 |
+
if text is None:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
# Handle both string and list input
|
| 64 |
+
if isinstance(text, str):
|
| 65 |
+
words = text.split()
|
| 66 |
+
else:
|
| 67 |
+
words = text
|
| 68 |
+
|
| 69 |
+
firstID, lastID = None, None
|
| 70 |
+
neededSort = self.needsSort
|
| 71 |
+
|
| 72 |
+
path = []
|
| 73 |
+
for word in words:
|
| 74 |
+
nodeID = self.addNode(word)
|
| 75 |
+
if firstID is None:
|
| 76 |
+
firstID = nodeID
|
| 77 |
+
if lastID is not None:
|
| 78 |
+
self.addEdge(lastID, nodeID, label=label)
|
| 79 |
+
lastID = nodeID
|
| 80 |
+
path.append(nodeID)
|
| 81 |
+
|
| 82 |
+
self._needsort = neededSort
|
| 83 |
+
if updateSequences:
|
| 84 |
+
self._seqs.append(words)
|
| 85 |
+
self._labels.append(label)
|
| 86 |
+
self._starts.append(firstID)
|
| 87 |
+
self._seq_paths[label] = path
|
| 88 |
+
|
| 89 |
+
return firstID, lastID
|
| 90 |
+
|
| 91 |
+
def add_text(self, text, label=-1):
|
| 92 |
+
"""Main method to add new text to the alignment"""
|
| 93 |
+
if len(self._seqs) == 0:
|
| 94 |
+
# First sequence - just add it
|
| 95 |
+
self.addUnmatchedSeq(text, label)
|
| 96 |
+
else:
|
| 97 |
+
# Align to existing graph
|
| 98 |
+
alignment = TextSeqGraphAlignment(
|
| 99 |
+
text, self, matchscore=2, mismatchscore=-1, gapscore=-2
|
| 100 |
+
)
|
| 101 |
+
self.incorporateSeqAlignment(alignment, text, label)
|
| 102 |
+
|
| 103 |
+
# Update node frequencies
|
| 104 |
+
self._update_frequencies()
|
| 105 |
+
|
| 106 |
+
def removeNode(self, nodeID):
|
| 107 |
+
"""Override to handle text nodes"""
|
| 108 |
+
node = self.nodedict[nodeID]
|
| 109 |
+
if node is None:
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
# Remove all edges to this node
|
| 113 |
+
out_edges = node.outEdges.copy()
|
| 114 |
+
in_edges = node.inEdges.copy()
|
| 115 |
+
|
| 116 |
+
for edge in out_edges:
|
| 117 |
+
self.removeEdge(node.ID, edge)
|
| 118 |
+
for edge in in_edges:
|
| 119 |
+
self.removeEdge(edge, node.ID)
|
| 120 |
+
|
| 121 |
+
# Remove from graph
|
| 122 |
+
del self.nodedict[nodeID]
|
| 123 |
+
self.nodeidlist.remove(nodeID)
|
| 124 |
+
|
| 125 |
+
for path in self._seq_paths.values():
|
| 126 |
+
if nodeID in path:
|
| 127 |
+
path.remove(nodeID)
|
| 128 |
+
|
| 129 |
+
self._nnodes -= 1
|
| 130 |
+
self._needsSort = True
|
| 131 |
+
|
| 132 |
+
def removeEdge(self, nodeID1, nodeID2):
|
| 133 |
+
"""Override to handle text nodes"""
|
| 134 |
+
node1 = self.nodedict[nodeID1]
|
| 135 |
+
node2 = self.nodedict[nodeID2]
|
| 136 |
+
|
| 137 |
+
if node1 is None or node2 is None:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# Remove from graph
|
| 141 |
+
del node1.outEdges[nodeID2]
|
| 142 |
+
del node2.inEdges[nodeID1]
|
| 143 |
+
|
| 144 |
+
def merge_consensus_nodes(self, verbose: bool = False):
|
| 145 |
+
self.toposort()
|
| 146 |
+
# reset consensus node ids
|
| 147 |
+
self.consensus_node_ids = []
|
| 148 |
+
nodes = list(self.nodeiterator()())
|
| 149 |
+
consensus_segments = []
|
| 150 |
+
i = 0
|
| 151 |
+
while i < len(nodes):
|
| 152 |
+
node = nodes[i]
|
| 153 |
+
out_weight = sum(e.weight for e in node.outEdges.values())
|
| 154 |
+
in_weight = sum(e.weight for e in node.inEdges.values())
|
| 155 |
+
|
| 156 |
+
if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]:
|
| 157 |
+
consensus_segment = [(node.ID, node.text)]
|
| 158 |
+
next_node = node
|
| 159 |
+
while (i + 1) < len(nodes) and len(next_node.outEdges) == 1:
|
| 160 |
+
next_node = nodes[i + 1]
|
| 161 |
+
next_out_weight = sum(e.weight for e in next_node.outEdges.values())
|
| 162 |
+
next_in_weight = sum(e.weight for e in next_node.inEdges.values())
|
| 163 |
+
|
| 164 |
+
if (
|
| 165 |
+
next_out_weight != self.num_sequences
|
| 166 |
+
or next_in_weight != self.num_sequences
|
| 167 |
+
):
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
consensus_segment.append((next_node.ID, next_node.text))
|
| 171 |
+
i += 1
|
| 172 |
+
consensus_segments.append(consensus_segment)
|
| 173 |
+
i += 1
|
| 174 |
+
# merge consensus nodes into a single node
|
| 175 |
+
for segment in consensus_segments:
|
| 176 |
+
if len(segment) == 1:
|
| 177 |
+
self.consensus_node_ids.append(segment[0][0])
|
| 178 |
+
continue
|
| 179 |
+
merged_text = " ".join([text for _, text in segment])
|
| 180 |
+
first_node_id = segment[0][0]
|
| 181 |
+
last_node_id = segment[-1][0]
|
| 182 |
+
|
| 183 |
+
self.nodedict[last_node_id].text = merged_text
|
| 184 |
+
self.consensus_node_ids.append(last_node_id)
|
| 185 |
+
|
| 186 |
+
# attach all incoming edges to first node to last node
|
| 187 |
+
for id, edge in self.nodedict[first_node_id].inEdges.items():
|
| 188 |
+
weight = edge.weight
|
| 189 |
+
for _ in range(weight):
|
| 190 |
+
self.addEdge(id, last_node_id, label=edge.labels)
|
| 191 |
+
|
| 192 |
+
# delete all nodes except last node
|
| 193 |
+
for node_id, _ in segment[:-1]:
|
| 194 |
+
self.removeNode(node_id)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if verbose:
|
| 199 |
+
print(self.consensus_node_ids)
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
find all paths between start_node_id and end_node_id from original sequences
|
| 203 |
+
return a list of dictionaries with the following keys:
|
| 204 |
+
- path: list of node ids in the path (excluding start and including end)
|
| 205 |
+
- text: text of the path (excluding start and end)
|
| 206 |
+
- weight: minimal edge weight across all edges in the path
|
| 207 |
+
- labels: intersection of all edge labels in the path
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def find_paths_between(self, start_node_id: int, end_node_id: int):
|
| 211 |
+
# find all paths between start_node_id and end_node_id from original sequences
|
| 212 |
+
path_dicts = []
|
| 213 |
+
|
| 214 |
+
# keep track of visited paths to avoid duplicates
|
| 215 |
+
visited_paths = set()
|
| 216 |
+
|
| 217 |
+
for _, path in self._seq_paths.items():
|
| 218 |
+
start_index = path.index(start_node_id) if start_node_id in path else None
|
| 219 |
+
end_index = path.index(end_node_id) if end_node_id in path else None
|
| 220 |
+
|
| 221 |
+
# print(start_index, end_index)
|
| 222 |
+
# print(path)
|
| 223 |
+
|
| 224 |
+
if (
|
| 225 |
+
start_index is not None
|
| 226 |
+
and end_index is not None
|
| 227 |
+
and end_index - start_index > 1
|
| 228 |
+
and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths
|
| 229 |
+
):
|
| 230 |
+
# intersection of all edge labels in the path
|
| 231 |
+
path_labels = set.intersection(
|
| 232 |
+
*[
|
| 233 |
+
set(self.nodedict[next_node_id].inEdges[node_id].labels)
|
| 234 |
+
for node_id, next_node_id in zip(
|
| 235 |
+
path[start_index:end_index], path[start_index + 1 : end_index + 1]
|
| 236 |
+
)
|
| 237 |
+
]
|
| 238 |
+
)
|
| 239 |
+
path_weight = len(path_labels)
|
| 240 |
+
path_dicts.append(
|
| 241 |
+
{
|
| 242 |
+
|
| 243 |
+
"path": path[start_index + 1 : end_index + 1],
|
| 244 |
+
"body_text": " ".join(
|
| 245 |
+
[
|
| 246 |
+
self.nodedict[node_id].text
|
| 247 |
+
for node_id in path[start_index + 1 : end_index]
|
| 248 |
+
]
|
| 249 |
+
),
|
| 250 |
+
"begin_text": self.nodedict[path[start_index]].text,
|
| 251 |
+
"end_text": self.nodedict[path[end_index]].text,
|
| 252 |
+
"weight": path_weight,
|
| 253 |
+
"labels": path_labels,
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
visited_paths.add(tuple(path[start_index + 1 : end_index + 1]))
|
| 257 |
+
|
| 258 |
+
return path_dicts
|
| 259 |
+
|
| 260 |
+
def _follow_path(self, start_id):
|
| 261 |
+
"""Follow all possible paths from a node"""
|
| 262 |
+
paths = []
|
| 263 |
+
visited = set()
|
| 264 |
+
|
| 265 |
+
def dfs(node_id, current_path):
|
| 266 |
+
if node_id in visited:
|
| 267 |
+
return
|
| 268 |
+
visited.add(node_id)
|
| 269 |
+
node = self.nodedict[node_id]
|
| 270 |
+
|
| 271 |
+
if not node.outEdges:
|
| 272 |
+
paths.append(current_path + [node_id])
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
for next_id in node.outEdges:
|
| 276 |
+
dfs(next_id, current_path + [node_id])
|
| 277 |
+
|
| 278 |
+
dfs(start_id, [])
|
| 279 |
+
return paths
|
| 280 |
+
|
| 281 |
+
def merge_paths_between(
|
| 282 |
+
self,
|
| 283 |
+
start_node_id: int,
|
| 284 |
+
end_node_id: int,
|
| 285 |
+
path_sim_type: str = "llm",
|
| 286 |
+
verbose: bool = False,
|
| 287 |
+
**kwargs,
|
| 288 |
+
):
|
| 289 |
+
path_dicts = self.find_paths_between(start_node_id, end_node_id)
|
| 290 |
+
|
| 291 |
+
if path_sim_type == "llm":
|
| 292 |
+
api = kwargs.get("api", "openai")
|
| 293 |
+
model = kwargs.get("model", "gpt-4o-mini")
|
| 294 |
+
domain = kwargs.get("domain", None)
|
| 295 |
+
similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None)
|
| 296 |
+
|
| 297 |
+
def path_sim_func(path1_text, path2_text):
|
| 298 |
+
return path_sim_llm(
|
| 299 |
+
path1_text,
|
| 300 |
+
path2_text,
|
| 301 |
+
api=api,
|
| 302 |
+
model=model,
|
| 303 |
+
domain=domain,
|
| 304 |
+
custom_similarity_judge_prompt=similarity_judge_prompt,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
elif path_sim_type == "cosine":
|
| 308 |
+
pass
|
| 309 |
+
# embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 310 |
+
# threshold = kwargs.get("threshold", 0.9)
|
| 311 |
+
# path_sim_func = path_sim_cosine(embedding_model, threshold)
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Invalid path similarity type: {path_sim_type}")
|
| 314 |
+
|
| 315 |
+
# merge paths based on semantic similarity
|
| 316 |
+
path_equivalence_classes = {}
|
| 317 |
+
class_count = 0
|
| 318 |
+
|
| 319 |
+
for path_dict in path_dicts:
|
| 320 |
+
if verbose:
|
| 321 |
+
print(path_dict)
|
| 322 |
+
found_class = False
|
| 323 |
+
for _, eq_class in path_equivalence_classes.items():
|
| 324 |
+
# check if path dict is already in an equivalence class
|
| 325 |
+
path1_text = (
|
| 326 |
+
path_dict["begin_text"]
|
| 327 |
+
+ " "
|
| 328 |
+
+ path_dict["body_text"]
|
| 329 |
+
+ " "
|
| 330 |
+
+ path_dict["end_text"]
|
| 331 |
+
)
|
| 332 |
+
path2_text = (
|
| 333 |
+
eq_class[0]["begin_text"]
|
| 334 |
+
+ " "
|
| 335 |
+
+ eq_class[0]["body_text"]
|
| 336 |
+
+ " "
|
| 337 |
+
+ eq_class[0]["end_text"]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
judgement, num_input_tokens, num_output_tokens = path_sim_func(
|
| 341 |
+
path1_text, path2_text
|
| 342 |
+
)
|
| 343 |
+
self.num_input_tokens_used += num_input_tokens
|
| 344 |
+
self.num_output_tokens_used += num_output_tokens
|
| 345 |
+
if judgement:
|
| 346 |
+
eq_class.append(path_dict)
|
| 347 |
+
found_class = True
|
| 348 |
+
break
|
| 349 |
+
if not found_class:
|
| 350 |
+
class_count += 1
|
| 351 |
+
path_equivalence_classes[class_count] = [path_dict]
|
| 352 |
+
|
| 353 |
+
nodes_to_remove = set() # Track nodes to remove
|
| 354 |
+
for _, eq_class in path_equivalence_classes.items():
|
| 355 |
+
path_dict = eq_class[0]
|
| 356 |
+
|
| 357 |
+
if verbose:
|
| 358 |
+
print(eq_class)
|
| 359 |
+
# add new node with merged text
|
| 360 |
+
new_node_id = self.addNode(path_dict["body_text"])
|
| 361 |
+
for sequence_id in path_dict["labels"]:
|
| 362 |
+
self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
|
| 363 |
+
|
| 364 |
+
# collect nodes to remove from first path
|
| 365 |
+
nodes_to_remove.update(path_dict["path"][:-1])
|
| 366 |
+
|
| 367 |
+
# process data regarding weights and labels
|
| 368 |
+
labels = list(path_dict["labels"])
|
| 369 |
+
weight = path_dict["weight"]
|
| 370 |
+
self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
|
| 371 |
+
|
| 372 |
+
# Updated seq_paths for all labels to include new_node betwwen start_node and end_node
|
| 373 |
+
for label in labels:
|
| 374 |
+
index = self._seq_paths[label].index(start_node_id)
|
| 375 |
+
if (
|
| 376 |
+
index + 1 < len(self._seq_paths[label])
|
| 377 |
+
and self._seq_paths[label][index + 1] != new_node_id
|
| 378 |
+
):
|
| 379 |
+
self._seq_paths[label].insert(index + 1, new_node_id)
|
| 380 |
+
|
| 381 |
+
self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
|
| 382 |
+
|
| 383 |
+
self.nodedict[new_node_id].sequences = labels
|
| 384 |
+
# process additional paths
|
| 385 |
+
for path_dict in eq_class[1:]:
|
| 386 |
+
for sequence_id in path_dict["labels"]:
|
| 387 |
+
self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
|
| 388 |
+
nodes_to_remove.update(path_dict["path"][:-1])
|
| 389 |
+
|
| 390 |
+
# copy incoming edges to new node
|
| 391 |
+
labels = list(path_dict["labels"])
|
| 392 |
+
weight = path_dict["weight"]
|
| 393 |
+
self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
|
| 394 |
+
|
| 395 |
+
# Updated seq_paths for all labels to include new_node betwwen start_node and end_node
|
| 396 |
+
for label in labels:
|
| 397 |
+
index = self._seq_paths[label].index(start_node_id)
|
| 398 |
+
if (
|
| 399 |
+
index + 1 < len(self._seq_paths[label])
|
| 400 |
+
and self._seq_paths[label][index + 1] != new_node_id
|
| 401 |
+
):
|
| 402 |
+
self._seq_paths[label].insert(index + 1, new_node_id)
|
| 403 |
+
|
| 404 |
+
self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
|
| 405 |
+
self.nodedict[new_node_id].sequences.extend(labels)
|
| 406 |
+
|
| 407 |
+
self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences))
|
| 408 |
+
|
| 409 |
+
# Remove all collected nodes after processing
|
| 410 |
+
for node_id in nodes_to_remove:
|
| 411 |
+
if node_id in self.nodedict:
|
| 412 |
+
if verbose:
|
| 413 |
+
print(f"Removing node {node_id}")
|
| 414 |
+
self.removeNode(node_id)
|
| 415 |
+
|
| 416 |
+
def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs):
|
| 417 |
+
# add dummy end node to the end of the graph
|
| 418 |
+
if not self.consensus_node_ids:
|
| 419 |
+
self.merge_consensus_nodes(verbose=verbose)
|
| 420 |
+
|
| 421 |
+
self.toposort()
|
| 422 |
+
|
| 423 |
+
if self.start_id == -1:
|
| 424 |
+
if verbose:
|
| 425 |
+
print("Adding start node")
|
| 426 |
+
self.start_id = self.addNode(text="START")
|
| 427 |
+
self._nextnodeID += 1
|
| 428 |
+
self.consensus_node_ids.insert(0, self.start_id)
|
| 429 |
+
|
| 430 |
+
for label, path in self._seq_paths.items():
|
| 431 |
+
self.addEdge(self.start_id, path[0], label=label, weight=1)
|
| 432 |
+
path.insert(0, self.start_id)
|
| 433 |
+
|
| 434 |
+
if self.end_id == -1:
|
| 435 |
+
if verbose:
|
| 436 |
+
print("Adding end node")
|
| 437 |
+
self.end_id = self.addNode(text="END")
|
| 438 |
+
self._nextnodeID += 1
|
| 439 |
+
self.consensus_node_ids = self.consensus_node_ids + [self.end_id]
|
| 440 |
+
|
| 441 |
+
for label, path in self._seq_paths.items():
|
| 442 |
+
self.addEdge(path[-1], self.end_id, label=label, weight=1)
|
| 443 |
+
path.append(self.end_id)
|
| 444 |
+
|
| 445 |
+
for i in tqdm(range(len(self.consensus_node_ids) - 1)):
|
| 446 |
+
if verbose:
|
| 447 |
+
print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1])
|
| 448 |
+
self.merge_paths_between(
|
| 449 |
+
self.consensus_node_ids[i],
|
| 450 |
+
self.consensus_node_ids[i + 1],
|
| 451 |
+
path_sim_type=path_sim_type,
|
| 452 |
+
verbose=verbose,
|
| 453 |
+
**kwargs,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
def get_variable_node_ids(self):
|
| 457 |
+
return [
|
| 458 |
+
node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
def compress_paths_between(self, start_node_id: int, end_node_id: int):
|
| 462 |
+
pass
|
| 463 |
+
|
| 464 |
+
def compress_graph(self):
|
| 465 |
+
pass
|
| 466 |
+
|
| 467 |
+
def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2):
|
| 468 |
+
self.toposort()
|
| 469 |
+
direct_scores = []
|
| 470 |
+
for node in self.nodedict.values():
|
| 471 |
+
next_out_weight = sum(e.weight for e in node.outEdges.values())
|
| 472 |
+
next_in_weight = sum(e.weight for e in node.inEdges.values())
|
| 473 |
+
if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences:
|
| 474 |
+
out_list = []
|
| 475 |
+
for edge in node.outEdges.values():
|
| 476 |
+
for _ in range(len(set(edge.labels))):
|
| 477 |
+
out_list.append(np.mean([outcome[label] for label in set(edge.labels)]))
|
| 478 |
+
direct_scores.append((node.ID, np.var(out_list)))
|
| 479 |
+
|
| 480 |
+
scores = direct_scores.copy()
|
| 481 |
+
|
| 482 |
+
# Start from the end and propagate influence backward
|
| 483 |
+
for i in range(len(scores) - 2, -1, -1):
|
| 484 |
+
# Current node gets its direct influence plus discounted influence of next node
|
| 485 |
+
current_direct = scores[i][1]
|
| 486 |
+
next_total = scores[i + 1][1]
|
| 487 |
+
scores[i] = (scores[i][0], current_direct + discount_factor * next_total)
|
| 488 |
+
|
| 489 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
| 490 |
+
return scores
|
| 491 |
+
|
| 492 |
+
def jsOutput(
|
| 493 |
+
self,
|
| 494 |
+
verbose: bool = False,
|
| 495 |
+
annotate_consensus: bool = True,
|
| 496 |
+
color_annotations: Dict[int, str] = None,
|
| 497 |
+
):
|
| 498 |
+
"""returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
|
| 499 |
+
|
| 500 |
+
# get the consensus sequence, which we'll use as the "spine" of the
|
| 501 |
+
# graph
|
| 502 |
+
pathdict = {}
|
| 503 |
+
if annotate_consensus:
|
| 504 |
+
path, __, __ = self.consensus()
|
| 505 |
+
lines = ["var nodes = ["]
|
| 506 |
+
|
| 507 |
+
ni = self.nodeiterator()
|
| 508 |
+
count = 0
|
| 509 |
+
for node in ni():
|
| 510 |
+
title_text = ""
|
| 511 |
+
if node.sequences:
|
| 512 |
+
title_text += f"Sequences: {node.sequences}"
|
| 513 |
+
if node.variations:
|
| 514 |
+
title_text += ";;;".join(
|
| 515 |
+
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
|
| 516 |
+
)
|
| 517 |
+
title_text = title_text.replace('"', "'")
|
| 518 |
+
line = (
|
| 519 |
+
" {id:"
|
| 520 |
+
+ str(node.ID)
|
| 521 |
+
+ ', label: "'
|
| 522 |
+
+ str(node.ID)
|
| 523 |
+
+ ": "
|
| 524 |
+
+ node.text.replace('"', "'")
|
| 525 |
+
+ '", title: '
|
| 526 |
+
+ '"'
|
| 527 |
+
+ title_text
|
| 528 |
+
+ '",'
|
| 529 |
+
)
|
| 530 |
+
if color_annotations and node.ID in color_annotations:
|
| 531 |
+
line += f" color: '{color_annotations[node.ID]}', "
|
| 532 |
+
if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
|
| 533 |
+
line += (
|
| 534 |
+
", x: "
|
| 535 |
+
+ str(pathdict[node.ID])
|
| 536 |
+
+ ", y: 0 , fixed: { x:true, y:false},"
|
| 537 |
+
+ "color: '#7BE141', is_consensus:true},"
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
line += "},"
|
| 541 |
+
lines.append(line)
|
| 542 |
+
|
| 543 |
+
lines[-1] = lines[-1][:-1]
|
| 544 |
+
lines.append("];")
|
| 545 |
+
|
| 546 |
+
lines.append(" ")
|
| 547 |
+
|
| 548 |
+
lines.append("var edges = [ ")
|
| 549 |
+
ni = self.nodeiterator()
|
| 550 |
+
for node in ni():
|
| 551 |
+
nodeID = str(node.ID)
|
| 552 |
+
for edge in node.outEdges:
|
| 553 |
+
target = str(edge)
|
| 554 |
+
weight = str(node.outEdges[edge].weight + 1.5)
|
| 555 |
+
lines.append(
|
| 556 |
+
" {from: "
|
| 557 |
+
+ nodeID
|
| 558 |
+
+ ", to: "
|
| 559 |
+
+ target
|
| 560 |
+
+ ", value: "
|
| 561 |
+
+ weight
|
| 562 |
+
+ ", color: '#4b72b0', arrows: 'to'},"
|
| 563 |
+
)
|
| 564 |
+
if verbose:
|
| 565 |
+
for alignededge in node.alignedTo:
|
| 566 |
+
# These edges indicate alignment to different bases, and are
|
| 567 |
+
# undirected; thus make sure we only plot them once:
|
| 568 |
+
if node.ID > alignededge:
|
| 569 |
+
continue
|
| 570 |
+
target = str(alignededge)
|
| 571 |
+
lines.append(
|
| 572 |
+
" {from: "
|
| 573 |
+
+ nodeID
|
| 574 |
+
+ ", to: "
|
| 575 |
+
+ target
|
| 576 |
+
+ ', value: 1, style: "dash-line", color: "red"},'
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
lines[-1] = lines[-1][:-1]
|
| 580 |
+
lines.append("];")
|
| 581 |
+
return lines
|
| 582 |
+
|
| 583 |
+
def htmlOutput(
|
| 584 |
+
self,
|
| 585 |
+
outfile,
|
| 586 |
+
verbose: bool = False,
|
| 587 |
+
annotate_consensus: bool = True,
|
| 588 |
+
color_annotations: Dict[int, str] = None,
|
| 589 |
+
):
|
| 590 |
+
header = """
|
| 591 |
+
<!doctype html>
|
| 592 |
+
<html>
|
| 593 |
+
<head>
|
| 594 |
+
<title>POA Graph Alignment</title>
|
| 595 |
+
|
| 596 |
+
<script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
|
| 597 |
+
</head>
|
| 598 |
+
|
| 599 |
+
<body>
|
| 600 |
+
|
| 601 |
+
<div id="loadingProgress">0%</div>
|
| 602 |
+
|
| 603 |
+
<div id="mynetwork"></div>
|
| 604 |
+
|
| 605 |
+
<script type="text/javascript">
|
| 606 |
+
// create a network
|
| 607 |
+
"""
|
| 608 |
+
outfile.write(textwrap.dedent(header[1:]))
|
| 609 |
+
lines = self.jsOutput(
|
| 610 |
+
verbose=verbose,
|
| 611 |
+
annotate_consensus=annotate_consensus,
|
| 612 |
+
color_annotations=color_annotations,
|
| 613 |
+
)
|
| 614 |
+
for line in lines:
|
| 615 |
+
outfile.write(line + "\n")
|
| 616 |
+
footer = """
|
| 617 |
+
var container = document.getElementById('mynetwork');
|
| 618 |
+
var data= {
|
| 619 |
+
nodes: nodes,
|
| 620 |
+
edges: edges,
|
| 621 |
+
};
|
| 622 |
+
var options = {
|
| 623 |
+
width: '100%',
|
| 624 |
+
height: '800px',
|
| 625 |
+
physics: {
|
| 626 |
+
enabled: false,
|
| 627 |
+
stabilization: {
|
| 628 |
+
updateInterval: 10,
|
| 629 |
+
},
|
| 630 |
+
},
|
| 631 |
+
edges: {
|
| 632 |
+
color: {
|
| 633 |
+
inherit: false
|
| 634 |
+
}
|
| 635 |
+
},
|
| 636 |
+
layout: {
|
| 637 |
+
hierarchical: {
|
| 638 |
+
direction: "UD",
|
| 639 |
+
sortMethod: "directed",
|
| 640 |
+
shakeTowards: "roots",
|
| 641 |
+
levelSeparation: 150, // Adjust as needed
|
| 642 |
+
nodeSpacing: 800, // Adjust as needed
|
| 643 |
+
treeSpacing: 200, // Adjust as needed
|
| 644 |
+
parentCentralization: true,
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
};
|
| 648 |
+
var network = new vis.Network(container, data, options);
|
| 649 |
+
|
| 650 |
+
network.on('beforeDrawing', function(ctx) {
|
| 651 |
+
nodes.forEach(function(node) {
|
| 652 |
+
if (node.isConsensus) {
|
| 653 |
+
// Set the level of spine nodes to the bottom
|
| 654 |
+
network.body.data.nodes.update({
|
| 655 |
+
id: node.id,
|
| 656 |
+
level: 0 // Set level to 0 for spine nodes
|
| 657 |
+
});
|
| 658 |
+
}
|
| 659 |
+
});
|
| 660 |
+
});
|
| 661 |
+
|
| 662 |
+
network.on("stabilizationProgress", function (params) {
|
| 663 |
+
document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
|
| 664 |
+
});
|
| 665 |
+
network.once("stabilizationIterationsDone", function () {
|
| 666 |
+
document.getElementById("loadingProgress").innerText = "100%";
|
| 667 |
+
setTimeout(function () {
|
| 668 |
+
document.getElementById("loadingProgress").style.display = "none";
|
| 669 |
+
}, 500);
|
| 670 |
+
});
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
</script>
|
| 674 |
+
|
| 675 |
+
</body>
|
| 676 |
+
</html>
|
| 677 |
+
"""
|
| 678 |
+
outfile.write(textwrap.dedent(footer))
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def multi_consensus_response(self, abstention_threshold: Optional[float] = None, filter: bool = True):
|
| 682 |
+
self.toposort()
|
| 683 |
+
nodesInReverse = self.nodeidlist[::-1]
|
| 684 |
+
maxnodeID = self.end_id
|
| 685 |
+
nextInPath = [-1] * maxnodeID
|
| 686 |
+
scores = np.zeros(len(self.nodeidlist))
|
| 687 |
+
|
| 688 |
+
id_to_index = {node_id: index for index, node_id in enumerate(self.nodeidlist)}
|
| 689 |
+
index_to_id = {index: node_id for index, node_id in enumerate(self.nodeidlist)}
|
| 690 |
+
|
| 691 |
+
for nodeID in nodesInReverse:
|
| 692 |
+
bestWeightScoreEdges = [(-1, -1, None)]
|
| 693 |
+
for neighbourID in self.nodedict[nodeID].outEdges:
|
| 694 |
+
# print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
|
| 695 |
+
e = self.nodedict[nodeID].outEdges[neighbourID]
|
| 696 |
+
weightScoreEdge = (e.weight, scores[id_to_index[neighbourID]], neighbourID)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
if weightScoreEdge > bestWeightScoreEdges[0]:
|
| 700 |
+
bestWeightScoreEdges = [weightScoreEdge]
|
| 701 |
+
elif weightScoreEdge == bestWeightScoreEdges[0] and filter:
|
| 702 |
+
bestWeightScoreEdges.append(weightScoreEdge)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
scores[id_to_index[nodeID]] = sum(bestWeightScoreEdges[0][0:2])
|
| 706 |
+
if bestWeightScoreEdges[0][2] is not None:
|
| 707 |
+
nextInPath[id_to_index[nodeID]] = id_to_index[bestWeightScoreEdges[0][2]]
|
| 708 |
+
else:
|
| 709 |
+
nextInPath[id_to_index[nodeID]] = None
|
| 710 |
+
|
| 711 |
+
pos = np.argmax(scores)
|
| 712 |
+
path = []
|
| 713 |
+
text = []
|
| 714 |
+
labels = []
|
| 715 |
+
|
| 716 |
+
while pos is not None and pos > -1:
|
| 717 |
+
if abstention_threshold is not None and self.nodedict[index_to_id[pos]].variations:
|
| 718 |
+
if (
|
| 719 |
+
len(self.nodedict[index_to_id[pos]].labels) / self.num_sequences
|
| 720 |
+
>= abstention_threshold
|
| 721 |
+
):
|
| 722 |
+
path.append(index_to_id[pos])
|
| 723 |
+
labels.append(self.nodedict[index_to_id[pos]].labels)
|
| 724 |
+
text.append(self.nodedict[index_to_id[pos]].text)
|
| 725 |
+
else:
|
| 726 |
+
path.append(index_to_id[pos])
|
| 727 |
+
labels.append(self.nodedict[index_to_id[pos]].labels)
|
| 728 |
+
text.append(self.nodedict[index_to_id[pos]].text)
|
| 729 |
+
pos = nextInPath[pos]
|
| 730 |
+
|
| 731 |
+
# ignore END node
|
| 732 |
+
path = path[:-1]
|
| 733 |
+
# ignore END node
|
| 734 |
+
text = text[:-1]
|
| 735 |
+
# ignore START in text
|
| 736 |
+
text[0] = text[0].replace("START", "")
|
| 737 |
+
labels = labels[:-1]
|
| 738 |
+
|
| 739 |
+
return " ".join(text)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def consensus_response(
|
| 743 |
+
self, selection_threshold: Optional[float] = 0.5, api: str = "openai" , model: str = "gpt-4o-mini", task: str = "bio", **kwargs
|
| 744 |
+
) -> str:
|
| 745 |
+
self.toposort()
|
| 746 |
+
|
| 747 |
+
consensus_node_ids = self.consensus_node_ids
|
| 748 |
+
print(consensus_node_ids)
|
| 749 |
+
|
| 750 |
+
selected_node_ids = []
|
| 751 |
+
|
| 752 |
+
for node_id in consensus_node_ids:
|
| 753 |
+
if node_id == self.start_id or node_id == self.end_id:
|
| 754 |
+
continue
|
| 755 |
+
|
| 756 |
+
selected_node_ids.append(node_id)
|
| 757 |
+
|
| 758 |
+
for neighbor_id in self.nodedict[node_id].outEdges:
|
| 759 |
+
if neighbor_id in consensus_node_ids:
|
| 760 |
+
continue
|
| 761 |
+
|
| 762 |
+
if (
|
| 763 |
+
len(self.nodedict[neighbor_id].labels) / self.num_sequences
|
| 764 |
+
>= selection_threshold
|
| 765 |
+
):
|
| 766 |
+
selected_node_ids.append(neighbor_id)
|
| 767 |
+
|
| 768 |
+
text = " ".join([self.nodedict[node_id].text for node_id in selected_node_ids])
|
| 769 |
+
print(text)
|
| 770 |
+
cleaned_text = clean_up_text(text, task=task, api=api, model=model, **kwargs)
|
| 771 |
+
return cleaned_text
|
| 772 |
+
|
| 773 |
+
def save_to_pickle(self, filename):
|
| 774 |
+
with open(filename, "wb+") as f:
|
| 775 |
+
pickle.dump(self, f)
|
| 776 |
+
|
| 777 |
+
def refine_graph(
|
| 778 |
+
self,
|
| 779 |
+
verbose: bool = False,
|
| 780 |
+
save_intermediate_file: str = None,
|
| 781 |
+
final_merge: bool = True,
|
| 782 |
+
**kwargs,
|
| 783 |
+
):
|
| 784 |
+
self.merge_consensus_nodes(verbose=verbose)
|
| 785 |
+
|
| 786 |
+
if save_intermediate_file:
|
| 787 |
+
with open(save_intermediate_file, "w+") as f:
|
| 788 |
+
self.htmlOutput(f, annotate_consensus=False)
|
| 789 |
+
|
| 790 |
+
if not self.consensus_node_ids:
|
| 791 |
+
self.failed = True
|
| 792 |
+
return
|
| 793 |
+
|
| 794 |
+
else:
|
| 795 |
+
self.merge_divergent_paths(verbose=verbose, **kwargs)
|
| 796 |
+
|
| 797 |
+
if final_merge:
|
| 798 |
+
try:
|
| 799 |
+
self.merge_consensus_nodes(verbose=verbose)
|
| 800 |
+
except Exception as e:
|
| 801 |
+
print(e)
|
| 802 |
+
self.failed = True
|
src/text_poa_graph_utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TEXT_SIMILARITY_JUDGE_PROMPT = """
|
| 8 |
+
You are given two pieces of text. Your task is to determine whether they are semantically equivalent based solely on their factual content.
|
| 9 |
+
|
| 10 |
+
Here are the specific guidelines:
|
| 11 |
+
- Texts are equivalent if they convey the same core information or concept, regardless of wording or structure
|
| 12 |
+
- If one text has information that is a subset of the other text, then the texts are equivalent
|
| 13 |
+
- Focus ONLY on the essential claims, not on:
|
| 14 |
+
* Stylistic differences or tone
|
| 15 |
+
* Level of detail (if the core facts remain the same)
|
| 16 |
+
* Connotative differences between words
|
| 17 |
+
* Implied significance or emphasis
|
| 18 |
+
* Presentation order (if all key information is present in both)
|
| 19 |
+
- Minor additions of non-contradictory information should not make texts non-equivalent
|
| 20 |
+
- For ambiguous cases, prioritize the central claim or purpose of the text
|
| 21 |
+
|
| 22 |
+
Examples of equivalent pairs:
|
| 23 |
+
- "The meeting starts at 3pm" and "The 3 o'clock meeting will begin on time"
|
| 24 |
+
- "Research indicates a 15% increase" and "Studies show a fifteen percent rise"
|
| 25 |
+
- "was influential in the field" and "had a significant impact on the community"
|
| 26 |
+
|
| 27 |
+
Examples of non-equivalent pairs:
|
| 28 |
+
- "The project might be completed by Friday" and "The project will be finished by Friday"
|
| 29 |
+
- "Most experts agree on the approach" and "All experts support the approach"
|
| 30 |
+
|
| 31 |
+
Strictly follow these guidelines and return ONLY:
|
| 32 |
+
- equivalent
|
| 33 |
+
- not equivalent
|
| 34 |
+
"""
|
| 35 |
+
MATH_SIMILARITY_JUDGE_PROMPT = """
|
| 36 |
+
You are given two pieces of text from mathematical solutions. Your task is to determine whether the two solution segments are mathematically equivalent in their content, while allowing for stylistic variations.
|
| 37 |
+
|
| 38 |
+
Here are some important guidelines:
|
| 39 |
+
- Solutions should be considered equivalent if:
|
| 40 |
+
1. They communicate the same mathematical content/approach, even if word choice or phrasing differs
|
| 41 |
+
2. They contain the same key mathematical ideas, even if expressed differently
|
| 42 |
+
3. The same mathematical steps are described, even if using different words
|
| 43 |
+
4. They present the same final answer, regardless of wording style or formatting
|
| 44 |
+
|
| 45 |
+
- Allow for these variations while still considering solutions equivalent:
|
| 46 |
+
1. Stylistic differences ("we will" vs. "we'll" or "I'll")
|
| 47 |
+
2. Different levels of formality in the explanation
|
| 48 |
+
3. Minor rephrasing that preserves the core mathematical content
|
| 49 |
+
4. Use of synonyms or alternative mathematical terminology for the same concept
|
| 50 |
+
|
| 51 |
+
- Solutions are NOT equivalent if:
|
| 52 |
+
1. They use fundamentally different mathematical approaches
|
| 53 |
+
2. They work with different formulas or equations
|
| 54 |
+
3. They present different mathematical steps or operations
|
| 55 |
+
4. They reach different conclusions or answers
|
| 56 |
+
5. One contains substantial mathematical content that the other lacks
|
| 57 |
+
|
| 58 |
+
- When examining final answers, focus on mathematical equivalence rather than stylistic presentation
|
| 59 |
+
- For solution steps, maintain the core mathematical approach while allowing for rephrasing
|
| 60 |
+
|
| 61 |
+
Examples of solutions that SHOULD be considered equivalent:
|
| 62 |
+
- "We will systematically evaluate each possible grouping" and "We'll evaluate each grouping"
|
| 63 |
+
- "The answer is x = 5" and "Therefore, x equals 5"
|
| 64 |
+
- "Using the quadratic formula" and "Applying the quadratic formula"
|
| 65 |
+
|
| 66 |
+
Strictly follow the guidelines above.
|
| 67 |
+
Return your judgment in the following format. Do not include any other text:
|
| 68 |
+
- equivalent
|
| 69 |
+
- not equivalent
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def path_sim_llm(
|
| 73 |
+
path1_text: str,
|
| 74 |
+
path2_text: str,
|
| 75 |
+
api: str = "openai",
|
| 76 |
+
model: str = "gpt-4.1-mini",
|
| 77 |
+
verbose: bool = False,
|
| 78 |
+
domain: Optional[str] = "text",
|
| 79 |
+
custom_similarity_judge_prompt: str = None,
|
| 80 |
+
):
|
| 81 |
+
if api == "openai":
|
| 82 |
+
client = OpenAI()
|
| 83 |
+
elif api == "hf":
|
| 84 |
+
client = InferenceClient()
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Invalid API: {api}")
|
| 87 |
+
|
| 88 |
+
if domain == "text":
|
| 89 |
+
similarity_judge_prompt = (
|
| 90 |
+
f"{TEXT_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
|
| 91 |
+
)
|
| 92 |
+
elif domain == "math":
|
| 93 |
+
similarity_judge_prompt = (
|
| 94 |
+
f"{MATH_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
|
| 95 |
+
)
|
| 96 |
+
elif not domain and custom_similarity_judge_prompt:
|
| 97 |
+
similarity_judge_prompt = (
|
| 98 |
+
f"{custom_similarity_judge_prompt}\n\nText 1: {path1_text}\nText 2: {path2_text}"
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Invalid domain: {domain} and no custom similarity judge prompt provided")
|
| 102 |
+
|
| 103 |
+
completion = client.chat.completions.create(
|
| 104 |
+
model=model,
|
| 105 |
+
temperature=0,
|
| 106 |
+
messages=[
|
| 107 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 108 |
+
{"role": "user", "content": similarity_judge_prompt},
|
| 109 |
+
],
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
judgement = completion.choices[0].message.content.strip()
|
| 113 |
+
judgement = "".join(c for c in judgement if c.isalpha() or c == " ")
|
| 114 |
+
judgement = judgement.strip()
|
| 115 |
+
|
| 116 |
+
if verbose:
|
| 117 |
+
print(f"{path1_text} \nand \n{path2_text} \nare {judgement}")
|
| 118 |
+
|
| 119 |
+
if judgement == "equivalent":
|
| 120 |
+
return 1, completion.usage.prompt_tokens, completion.usage.completion_tokens
|
| 121 |
+
elif judgement == "not equivalent":
|
| 122 |
+
return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
|
| 123 |
+
else:
|
| 124 |
+
if verbose:
|
| 125 |
+
print(f"Invalid judgement: {judgement}")
|
| 126 |
+
return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
|
src/utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def detect_abstain(text: str, api: str, model: str):
|
| 8 |
+
if api == "openai":
|
| 9 |
+
client = OpenAI()
|
| 10 |
+
elif api == "hf":
|
| 11 |
+
client = InferenceClient()
|
| 12 |
+
else:
|
| 13 |
+
raise ValueError(f"Invalid API: {api}")
|
| 14 |
+
|
| 15 |
+
detect_abstain_prompt = f"""
|
| 16 |
+
You are given a piece of text that is a part of a biography of an entity.
|
| 17 |
+
Text: {text}
|
| 18 |
+
|
| 19 |
+
If the text claims a lack of knowledge about the topic, return "Abstain".
|
| 20 |
+
Otherwise, return "Not abstain".
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
completion = client.chat.completions.create(
|
| 24 |
+
model=model,
|
| 25 |
+
messages=[
|
| 26 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 27 |
+
{"role": "user", "content": detect_abstain_prompt},
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
return completion.choices[0].message.content.strip()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def calculate_factf1_at_k(
|
| 35 |
+
supported_facts: List[str], unsupported_facts: List[str], k: int
|
| 36 |
+
) -> float:
|
| 37 |
+
"""
|
| 38 |
+
Calculate the F1 score at k for supported and unsupported facts
|
| 39 |
+
"""
|
| 40 |
+
if len(supported_facts) == 0:
|
| 41 |
+
return 0
|
| 42 |
+
|
| 43 |
+
precision = len(supported_facts) / (len(supported_facts) + len(unsupported_facts))
|
| 44 |
+
recall = min(len(supported_facts) / k, 1)
|
| 45 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 46 |
+
return f1
|
web_interface/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
web_interface/README.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ConGr Visualizer Web Interface
|
| 2 |
+
|
| 3 |
+
A web-based interface for exploring and visualizing ConGrs (Consensus Graphs) from research datasets.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
### Browse Existing Graphs
|
| 8 |
+
- **Dataset Selection**: Choose from available datasets (BIO, FP, HIST, REFS, MATH, AIME)
|
| 9 |
+
- **Entity Selection**: Browse entities within each dataset
|
| 10 |
+
- **Model Information**: See which language models were used for each graph
|
| 11 |
+
- **Graph Visualization**: Interactive network visualization using vis.js
|
| 12 |
+
- **Metadata Display**: View graph statistics and consensus text
|
| 13 |
+
|
| 14 |
+
### Create New Graphs
|
| 15 |
+
- **Text Input**: Enter multiple text sequences to create new ConGrs
|
| 16 |
+
- **Real-time Visualization**: See the graph structure as it's created
|
| 17 |
+
- **Save Functionality**: Save created graphs to pickle files
|
| 18 |
+
|
| 19 |
+
## Available Datasets
|
| 20 |
+
|
| 21 |
+
- **BIO**: Biography datasets with various public figures
|
| 22 |
+
- **FP**: False Presupposition datasets
|
| 23 |
+
- **HIST**: Historical events datasets
|
| 24 |
+
- **REFS**: Reference datasets
|
| 25 |
+
- **MATH**: Mathematical problem datasets
|
| 26 |
+
- **AIME**: American Invitational Mathematics Examination datasets
|
| 27 |
+
|
| 28 |
+
## Models
|
| 29 |
+
|
| 30 |
+
The graphs are generated using various language models:
|
| 31 |
+
- olmo7b
|
| 32 |
+
- qwen72b
|
| 33 |
+
- llama70b
|
| 34 |
+
- llama8b
|
| 35 |
+
|
| 36 |
+
## Installation
|
| 37 |
+
|
| 38 |
+
1. Install dependencies:
|
| 39 |
+
```bash
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
2. Start the server:
|
| 44 |
+
```bash
|
| 45 |
+
python server.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
3. Open your browser and navigate to:
|
| 49 |
+
```
|
| 50 |
+
http://localhost:8080
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Usage
|
| 54 |
+
|
| 55 |
+
### Browsing Existing Graphs
|
| 56 |
+
|
| 57 |
+
1. **Select Dataset**: Choose a dataset from the dropdown menu
|
| 58 |
+
2. **Select Entity**: Choose an entity from the available options
|
| 59 |
+
3. **View Graph**: The graph will be automatically loaded and displayed
|
| 60 |
+
4. **View Information**: Graph metadata and consensus text will be shown
|
| 61 |
+
|
| 62 |
+
### Creating New Graphs
|
| 63 |
+
|
| 64 |
+
1. **Enter Text**: Input multiple text sequences (one per line)
|
| 65 |
+
2. **Create Graph**: Click "Create Graph" to generate a new ConGr
|
| 66 |
+
3. **Save Graph**: Optionally save the graph to a pickle file
|
| 67 |
+
|
| 68 |
+
## API Endpoints
|
| 69 |
+
|
| 70 |
+
- `GET /api/datasets` - Get available datasets
|
| 71 |
+
- `GET /api/entities?dataset=<dataset>` - Get entities for a dataset
|
| 72 |
+
- `POST /api/load_existing_graph` - Load an existing graph
|
| 73 |
+
- `POST /api/create_graph` - Create a new graph from text sequences
|
| 74 |
+
- `POST /api/save_graph` - Save a graph to file
|
| 75 |
+
|
| 76 |
+
## Testing
|
| 77 |
+
|
| 78 |
+
Run the test script to verify the server is working correctly:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
python test_server.py
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## File Structure
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
web_interface/
|
| 88 |
+
βββ server.py # Flask server
|
| 89 |
+
βββ index.html # Web interface
|
| 90 |
+
βββ requirements.txt # Python dependencies
|
| 91 |
+
βββ test_server.py # Test script
|
| 92 |
+
βββ README.md # This file
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Graph Information
|
| 96 |
+
|
| 97 |
+
When viewing a graph, you can see:
|
| 98 |
+
- **Dataset**: The source dataset
|
| 99 |
+
- **Entity**: The specific entity or topic
|
| 100 |
+
- **Model**: The language model used
|
| 101 |
+
- **Sequences**: Number of input sequences
|
| 102 |
+
- **Nodes**: Number of nodes in the graph
|
| 103 |
+
- **Edges**: Number of edges in the graph
|
| 104 |
+
- **Consensus**: The consensus text generated from the graph
|
| 105 |
+
|
| 106 |
+
## Visualization Features
|
| 107 |
+
|
| 108 |
+
- **Hierarchical Layout**: Graphs are displayed in a hierarchical structure
|
| 109 |
+
- **Color Coding**: Consensus nodes are highlighted in green
|
| 110 |
+
- **Interactive**: Zoom, pan, and hover for more information
|
| 111 |
+
- **Responsive**: Works on desktop and mobile devices
|
web_interface/index.html
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>ConGr Visualizer</title>
|
| 7 |
+
<script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
|
| 8 |
+
<style>
|
| 9 |
+
* {
|
| 10 |
+
margin: 0;
|
| 11 |
+
padding: 0;
|
| 12 |
+
box-sizing: border-box;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
body {
|
| 16 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 17 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 18 |
+
min-height: 100vh;
|
| 19 |
+
color: #333;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.container {
|
| 23 |
+
max-width: 1400px;
|
| 24 |
+
margin: 0 auto;
|
| 25 |
+
padding: 20px;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.header {
|
| 29 |
+
text-align: center;
|
| 30 |
+
margin-bottom: 30px;
|
| 31 |
+
color: white;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
.header h1 {
|
| 35 |
+
font-size: 2.5rem;
|
| 36 |
+
margin-bottom: 10px;
|
| 37 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
.header p {
|
| 41 |
+
font-size: 1.1rem;
|
| 42 |
+
opacity: 0.9;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.main-content {
|
| 46 |
+
display: grid;
|
| 47 |
+
grid-template-columns: 1fr 2fr;
|
| 48 |
+
gap: 30px;
|
| 49 |
+
background: white;
|
| 50 |
+
border-radius: 15px;
|
| 51 |
+
box-shadow: 0 20px 40px rgba(0,0,0,0.1);
|
| 52 |
+
overflow: hidden;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.sidebar {
|
| 56 |
+
background: #f8f9fa;
|
| 57 |
+
padding: 30px;
|
| 58 |
+
border-right: 1px solid #e9ecef;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
.section {
|
| 62 |
+
margin-bottom: 30px;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.section h3 {
|
| 66 |
+
color: #495057;
|
| 67 |
+
margin-bottom: 15px;
|
| 68 |
+
font-size: 1.2rem;
|
| 69 |
+
border-bottom: 2px solid #667eea;
|
| 70 |
+
padding-bottom: 5px;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
.input-group {
|
| 74 |
+
margin-bottom: 20px;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
.input-group label {
|
| 78 |
+
display: block;
|
| 79 |
+
margin-bottom: 8px;
|
| 80 |
+
font-weight: 600;
|
| 81 |
+
color: #495057;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
.input-group textarea {
|
| 85 |
+
width: 100%;
|
| 86 |
+
min-height: 120px;
|
| 87 |
+
padding: 12px;
|
| 88 |
+
border: 2px solid #e9ecef;
|
| 89 |
+
border-radius: 8px;
|
| 90 |
+
font-family: inherit;
|
| 91 |
+
font-size: 14px;
|
| 92 |
+
resize: vertical;
|
| 93 |
+
transition: border-color 0.3s ease;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
.input-group textarea:focus {
|
| 97 |
+
outline: none;
|
| 98 |
+
border-color: #667eea;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
.input-group select {
|
| 102 |
+
width: 100%;
|
| 103 |
+
padding: 12px;
|
| 104 |
+
border: 2px solid #e9ecef;
|
| 105 |
+
border-radius: 8px;
|
| 106 |
+
font-family: inherit;
|
| 107 |
+
font-size: 14px;
|
| 108 |
+
background: white;
|
| 109 |
+
cursor: pointer;
|
| 110 |
+
transition: border-color 0.3s ease;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.input-group select:focus {
|
| 114 |
+
outline: none;
|
| 115 |
+
border-color: #667eea;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.btn {
|
| 119 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 120 |
+
color: white;
|
| 121 |
+
border: none;
|
| 122 |
+
padding: 12px 24px;
|
| 123 |
+
border-radius: 8px;
|
| 124 |
+
cursor: pointer;
|
| 125 |
+
font-size: 14px;
|
| 126 |
+
font-weight: 600;
|
| 127 |
+
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
| 128 |
+
width: 100%;
|
| 129 |
+
margin-bottom: 10px;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
.btn:hover {
|
| 133 |
+
transform: translateY(-2px);
|
| 134 |
+
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.btn:active {
|
| 138 |
+
transform: translateY(0);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
.btn-secondary {
|
| 142 |
+
background: linear-gradient(135deg, #6c757d 0%, #495057 100%);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
.btn-secondary:hover {
|
| 146 |
+
box-shadow: 0 5px 15px rgba(108, 117, 125, 0.4);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.graph-container {
|
| 150 |
+
padding: 30px;
|
| 151 |
+
position: relative;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
#mynetwork {
|
| 155 |
+
width: 100%;
|
| 156 |
+
height: 600px;
|
| 157 |
+
border: 2px solid #e9ecef;
|
| 158 |
+
border-radius: 10px;
|
| 159 |
+
background: white;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.loading {
|
| 163 |
+
position: absolute;
|
| 164 |
+
top: 50%;
|
| 165 |
+
left: 50%;
|
| 166 |
+
transform: translate(-50%, -50%);
|
| 167 |
+
background: rgba(255, 255, 255, 0.9);
|
| 168 |
+
padding: 20px;
|
| 169 |
+
border-radius: 10px;
|
| 170 |
+
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
| 171 |
+
z-index: 1000;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
.loading.hidden {
|
| 175 |
+
display: none;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.status {
|
| 179 |
+
margin-top: 15px;
|
| 180 |
+
padding: 10px;
|
| 181 |
+
border-radius: 5px;
|
| 182 |
+
font-size: 14px;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
.status.success {
|
| 186 |
+
background: #d4edda;
|
| 187 |
+
color: #155724;
|
| 188 |
+
border: 1px solid #c3e6cb;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.status.error {
|
| 192 |
+
background: #f8d7da;
|
| 193 |
+
color: #721c24;
|
| 194 |
+
border: 1px solid #f5c6cb;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
.status.info {
|
| 198 |
+
background: #d1ecf1;
|
| 199 |
+
color: #0c5460;
|
| 200 |
+
border: 1px solid #bee5eb;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.example-text {
|
| 204 |
+
background: #e9ecef;
|
| 205 |
+
padding: 15px;
|
| 206 |
+
border-radius: 8px;
|
| 207 |
+
font-size: 13px;
|
| 208 |
+
line-height: 1.4;
|
| 209 |
+
margin-top: 10px;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
.example-text h4 {
|
| 213 |
+
margin-bottom: 8px;
|
| 214 |
+
color: #495057;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
.example-text p {
|
| 218 |
+
margin-bottom: 8px;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.example-text ul {
|
| 222 |
+
margin-left: 20px;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
.example-text li {
|
| 226 |
+
margin-bottom: 4px;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
.entity-list {
|
| 230 |
+
max-height: 200px;
|
| 231 |
+
overflow-y: auto;
|
| 232 |
+
border: 1px solid #e9ecef;
|
| 233 |
+
border-radius: 8px;
|
| 234 |
+
background: white;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
.entity-item {
|
| 238 |
+
padding: 10px 12px;
|
| 239 |
+
border-bottom: 1px solid #f8f9fa;
|
| 240 |
+
cursor: pointer;
|
| 241 |
+
transition: background-color 0.2s ease;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
.entity-item:hover {
|
| 245 |
+
background-color: #f8f9fa;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
.entity-item:last-child {
|
| 249 |
+
border-bottom: none;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.entity-name {
|
| 253 |
+
font-weight: 600;
|
| 254 |
+
color: #495057;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
.entity-model {
|
| 258 |
+
font-size: 12px;
|
| 259 |
+
color: #6c757d;
|
| 260 |
+
margin-top: 2px;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
.graph-info {
|
| 264 |
+
background: #e9ecef;
|
| 265 |
+
padding: 15px;
|
| 266 |
+
border-radius: 8px;
|
| 267 |
+
margin-top: 15px;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
.graph-info h4 {
|
| 271 |
+
margin-bottom: 10px;
|
| 272 |
+
color: #495057;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
.graph-info p {
|
| 276 |
+
margin-bottom: 5px;
|
| 277 |
+
font-size: 14px;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
.consensus-text {
|
| 281 |
+
background: #d4edda;
|
| 282 |
+
padding: 10px;
|
| 283 |
+
border-radius: 5px;
|
| 284 |
+
margin-top: 10px;
|
| 285 |
+
font-style: italic;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
.sequences-section {
|
| 289 |
+
background: #f8f9fa;
|
| 290 |
+
padding: 15px;
|
| 291 |
+
border-radius: 8px;
|
| 292 |
+
margin-bottom: 15px;
|
| 293 |
+
border: 1px solid #e9ecef;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
.sequences-section h4 {
|
| 297 |
+
margin-bottom: 10px;
|
| 298 |
+
color: #495057;
|
| 299 |
+
font-size: 1rem;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
.sequences-list {
|
| 303 |
+
max-height: 150px;
|
| 304 |
+
overflow-y: auto;
|
| 305 |
+
border: 1px solid #e9ecef;
|
| 306 |
+
border-radius: 6px;
|
| 307 |
+
background: white;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
.sequences-list li {
|
| 311 |
+
padding: 6px 10px;
|
| 312 |
+
border-bottom: 1px solid #f8f9fa;
|
| 313 |
+
cursor: pointer;
|
| 314 |
+
transition: background-color 0.2s ease;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
.sequences-list li:hover {
|
| 318 |
+
background-color: #f8f9fa;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
.sequences-list li:last-child {
|
| 322 |
+
border-bottom: none;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
.sequences-list .sequence-item {
|
| 326 |
+
font-family: 'Courier New', Courier, monospace;
|
| 327 |
+
font-size: 12px;
|
| 328 |
+
line-height: 1.3;
|
| 329 |
+
white-space: pre-wrap;
|
| 330 |
+
word-break: break-all;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
.consensus-highlight {
|
| 334 |
+
background-color: #ceeab2;
|
| 335 |
+
color: #2d5016;
|
| 336 |
+
font-weight: bold;
|
| 337 |
+
padding: 1px 2px;
|
| 338 |
+
border-radius: 3px;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
.consensus-section {
|
| 342 |
+
background: #f8f9fa;
|
| 343 |
+
padding: 15px;
|
| 344 |
+
border-radius: 8px;
|
| 345 |
+
margin-bottom: 15px;
|
| 346 |
+
border: 1px solid #e9ecef;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
.consensus-text {
|
| 350 |
+
font-family: 'Courier New', Courier, monospace;
|
| 351 |
+
font-size: 14px;
|
| 352 |
+
line-height: 1.4;
|
| 353 |
+
white-space: pre-wrap;
|
| 354 |
+
word-break: break-word;
|
| 355 |
+
background: white;
|
| 356 |
+
padding: 10px;
|
| 357 |
+
border: 1px solid #e9ecef;
|
| 358 |
+
border-radius: 6px;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
@media (max-width: 768px) {
|
| 362 |
+
.main-content {
|
| 363 |
+
grid-template-columns: 1fr;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
.sidebar {
|
| 367 |
+
border-right: none;
|
| 368 |
+
border-bottom: 1px solid #e9ecef;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
.header h1 {
|
| 372 |
+
font-size: 2rem;
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
</style>
|
| 376 |
+
</head>
|
| 377 |
+
<body>
|
| 378 |
+
<div class="container">
|
| 379 |
+
<div class="header">
|
| 380 |
+
<h1>ConGr Visualizer</h1>
|
| 381 |
+
<p>Explore and visualize ConGrs</p>
|
| 382 |
+
</div>
|
| 383 |
+
|
| 384 |
+
<div class="main-content">
|
| 385 |
+
<div class="sidebar">
|
| 386 |
+
<div class="section">
|
| 387 |
+
<h3>Browse Existing Graphs</h3>
|
| 388 |
+
<div class="input-group">
|
| 389 |
+
<label for="datasetSelect">Select Dataset:</label>
|
| 390 |
+
<select id="datasetSelect" onchange="loadEntities()">
|
| 391 |
+
<option value="">Choose a dataset...</option>
|
| 392 |
+
</select>
|
| 393 |
+
</div>
|
| 394 |
+
<div class="input-group">
|
| 395 |
+
<label for="entitySelect">Select Instance:</label>
|
| 396 |
+
<select id="entitySelect" onchange="loadModels()">
|
| 397 |
+
<option value="">Choose an instance...</option>
|
| 398 |
+
</select>
|
| 399 |
+
</div>
|
| 400 |
+
<div class="input-group">
|
| 401 |
+
<label for="modelSelect">Select Model:</label>
|
| 402 |
+
<select id="modelSelect" onchange="loadSelectedGraph()">
|
| 403 |
+
<option value="">Choose a model...</option>
|
| 404 |
+
</select>
|
| 405 |
+
</div>
|
| 406 |
+
<div id="graphInfo" class="graph-info hidden">
|
| 407 |
+
<h4>Graph Information</h4>
|
| 408 |
+
<div id="graphDetails"></div>
|
| 409 |
+
</div>
|
| 410 |
+
</div>
|
| 411 |
+
|
| 412 |
+
<div class="section">
|
| 413 |
+
<h3>Create New Graph</h3>
|
| 414 |
+
<div class="input-group">
|
| 415 |
+
<label for="textInput">Enter text sequences (one per line):</label>
|
| 416 |
+
<textarea id="textInput" placeholder="Enter your text sequences here..."></textarea>
|
| 417 |
+
</div>
|
| 418 |
+
<button class="btn" onclick="createGraph()">Create Graph</button>
|
| 419 |
+
<div class="input-group">
|
| 420 |
+
<label>
|
| 421 |
+
<input type="checkbox" id="computeConsensus" checked> Display consensus response using consensus decoding
|
| 422 |
+
</label>
|
| 423 |
+
</div>
|
| 424 |
+
</div>
|
| 425 |
+
|
| 426 |
+
<div class="section">
|
| 427 |
+
<h3>Graph Options</h3>
|
| 428 |
+
<div class="input-group">
|
| 429 |
+
<label for="saveFilename">Save filename:</label>
|
| 430 |
+
<input type="text" id="saveFilename" placeholder="graph.pkl" value="graph.pkl">
|
| 431 |
+
</div>
|
| 432 |
+
<button class="btn btn-secondary" onclick="saveGraph()">Save Graph</button>
|
| 433 |
+
<button class="btn btn-secondary" onclick="clearGraph()">Clear Graph</button>
|
| 434 |
+
</div>
|
| 435 |
+
|
| 436 |
+
<div id="status" class="status hidden"></div>
|
| 437 |
+
</div>
|
| 438 |
+
|
| 439 |
+
<div class="graph-container">
|
| 440 |
+
<div id="loadingProgress" class="loading hidden">Processing...</div>
|
| 441 |
+
<div id="originalSequences" class="sequences-section hidden">
|
| 442 |
+
<h4>Original Sequences</h4>
|
| 443 |
+
<div id="sequencesList"></div>
|
| 444 |
+
</div>
|
| 445 |
+
<div id="consensusResponse" class="consensus-section hidden">
|
| 446 |
+
<h4>Consensus Response</h4>
|
| 447 |
+
<div id="consensusText"></div>
|
| 448 |
+
</div>
|
| 449 |
+
<div id="mynetwork"></div>
|
| 450 |
+
</div>
|
| 451 |
+
</div>
|
| 452 |
+
</div>
|
| 453 |
+
|
| 454 |
+
<script>
|
| 455 |
+
let network = null;
|
| 456 |
+
let currentGraphData = null;
|
| 457 |
+
let availableEntities = [];
|
| 458 |
+
|
| 459 |
+
function showStatus(message, type = 'info') {
|
| 460 |
+
const status = document.getElementById('status');
|
| 461 |
+
status.textContent = message;
|
| 462 |
+
status.className = `status ${type}`;
|
| 463 |
+
status.classList.remove('hidden');
|
| 464 |
+
|
| 465 |
+
if (type === 'success') {
|
| 466 |
+
setTimeout(() => {
|
| 467 |
+
status.classList.add('hidden');
|
| 468 |
+
}, 3000);
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
function showLoading() {
|
| 473 |
+
document.getElementById('loadingProgress').classList.remove('hidden');
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
function hideLoading() {
|
| 477 |
+
document.getElementById('loadingProgress').classList.add('hidden');
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
async function loadDatasets() {
|
| 481 |
+
try {
|
| 482 |
+
const response = await fetch('/api/datasets');
|
| 483 |
+
const data = await response.json();
|
| 484 |
+
|
| 485 |
+
const datasetSelect = document.getElementById('datasetSelect');
|
| 486 |
+
datasetSelect.innerHTML = '<option value="">Choose a dataset...</option>';
|
| 487 |
+
|
| 488 |
+
if (data.datasets && data.datasets.length > 0) {
|
| 489 |
+
data.datasets.forEach(dataset => {
|
| 490 |
+
const option = document.createElement('option');
|
| 491 |
+
option.value = dataset.name;
|
| 492 |
+
option.textContent = `${dataset.display_name} (${dataset.count} graphs)`;
|
| 493 |
+
datasetSelect.appendChild(option);
|
| 494 |
+
});
|
| 495 |
+
}
|
| 496 |
+
} catch (error) {
|
| 497 |
+
showStatus('Error loading datasets: ' + error.message, 'error');
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
async function loadEntities() {
|
| 502 |
+
const datasetSelect = document.getElementById('datasetSelect');
|
| 503 |
+
const entitySelect = document.getElementById('entitySelect');
|
| 504 |
+
const modelSelect = document.getElementById('modelSelect');
|
| 505 |
+
const dataset = datasetSelect.value;
|
| 506 |
+
|
| 507 |
+
if (!dataset) {
|
| 508 |
+
entitySelect.innerHTML = '<option value="">Choose an entity...</option>';
|
| 509 |
+
modelSelect.innerHTML = '<option value="">Choose a model...</option>';
|
| 510 |
+
return;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
showLoading();
|
| 514 |
+
showStatus('Loading entities...', 'info');
|
| 515 |
+
|
| 516 |
+
try {
|
| 517 |
+
const response = await fetch(`/api/entities?dataset=${dataset}`);
|
| 518 |
+
const data = await response.json();
|
| 519 |
+
|
| 520 |
+
if (data.error) {
|
| 521 |
+
showStatus('Error loading entities: ' + data.error, 'error');
|
| 522 |
+
return;
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
availableEntities = data.entities;
|
| 526 |
+
entitySelect.innerHTML = '<option value="">Choose an entity...</option>';
|
| 527 |
+
modelSelect.innerHTML = '<option value="">Choose a model...</option>';
|
| 528 |
+
|
| 529 |
+
if (data.entities && data.entities.length > 0) {
|
| 530 |
+
// Get unique entity names
|
| 531 |
+
const uniqueEntities = [...new Set(data.entities.map(e => e.entity))];
|
| 532 |
+
|
| 533 |
+
// Sort numerically for non-bio datasets
|
| 534 |
+
if (dataset !== 'bio') {
|
| 535 |
+
uniqueEntities.sort((a, b) => {
|
| 536 |
+
// Extract numbers from entity names for sorting
|
| 537 |
+
const numA = parseInt(a.match(/\d+/)?.[0] || '0');
|
| 538 |
+
const numB = parseInt(b.match(/\d+/)?.[0] || '0');
|
| 539 |
+
return numA - numB;
|
| 540 |
+
});
|
| 541 |
+
} else {
|
| 542 |
+
// Sort alphabetically for bio dataset
|
| 543 |
+
uniqueEntities.sort();
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
uniqueEntities.forEach(entityName => {
|
| 547 |
+
const option = document.createElement('option');
|
| 548 |
+
option.value = entityName;
|
| 549 |
+
option.textContent = entityName;
|
| 550 |
+
entitySelect.appendChild(option);
|
| 551 |
+
});
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
showStatus(`Loaded ${data.entities.length} entities from ${dataset} dataset`, 'success');
|
| 555 |
+
} catch (error) {
|
| 556 |
+
showStatus('Error loading entities: ' + error.message, 'error');
|
| 557 |
+
} finally {
|
| 558 |
+
hideLoading();
|
| 559 |
+
}
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
async function loadModels() {
|
| 563 |
+
const datasetSelect = document.getElementById('datasetSelect');
|
| 564 |
+
const entitySelect = document.getElementById('entitySelect');
|
| 565 |
+
const modelSelect = document.getElementById('modelSelect');
|
| 566 |
+
const dataset = datasetSelect.value;
|
| 567 |
+
const entityName = entitySelect.value;
|
| 568 |
+
|
| 569 |
+
if (!entityName) {
|
| 570 |
+
modelSelect.innerHTML = '<option value="">Choose a model...</option>';
|
| 571 |
+
return;
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
showLoading();
|
| 575 |
+
showStatus('Loading models...', 'info');
|
| 576 |
+
|
| 577 |
+
try {
|
| 578 |
+
const response = await fetch(`/api/models?dataset=${dataset}&entity=${encodeURIComponent(entityName)}`);
|
| 579 |
+
const data = await response.json();
|
| 580 |
+
|
| 581 |
+
if (data.error) {
|
| 582 |
+
showStatus('Error loading models: ' + data.error, 'error');
|
| 583 |
+
return;
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
modelSelect.innerHTML = '<option value="">Choose a model...</option>';
|
| 587 |
+
if (data.models && data.models.length > 0) {
|
| 588 |
+
// Sort models by name for consistency
|
| 589 |
+
data.models.sort((a, b) => a.model.localeCompare(b.model));
|
| 590 |
+
|
| 591 |
+
data.models.forEach(model => {
|
| 592 |
+
const option = document.createElement('option');
|
| 593 |
+
option.value = model.filepath;
|
| 594 |
+
option.textContent = model.model;
|
| 595 |
+
modelSelect.appendChild(option);
|
| 596 |
+
});
|
| 597 |
+
} else {
|
| 598 |
+
console.log('No models found for this entity');
|
| 599 |
+
}
|
| 600 |
+
showStatus(`Loaded ${data.models.length} models for ${entityName}`, 'success');
|
| 601 |
+
} catch (error) {
|
| 602 |
+
showStatus('Error loading models: ' + error.message, 'error');
|
| 603 |
+
} finally {
|
| 604 |
+
hideLoading();
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
function displayOriginalSequences(sequences, consensusText = null) {
|
| 609 |
+
const sequencesSection = document.getElementById('originalSequences');
|
| 610 |
+
const sequencesList = document.getElementById('sequencesList');
|
| 611 |
+
|
| 612 |
+
if (!sequences || sequences.length === 0) {
|
| 613 |
+
sequencesSection.classList.add('hidden');
|
| 614 |
+
return;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
let html = '<ul class="sequences-list">';
|
| 618 |
+
sequences.forEach((sequence, index) => {
|
| 619 |
+
let highlightedSequence = sequence;
|
| 620 |
+
|
| 621 |
+
// Highlight consensus text in green if available
|
| 622 |
+
if (consensusText && consensusText.trim()) {
|
| 623 |
+
const consensusWords = consensusText.trim().split(/\s+/);
|
| 624 |
+
let currentSequence = sequence;
|
| 625 |
+
|
| 626 |
+
consensusWords.forEach(word => {
|
| 627 |
+
if (word.length > 2) { // Only highlight words longer than 2 characters
|
| 628 |
+
const regex = new RegExp(`\\b${word.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')}\\b`, 'gi');
|
| 629 |
+
currentSequence = currentSequence.replace(regex, `<span class="consensus-highlight">${word}</span>`);
|
| 630 |
+
}
|
| 631 |
+
});
|
| 632 |
+
|
| 633 |
+
highlightedSequence = currentSequence;
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
html += `<li><div class="sequence-item"><strong>Sequence ${index + 1}:</strong> ${highlightedSequence}</div></li>`;
|
| 637 |
+
});
|
| 638 |
+
|
| 639 |
+
html += '</ul>';
|
| 640 |
+
|
| 641 |
+
sequencesList.innerHTML = html;
|
| 642 |
+
sequencesSection.classList.remove('hidden');
|
| 643 |
+
|
| 644 |
+
// Display consensus response in separate box if available
|
| 645 |
+
displayConsensusResponse(consensusText);
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
function displayConsensusResponse(consensusText) {
|
| 649 |
+
const consensusSection = document.getElementById('consensusResponse');
|
| 650 |
+
const consensusTextDiv = document.getElementById('consensusText');
|
| 651 |
+
|
| 652 |
+
if (!consensusText || !consensusText.trim()) {
|
| 653 |
+
consensusSection.classList.add('hidden');
|
| 654 |
+
return;
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
consensusTextDiv.innerHTML = `<div class="consensus-text">${consensusText}</div>`;
|
| 658 |
+
consensusSection.classList.remove('hidden');
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
async function loadSelectedGraph() {
|
| 662 |
+
const modelSelect = document.getElementById('modelSelect');
|
| 663 |
+
const selectedModel = modelSelect.value;
|
| 664 |
+
|
| 665 |
+
if (!selectedModel) {
|
| 666 |
+
return;
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
showLoading();
|
| 670 |
+
showStatus('Loading graph...', 'info');
|
| 671 |
+
|
| 672 |
+
try {
|
| 673 |
+
const computeConsensus = document.getElementById('computeConsensus').checked;
|
| 674 |
+
const response = await fetch('/api/load_existing_graph', {
|
| 675 |
+
method: 'POST',
|
| 676 |
+
headers: {
|
| 677 |
+
'Content-Type': 'application/json',
|
| 678 |
+
},
|
| 679 |
+
body: JSON.stringify({
|
| 680 |
+
filepath: selectedModel,
|
| 681 |
+
compute_consensus: computeConsensus
|
| 682 |
+
})
|
| 683 |
+
});
|
| 684 |
+
|
| 685 |
+
const data = await response.json();
|
| 686 |
+
|
| 687 |
+
if (data.success) {
|
| 688 |
+
displayGraph(data.nodes, data.edges);
|
| 689 |
+
displayOriginalSequences(data.original_sequences, data.consensus_text);
|
| 690 |
+
showGraphInfo(data);
|
| 691 |
+
showStatus(`Graph loaded successfully! ${data.num_sequences} sequences, ${data.num_nodes} nodes, ${data.num_edges} edges.`, 'success');
|
| 692 |
+
} else {
|
| 693 |
+
showStatus('Error loading graph: ' + data.error, 'error');
|
| 694 |
+
}
|
| 695 |
+
} catch (error) {
|
| 696 |
+
showStatus('Error loading graph: ' + error.message, 'error');
|
| 697 |
+
} finally {
|
| 698 |
+
hideLoading();
|
| 699 |
+
}
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
function showGraphInfo(data) {
|
| 703 |
+
const graphInfo = document.getElementById('graphInfo');
|
| 704 |
+
const graphDetails = document.getElementById('graphDetails');
|
| 705 |
+
|
| 706 |
+
let detailsHtml = '';
|
| 707 |
+
|
| 708 |
+
if (data.metadata) {
|
| 709 |
+
detailsHtml += `
|
| 710 |
+
<p><strong>Dataset:</strong> ${data.metadata.task}</p>
|
| 711 |
+
<p><strong>Entity:</strong> ${data.metadata.entity}</p>
|
| 712 |
+
<p><strong>Model:</strong> ${data.metadata.model}</p>
|
| 713 |
+
`;
|
| 714 |
+
} else {
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
detailsHtml += `
|
| 718 |
+
<p><strong>Sequences:</strong> ${data.num_sequences}</p>
|
| 719 |
+
<p><strong>Nodes:</strong> ${data.num_nodes}</p>
|
| 720 |
+
<p><strong>Edges:</strong> ${data.num_edges}</p>
|
| 721 |
+
`;
|
| 722 |
+
|
| 723 |
+
graphDetails.innerHTML = detailsHtml;
|
| 724 |
+
graphInfo.classList.remove('hidden');
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
async function createGraph() {
|
| 728 |
+
const textInput = document.getElementById('textInput').value.trim();
|
| 729 |
+
if (!textInput) {
|
| 730 |
+
showStatus('Please enter some text sequences.', 'error');
|
| 731 |
+
return;
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
const sequences = textInput.split('\n').filter(line => line.trim() !== '');
|
| 735 |
+
if (sequences.length < 2) {
|
| 736 |
+
showStatus('Please enter at least 2 text sequences.', 'error');
|
| 737 |
+
return;
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
showLoading();
|
| 741 |
+
showStatus('Creating graph...', 'info');
|
| 742 |
+
|
| 743 |
+
try {
|
| 744 |
+
const computeConsensus = document.getElementById('computeConsensus').checked;
|
| 745 |
+
const response = await fetch('/api/create_graph', {
|
| 746 |
+
method: 'POST',
|
| 747 |
+
headers: {
|
| 748 |
+
'Content-Type': 'application/json',
|
| 749 |
+
},
|
| 750 |
+
body: JSON.stringify({
|
| 751 |
+
sequences: sequences,
|
| 752 |
+
compute_consensus: computeConsensus
|
| 753 |
+
})
|
| 754 |
+
});
|
| 755 |
+
|
| 756 |
+
const data = await response.json();
|
| 757 |
+
|
| 758 |
+
if (data.success) {
|
| 759 |
+
displayGraph(data.nodes, data.edges);
|
| 760 |
+
displayOriginalSequences(data.original_sequences, data.consensus_text);
|
| 761 |
+
showStatus(`Graph created with ${data.num_sequences} sequences, ${data.num_nodes} nodes, and ${data.num_edges} edges!`, 'success');
|
| 762 |
+
} else {
|
| 763 |
+
showStatus('Error creating graph: ' + data.error, 'error');
|
| 764 |
+
}
|
| 765 |
+
} catch (error) {
|
| 766 |
+
showStatus('Error creating graph: ' + error.message, 'error');
|
| 767 |
+
} finally {
|
| 768 |
+
hideLoading();
|
| 769 |
+
}
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
function displayGraph(nodes, edges) {
|
| 773 |
+
const container = document.getElementById('mynetwork');
|
| 774 |
+
|
| 775 |
+
if (!nodes || nodes.length === 0) {
|
| 776 |
+
console.error('No nodes provided to displayGraph');
|
| 777 |
+
return;
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
// Process nodes without manual level assignment
|
| 781 |
+
const processedNodes = nodes.map(node => ({ ...node }));
|
| 782 |
+
|
| 783 |
+
const data = {
|
| 784 |
+
nodes: new vis.DataSet(processedNodes),
|
| 785 |
+
edges: new vis.DataSet(edges)
|
| 786 |
+
};
|
| 787 |
+
|
| 788 |
+
const options = {
|
| 789 |
+
width: '100%',
|
| 790 |
+
height: '100%',
|
| 791 |
+
physics: {
|
| 792 |
+
enabled: false,
|
| 793 |
+
stabilization: {
|
| 794 |
+
updateInterval: 10,
|
| 795 |
+
},
|
| 796 |
+
},
|
| 797 |
+
edges: {
|
| 798 |
+
color: {
|
| 799 |
+
inherit: false
|
| 800 |
+
}
|
| 801 |
+
},
|
| 802 |
+
layout: {
|
| 803 |
+
hierarchical: {
|
| 804 |
+
direction: "UD",
|
| 805 |
+
sortMethod: "directed",
|
| 806 |
+
shakeTowards: "roots",
|
| 807 |
+
levelSeparation: 150,
|
| 808 |
+
nodeSpacing: 800,
|
| 809 |
+
treeSpacing: 200,
|
| 810 |
+
parentCentralization: true,
|
| 811 |
+
}
|
| 812 |
+
}
|
| 813 |
+
};
|
| 814 |
+
|
| 815 |
+
if (network) {
|
| 816 |
+
network.destroy();
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
try {
|
| 820 |
+
network = new vis.Network(container, data, options);
|
| 821 |
+
} catch (error) {
|
| 822 |
+
console.error('Error creating network:', error);
|
| 823 |
+
return;
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
network.on("stabilizationProgress", function (params) {
|
| 827 |
+
document.getElementById("loadingProgress").innerText =
|
| 828 |
+
"Stabilizing: " + Math.round(params.iterations / params.total * 100) + "%";
|
| 829 |
+
});
|
| 830 |
+
|
| 831 |
+
network.once("stabilizationIterationsDone", function () {
|
| 832 |
+
document.getElementById("loadingProgress").innerText = "100%";
|
| 833 |
+
setTimeout(function () {
|
| 834 |
+
document.getElementById("loadingProgress").classList.add("hidden");
|
| 835 |
+
}, 500);
|
| 836 |
+
});
|
| 837 |
+
|
| 838 |
+
currentGraphData = { nodes, edges };
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
async function saveGraph() {
|
| 842 |
+
const textInput = document.getElementById('textInput').value.trim();
|
| 843 |
+
if (!textInput) {
|
| 844 |
+
showStatus('Please enter some text sequences first.', 'error');
|
| 845 |
+
return;
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
const sequences = textInput.split('\n').filter(line => line.trim() !== '');
|
| 849 |
+
if (sequences.length < 2) {
|
| 850 |
+
showStatus('Please enter at least 2 text sequences.', 'error');
|
| 851 |
+
return;
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
const filename = document.getElementById('saveFilename').value || 'graph.pkl';
|
| 855 |
+
|
| 856 |
+
showLoading();
|
| 857 |
+
showStatus('Saving graph...', 'info');
|
| 858 |
+
|
| 859 |
+
try {
|
| 860 |
+
const response = await fetch('/api/save_graph', {
|
| 861 |
+
method: 'POST',
|
| 862 |
+
headers: {
|
| 863 |
+
'Content-Type': 'application/json',
|
| 864 |
+
},
|
| 865 |
+
body: JSON.stringify({
|
| 866 |
+
sequences: sequences,
|
| 867 |
+
filename: filename
|
| 868 |
+
})
|
| 869 |
+
});
|
| 870 |
+
|
| 871 |
+
const data = await response.json();
|
| 872 |
+
|
| 873 |
+
if (data.success) {
|
| 874 |
+
showStatus(`Graph saved successfully to ${data.filename}!`, 'success');
|
| 875 |
+
} else {
|
| 876 |
+
showStatus('Error saving graph: ' + data.error, 'error');
|
| 877 |
+
}
|
| 878 |
+
} catch (error) {
|
| 879 |
+
showStatus('Error saving graph: ' + error.message, 'error');
|
| 880 |
+
} finally {
|
| 881 |
+
hideLoading();
|
| 882 |
+
}
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
function clearGraph() {
|
| 886 |
+
if (network) {
|
| 887 |
+
network.destroy();
|
| 888 |
+
network = null;
|
| 889 |
+
}
|
| 890 |
+
currentGraphData = null;
|
| 891 |
+
document.getElementById('textInput').value = '';
|
| 892 |
+
document.getElementById('datasetSelect').value = '';
|
| 893 |
+
document.getElementById('entitySelect').innerHTML = '<option value="">Choose an entity...</option>';
|
| 894 |
+
document.getElementById('modelSelect').innerHTML = '<option value="">Choose a model...</option>';
|
| 895 |
+
document.getElementById('graphInfo').classList.add('hidden');
|
| 896 |
+
document.getElementById('originalSequences').classList.add('hidden');
|
| 897 |
+
showStatus('Graph cleared.', 'info');
|
| 898 |
+
}
|
| 899 |
+
|
| 900 |
+
// Initialize
|
| 901 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 902 |
+
loadDatasets();
|
| 903 |
+
showStatus('Ready to explore existing graphs or create new ones!', 'info');
|
| 904 |
+
});
|
| 905 |
+
</script>
|
| 906 |
+
</body>
|
| 907 |
+
</html>
|
web_interface/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask==2.3.3
|
| 2 |
+
flask-cors==4.0.0
|
| 3 |
+
numpy==1.24.3
|
| 4 |
+
tqdm==4.66.1
|
web_interface/server.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Flask server for POA Graph Web Interface
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import glob
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
import re
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
from flask import Flask, jsonify, request, send_from_directory
|
| 13 |
+
from flask_cors import CORS
|
| 14 |
+
|
| 15 |
+
# Get the repository root directory (parent of web_interface)
|
| 16 |
+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
+
|
| 18 |
+
# Add the repository root to the path so we can import the POA graph modules
|
| 19 |
+
sys.path.append(REPO_ROOT)
|
| 20 |
+
|
| 21 |
+
from src.new_text_alignment import TextSeqGraphAlignment
|
| 22 |
+
from src.text_poa_graph import TextPOAGraph
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from src.generation_methods import decode_consensus
|
| 26 |
+
except ImportError:
|
| 27 |
+
decode_consensus = None
|
| 28 |
+
|
| 29 |
+
app = Flask(__name__)
|
| 30 |
+
CORS(app) # Enable CORS for all routes
|
| 31 |
+
|
| 32 |
+
# Base paths for different datasets (relative to repo root)
|
| 33 |
+
GRAPH_PATHS = {
|
| 34 |
+
"bio": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/bio"),
|
| 35 |
+
"fp": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/fp"),
|
| 36 |
+
"hist": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/hist"),
|
| 37 |
+
"refs": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/refs"),
|
| 38 |
+
"math": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/MATH"),
|
| 39 |
+
"aime": os.path.join(REPO_ROOT, "results/graphs/HALoGEN/AIME"),
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
MODELS = ["qwen72b", "qwen7b", "llama8b", "llama70b", "olmo7b", "olmo32b"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@app.route("/")
|
| 46 |
+
def index():
|
| 47 |
+
"""Serve the main HTML file"""
|
| 48 |
+
return send_from_directory(".", "index.html")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@app.route("/api/datasets", methods=["GET"])
|
| 52 |
+
def get_datasets():
|
| 53 |
+
"""Get available datasets"""
|
| 54 |
+
datasets = []
|
| 55 |
+
for dataset_name, path in GRAPH_PATHS.items():
|
| 56 |
+
if os.path.exists(path):
|
| 57 |
+
# Count available graphs
|
| 58 |
+
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
|
| 59 |
+
datasets.append(
|
| 60 |
+
{
|
| 61 |
+
"name": dataset_name,
|
| 62 |
+
"display_name": dataset_name.upper(),
|
| 63 |
+
"path": path,
|
| 64 |
+
"count": len(pkl_files),
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
return jsonify({"datasets": datasets})
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@app.route("/api/models", methods=["GET"])
|
| 71 |
+
def get_models():
|
| 72 |
+
"""Get available models for a specific entity"""
|
| 73 |
+
entity = request.args.get("entity")
|
| 74 |
+
dataset = request.args.get("dataset")
|
| 75 |
+
|
| 76 |
+
if not entity:
|
| 77 |
+
return jsonify({"error": "Entity parameter required"}), 400
|
| 78 |
+
|
| 79 |
+
if not dataset or dataset not in GRAPH_PATHS:
|
| 80 |
+
return jsonify({"error": "Invalid dataset"}), 400
|
| 81 |
+
|
| 82 |
+
path = GRAPH_PATHS[dataset]
|
| 83 |
+
if not os.path.exists(path):
|
| 84 |
+
return jsonify({"error": "Dataset path not found"}), 404
|
| 85 |
+
|
| 86 |
+
models = []
|
| 87 |
+
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
|
| 88 |
+
|
| 89 |
+
for pkl_file in pkl_files:
|
| 90 |
+
filename = os.path.basename(pkl_file)
|
| 91 |
+
|
| 92 |
+
# Different filename patterns for different datasets
|
| 93 |
+
if dataset == "bio":
|
| 94 |
+
# Format: bio_graph_{entity}_merged_{model}.pkl
|
| 95 |
+
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 96 |
+
if match:
|
| 97 |
+
entity_name, model = match.groups()
|
| 98 |
+
if entity_name == entity:
|
| 99 |
+
models.append({"model": model, "filename": filename, "filepath": pkl_file})
|
| 100 |
+
elif dataset == "fp":
|
| 101 |
+
# Format: fp_graph_{number}_merged_{model}.pkl
|
| 102 |
+
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
|
| 103 |
+
if match:
|
| 104 |
+
entity_name, model = match.groups()
|
| 105 |
+
if f"Problem {entity_name}" == entity:
|
| 106 |
+
models.append({"model": model, "filename": filename, "filepath": pkl_file})
|
| 107 |
+
elif dataset == "math":
|
| 108 |
+
# Format: qwen72_math_{number}.pkl
|
| 109 |
+
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
|
| 110 |
+
if match:
|
| 111 |
+
entity_name = match.group(1)
|
| 112 |
+
if f"Math Problem {entity_name}" == entity:
|
| 113 |
+
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
|
| 114 |
+
elif dataset == "aime":
|
| 115 |
+
# Format: aime_qwen72b_{number}.pkl
|
| 116 |
+
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
|
| 117 |
+
if match:
|
| 118 |
+
entity_name = match.group(1)
|
| 119 |
+
if f"AIME Problem {entity_name}" == entity:
|
| 120 |
+
models.append({"model": "qwen72b", "filename": filename, "filepath": pkl_file})
|
| 121 |
+
else:
|
| 122 |
+
# Generic pattern for other datasets
|
| 123 |
+
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 124 |
+
if match:
|
| 125 |
+
task, entity_name, model = match.groups()
|
| 126 |
+
if entity_name == entity:
|
| 127 |
+
models.append({"model": model, "filename": filename, "filepath": pkl_file})
|
| 128 |
+
|
| 129 |
+
return jsonify({"models": models})
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@app.route("/api/entities", methods=["GET"])
|
| 133 |
+
def get_entities():
|
| 134 |
+
"""Get available entities for a dataset"""
|
| 135 |
+
dataset = request.args.get("dataset")
|
| 136 |
+
if not dataset or dataset not in GRAPH_PATHS:
|
| 137 |
+
return jsonify({"error": "Invalid dataset"}), 400
|
| 138 |
+
|
| 139 |
+
path = GRAPH_PATHS[dataset]
|
| 140 |
+
if not os.path.exists(path):
|
| 141 |
+
return jsonify({"error": "Dataset path not found"}), 404
|
| 142 |
+
|
| 143 |
+
entities = []
|
| 144 |
+
pkl_files = glob.glob(os.path.join(path, "*.pkl"))
|
| 145 |
+
|
| 146 |
+
for pkl_file in pkl_files:
|
| 147 |
+
filename = os.path.basename(pkl_file)
|
| 148 |
+
|
| 149 |
+
# Different filename patterns for different datasets
|
| 150 |
+
if dataset == "bio":
|
| 151 |
+
# Format: bio_graph_{entity}_merged_{model}.pkl
|
| 152 |
+
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 153 |
+
if match:
|
| 154 |
+
entity_name, model = match.groups()
|
| 155 |
+
entities.append(
|
| 156 |
+
{
|
| 157 |
+
"entity": entity_name,
|
| 158 |
+
"model": model,
|
| 159 |
+
"filename": filename,
|
| 160 |
+
"filepath": pkl_file,
|
| 161 |
+
}
|
| 162 |
+
)
|
| 163 |
+
elif dataset == "fp":
|
| 164 |
+
# Format: fp_graph_{number}_merged_{model}.pkl
|
| 165 |
+
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
|
| 166 |
+
if match:
|
| 167 |
+
entity_name, model = match.groups()
|
| 168 |
+
entities.append(
|
| 169 |
+
{
|
| 170 |
+
"entity": f"Problem {entity_name}",
|
| 171 |
+
"model": model,
|
| 172 |
+
"filename": filename,
|
| 173 |
+
"filepath": pkl_file,
|
| 174 |
+
}
|
| 175 |
+
)
|
| 176 |
+
elif dataset == "math":
|
| 177 |
+
# Format: qwen72_math_{number}.pkl
|
| 178 |
+
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
|
| 179 |
+
if match:
|
| 180 |
+
entity_name = match.group(1)
|
| 181 |
+
entities.append(
|
| 182 |
+
{
|
| 183 |
+
"entity": f"Math Problem {entity_name}",
|
| 184 |
+
"model": "qwen72b",
|
| 185 |
+
"filename": filename,
|
| 186 |
+
"filepath": pkl_file,
|
| 187 |
+
}
|
| 188 |
+
)
|
| 189 |
+
elif dataset == "aime":
|
| 190 |
+
# Format: aime_qwen72b_{number}.pkl
|
| 191 |
+
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
|
| 192 |
+
if match:
|
| 193 |
+
entity_name = match.group(1)
|
| 194 |
+
entities.append(
|
| 195 |
+
{
|
| 196 |
+
"entity": f"AIME Problem {entity_name}",
|
| 197 |
+
"model": "qwen72b",
|
| 198 |
+
"filename": filename,
|
| 199 |
+
"filepath": pkl_file,
|
| 200 |
+
}
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
# Generic pattern for other datasets
|
| 204 |
+
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 205 |
+
if match:
|
| 206 |
+
task, entity_name, model = match.groups()
|
| 207 |
+
entities.append(
|
| 208 |
+
{
|
| 209 |
+
"entity": entity_name,
|
| 210 |
+
"model": model,
|
| 211 |
+
"filename": filename,
|
| 212 |
+
"filepath": pkl_file,
|
| 213 |
+
}
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return jsonify({"entities": entities})
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@app.route("/api/load_existing_graph", methods=["POST"])
|
| 220 |
+
def load_existing_graph():
|
| 221 |
+
"""Load an existing graph from the stored pickle files"""
|
| 222 |
+
try:
|
| 223 |
+
data = request.get_json()
|
| 224 |
+
filepath = data.get("filepath")
|
| 225 |
+
|
| 226 |
+
if not filepath or not os.path.exists(filepath):
|
| 227 |
+
return jsonify({"error": "Graph file not found"}), 404
|
| 228 |
+
|
| 229 |
+
# Read and load the pickle file
|
| 230 |
+
try:
|
| 231 |
+
with open(filepath, "rb") as f:
|
| 232 |
+
graph = pickle.load(f)
|
| 233 |
+
except Exception as e:
|
| 234 |
+
return jsonify({"error": f"Error loading pickle file: {str(e)}"}), 500
|
| 235 |
+
|
| 236 |
+
if not isinstance(graph, TextPOAGraph):
|
| 237 |
+
return jsonify({"error": "File does not contain a valid POA graph"}), 400
|
| 238 |
+
|
| 239 |
+
# Convert to JSON format for vis.js
|
| 240 |
+
nodes = []
|
| 241 |
+
edges = []
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
# Get consensus nodes for coloring
|
| 245 |
+
consensus_nodes = set(graph.consensus_node_ids)
|
| 246 |
+
|
| 247 |
+
# Create nodes using the same logic as jsOutput
|
| 248 |
+
for node in graph.nodeiterator()():
|
| 249 |
+
title_text = ""
|
| 250 |
+
if node.sequences:
|
| 251 |
+
title_text += f"Sequences: {node.sequences}"
|
| 252 |
+
if node.variations:
|
| 253 |
+
title_text += ";;;".join(
|
| 254 |
+
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
|
| 255 |
+
)
|
| 256 |
+
title_text = title_text.replace('"', "'")
|
| 257 |
+
|
| 258 |
+
# Use the same color logic as jsOutput
|
| 259 |
+
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
|
| 260 |
+
|
| 261 |
+
node_data = {
|
| 262 |
+
"id": node.ID,
|
| 263 |
+
"label": f"{node.ID}: {node.text}",
|
| 264 |
+
"title": title_text,
|
| 265 |
+
"color": color,
|
| 266 |
+
}
|
| 267 |
+
nodes.append(node_data)
|
| 268 |
+
|
| 269 |
+
# Create edges using the same logic as jsOutput
|
| 270 |
+
for node in graph.nodeiterator()():
|
| 271 |
+
nodeID = node.ID # Keep as integer
|
| 272 |
+
for edge in node.outEdges:
|
| 273 |
+
target = edge # Keep as integer
|
| 274 |
+
weight = node.outEdges[edge].weight + 1.5
|
| 275 |
+
edge_data = {
|
| 276 |
+
"from": nodeID,
|
| 277 |
+
"to": target,
|
| 278 |
+
"value": weight,
|
| 279 |
+
"color": "#cae0e6",
|
| 280 |
+
"arrows": "to",
|
| 281 |
+
}
|
| 282 |
+
edges.append(edge_data)
|
| 283 |
+
except Exception as e:
|
| 284 |
+
return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
|
| 285 |
+
|
| 286 |
+
# Extract metadata from filename
|
| 287 |
+
filename = os.path.basename(filepath)
|
| 288 |
+
metadata = {}
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
# Different filename patterns for different datasets
|
| 292 |
+
if filename.startswith("bio_graph_"):
|
| 293 |
+
# Format: bio_graph_{entity}_merged_{model}.pkl
|
| 294 |
+
match = re.match(r"bio_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 295 |
+
if match:
|
| 296 |
+
entity_name, model = match.groups()
|
| 297 |
+
metadata = {
|
| 298 |
+
"task": "bio",
|
| 299 |
+
"entity": entity_name,
|
| 300 |
+
"model": model,
|
| 301 |
+
"filename": filename,
|
| 302 |
+
}
|
| 303 |
+
elif filename.startswith("fp_graph_"):
|
| 304 |
+
# Format: fp_graph_{number}_merged_{model}.pkl
|
| 305 |
+
match = re.match(r"fp_graph_(\d+)_merged_(\w+)\.pkl", filename)
|
| 306 |
+
if match:
|
| 307 |
+
entity_name, model = match.groups()
|
| 308 |
+
metadata = {
|
| 309 |
+
"task": "fp",
|
| 310 |
+
"entity": f"Problem {entity_name}",
|
| 311 |
+
"model": model,
|
| 312 |
+
"filename": filename,
|
| 313 |
+
}
|
| 314 |
+
elif filename.startswith("qwen72_math_"):
|
| 315 |
+
# Format: qwen72_math_{number}.pkl
|
| 316 |
+
match = re.match(r"qwen72_math_(\d+)\.pkl", filename)
|
| 317 |
+
if match:
|
| 318 |
+
entity_name = match.group(1)
|
| 319 |
+
metadata = {
|
| 320 |
+
"task": "math",
|
| 321 |
+
"entity": f"Math Problem {entity_name}",
|
| 322 |
+
"model": "qwen72b",
|
| 323 |
+
"filename": filename,
|
| 324 |
+
}
|
| 325 |
+
elif filename.startswith("aime_qwen72b_"):
|
| 326 |
+
# Format: aime_qwen72b_{number}.pkl
|
| 327 |
+
match = re.match(r"aime_qwen72b_(\d+)\.pkl", filename)
|
| 328 |
+
if match:
|
| 329 |
+
entity_name = match.group(1)
|
| 330 |
+
metadata = {
|
| 331 |
+
"task": "aime",
|
| 332 |
+
"entity": f"AIME Problem {entity_name}",
|
| 333 |
+
"model": "qwen72b",
|
| 334 |
+
"filename": filename,
|
| 335 |
+
}
|
| 336 |
+
else:
|
| 337 |
+
# Generic pattern for other datasets
|
| 338 |
+
match = re.match(r"(\w+)_graph_(.+?)_merged_(\w+)\.pkl", filename)
|
| 339 |
+
if match:
|
| 340 |
+
task, entity_name, model = match.groups()
|
| 341 |
+
metadata = {
|
| 342 |
+
"task": task,
|
| 343 |
+
"entity": entity_name,
|
| 344 |
+
"model": model,
|
| 345 |
+
"filename": filename,
|
| 346 |
+
}
|
| 347 |
+
except Exception:
|
| 348 |
+
# Don't fail the request if metadata extraction fails
|
| 349 |
+
pass
|
| 350 |
+
|
| 351 |
+
# Extract text from consensus nodes
|
| 352 |
+
consensus_text = ""
|
| 353 |
+
try:
|
| 354 |
+
consensus_nodes = set(graph.consensus_node_ids)
|
| 355 |
+
consensus_node_texts = []
|
| 356 |
+
for node in graph.nodeiterator()():
|
| 357 |
+
if node.ID in consensus_nodes and node.text and node.text.strip():
|
| 358 |
+
consensus_node_texts.append(node.text.strip())
|
| 359 |
+
consensus_text = " ".join(consensus_node_texts)
|
| 360 |
+
except Exception:
|
| 361 |
+
consensus_text = ""
|
| 362 |
+
|
| 363 |
+
# Check if we should compute consensus using decode_consensus
|
| 364 |
+
compute_consensus = data.get("compute_consensus", False)
|
| 365 |
+
if compute_consensus and decode_consensus:
|
| 366 |
+
try:
|
| 367 |
+
# Determine task from metadata or default to "bio"
|
| 368 |
+
task = metadata.get("task", "bio") if metadata else "bio"
|
| 369 |
+
consensus_text = decode_consensus(graph, selection_threshold=0.5, task=task)
|
| 370 |
+
except Exception as e:
|
| 371 |
+
print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
|
| 372 |
+
# Keep the original consensus text if decode_consensus fails
|
| 373 |
+
|
| 374 |
+
# Get original sequences
|
| 375 |
+
try:
|
| 376 |
+
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
|
| 377 |
+
# Process sequences: join with spaces and remove "||"
|
| 378 |
+
print(f"DEBUG: Raw sequences: {raw_sequences}")
|
| 379 |
+
original_sequences = []
|
| 380 |
+
for seq in raw_sequences:
|
| 381 |
+
if isinstance(seq, list):
|
| 382 |
+
# Join list elements with spaces
|
| 383 |
+
processed_seq = " ".join(str(item) for item in seq)
|
| 384 |
+
else:
|
| 385 |
+
processed_seq = str(seq)
|
| 386 |
+
# Remove "||" characters
|
| 387 |
+
processed_seq = processed_seq.replace("||", "")
|
| 388 |
+
print(f"DEBUG: Processed sequence: {processed_seq}")
|
| 389 |
+
original_sequences.append(processed_seq)
|
| 390 |
+
except Exception:
|
| 391 |
+
original_sequences = []
|
| 392 |
+
|
| 393 |
+
result = {
|
| 394 |
+
"success": True,
|
| 395 |
+
"nodes": nodes,
|
| 396 |
+
"edges": edges,
|
| 397 |
+
"num_sequences": graph.num_sequences,
|
| 398 |
+
"num_nodes": len(nodes),
|
| 399 |
+
"num_edges": len(edges),
|
| 400 |
+
"metadata": metadata,
|
| 401 |
+
"consensus_text": consensus_text,
|
| 402 |
+
"original_sequences": original_sequences,
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
return jsonify(result)
|
| 406 |
+
|
| 407 |
+
except Exception as e:
|
| 408 |
+
return jsonify({"error": str(e)}), 500
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@app.route("/api/create_graph", methods=["POST"])
|
| 412 |
+
def create_graph():
|
| 413 |
+
"""Create a POA graph from text sequences"""
|
| 414 |
+
try:
|
| 415 |
+
data = request.get_json()
|
| 416 |
+
sequences = data.get("sequences", [])
|
| 417 |
+
|
| 418 |
+
if len(sequences) < 2:
|
| 419 |
+
return jsonify({"error": "At least 2 sequences are required"}), 400
|
| 420 |
+
|
| 421 |
+
print(f"DEBUG: Creating graph with sequences: {sequences}")
|
| 422 |
+
|
| 423 |
+
# Create the graph with first sequence as string
|
| 424 |
+
graph = TextPOAGraph(sequences[0], label=0)
|
| 425 |
+
print("DEBUG: Initial graph created")
|
| 426 |
+
|
| 427 |
+
# Add remaining sequences
|
| 428 |
+
for i, sequence in enumerate(sequences[1:], 1):
|
| 429 |
+
print(f"DEBUG: Adding sequence {i}: {sequence}")
|
| 430 |
+
alignment = TextSeqGraphAlignment(
|
| 431 |
+
text=sequence,
|
| 432 |
+
graph=graph,
|
| 433 |
+
fastMethod=True,
|
| 434 |
+
globalAlign=True,
|
| 435 |
+
matchscore=1,
|
| 436 |
+
mismatchscore=-2,
|
| 437 |
+
gap_open=-1,
|
| 438 |
+
)
|
| 439 |
+
graph.incorporateSeqAlignment(alignment, sequence, label=i)
|
| 440 |
+
|
| 441 |
+
print("DEBUG: All sequences added")
|
| 442 |
+
|
| 443 |
+
# Refine the graph with proper domain and model parameters
|
| 444 |
+
graph.refine_graph(verbose=False, domain="text", model="gpt-4o-mini")
|
| 445 |
+
print("DEBUG: Graph refined")
|
| 446 |
+
|
| 447 |
+
# Convert to JSON format for vis.js
|
| 448 |
+
nodes = []
|
| 449 |
+
edges = []
|
| 450 |
+
|
| 451 |
+
try:
|
| 452 |
+
print("DEBUG: Starting to process graph data")
|
| 453 |
+
# Get consensus nodes for coloring (make it optional)
|
| 454 |
+
try:
|
| 455 |
+
consensus_nodes = set(graph.consensus_node_ids)
|
| 456 |
+
print(f"DEBUG: Consensus nodes: {consensus_nodes}")
|
| 457 |
+
except Exception as e:
|
| 458 |
+
print(f"DEBUG: Error getting consensus nodes: {e}")
|
| 459 |
+
consensus_nodes = set() # Fallback to empty set if consensus fails
|
| 460 |
+
|
| 461 |
+
# Create nodes using the same logic as jsOutput
|
| 462 |
+
for node in graph.nodeiterator()():
|
| 463 |
+
title_text = ""
|
| 464 |
+
if node.sequences:
|
| 465 |
+
title_text += f"Sequences: {node.sequences}"
|
| 466 |
+
if node.variations:
|
| 467 |
+
title_text += ";;;".join(
|
| 468 |
+
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
|
| 469 |
+
)
|
| 470 |
+
title_text = title_text.replace('"', "'")
|
| 471 |
+
|
| 472 |
+
# Use the same color logic as jsOutput
|
| 473 |
+
color = "#ceeab2" if node.ID in consensus_nodes else "#cae0e6"
|
| 474 |
+
|
| 475 |
+
node_data = {
|
| 476 |
+
"id": node.ID,
|
| 477 |
+
"label": f"{node.ID}: {node.text}",
|
| 478 |
+
"title": title_text,
|
| 479 |
+
"color": color,
|
| 480 |
+
}
|
| 481 |
+
nodes.append(node_data)
|
| 482 |
+
|
| 483 |
+
print(f"DEBUG: Created {len(nodes)} nodes")
|
| 484 |
+
|
| 485 |
+
# Create edges using the same logic as jsOutput
|
| 486 |
+
for node in graph.nodeiterator()():
|
| 487 |
+
nodeID = node.ID # Keep as integer
|
| 488 |
+
for edge in node.outEdges:
|
| 489 |
+
target = edge # Keep as integer
|
| 490 |
+
weight = node.outEdges[edge].weight + 1.5
|
| 491 |
+
edge_data = {
|
| 492 |
+
"from": nodeID,
|
| 493 |
+
"to": target,
|
| 494 |
+
"value": weight,
|
| 495 |
+
"color": "#cae0e6",
|
| 496 |
+
"arrows": "to",
|
| 497 |
+
}
|
| 498 |
+
edges.append(edge_data)
|
| 499 |
+
|
| 500 |
+
print(f"DEBUG: Created {len(edges)} edges")
|
| 501 |
+
except Exception as e:
|
| 502 |
+
print(f"DEBUG: Error processing graph data: {e}")
|
| 503 |
+
return jsonify({"error": f"Error processing graph data: {str(e)}"}), 500
|
| 504 |
+
|
| 505 |
+
# Extract text from consensus nodes
|
| 506 |
+
consensus_text = ""
|
| 507 |
+
try:
|
| 508 |
+
consensus_node_texts = []
|
| 509 |
+
for node in graph.nodeiterator()():
|
| 510 |
+
if node.ID in consensus_nodes and node.text and node.text.strip():
|
| 511 |
+
consensus_node_texts.append(node.text.strip())
|
| 512 |
+
consensus_text = " ".join(consensus_node_texts)
|
| 513 |
+
except Exception:
|
| 514 |
+
consensus_text = ""
|
| 515 |
+
|
| 516 |
+
# Check if we should compute consensus using decode_consensus
|
| 517 |
+
compute_consensus = data.get("compute_consensus", False)
|
| 518 |
+
if compute_consensus and decode_consensus:
|
| 519 |
+
try:
|
| 520 |
+
# Default to "bio" task for new graphs
|
| 521 |
+
consensus_text = decode_consensus(graph, selection_threshold=0.5, task="bio")
|
| 522 |
+
except Exception as e:
|
| 523 |
+
print(f"DEBUG: Error computing consensus with decode_consensus: {e}")
|
| 524 |
+
# Keep the original consensus text if decode_consensus fails
|
| 525 |
+
|
| 526 |
+
# Get original sequences
|
| 527 |
+
try:
|
| 528 |
+
raw_sequences = graph._seqs if hasattr(graph, "_seqs") else []
|
| 529 |
+
# Process sequences: join with spaces and remove "||"
|
| 530 |
+
original_sequences = []
|
| 531 |
+
for seq in raw_sequences:
|
| 532 |
+
if isinstance(seq, list):
|
| 533 |
+
# Join list elements with spaces
|
| 534 |
+
processed_seq = " ".join(str(item) for item in seq)
|
| 535 |
+
else:
|
| 536 |
+
processed_seq = str(seq)
|
| 537 |
+
# Remove "||" characters
|
| 538 |
+
processed_seq = processed_seq.replace("||", "")
|
| 539 |
+
original_sequences.append(processed_seq)
|
| 540 |
+
except Exception:
|
| 541 |
+
original_sequences = []
|
| 542 |
+
|
| 543 |
+
print("DEBUG: Returning success response")
|
| 544 |
+
return jsonify(
|
| 545 |
+
{
|
| 546 |
+
"success": True,
|
| 547 |
+
"nodes": nodes,
|
| 548 |
+
"edges": edges,
|
| 549 |
+
"num_sequences": len(sequences),
|
| 550 |
+
"num_nodes": len(nodes),
|
| 551 |
+
"num_edges": len(edges),
|
| 552 |
+
"original_sequences": original_sequences,
|
| 553 |
+
"consensus_text": consensus_text,
|
| 554 |
+
}
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
except Exception as e:
|
| 558 |
+
print(f"DEBUG: Main exception in create_graph: {e}")
|
| 559 |
+
return jsonify({"error": str(e)}), 500
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
@app.route("/api/save_graph", methods=["POST"])
|
| 563 |
+
def save_graph():
|
| 564 |
+
"""Save a POA graph to a pickle file"""
|
| 565 |
+
try:
|
| 566 |
+
data = request.get_json()
|
| 567 |
+
sequences = data.get("sequences", [])
|
| 568 |
+
filename = data.get("filename", "graph.pkl")
|
| 569 |
+
|
| 570 |
+
if len(sequences) < 2:
|
| 571 |
+
return jsonify({"error": "At least 2 sequences are required"}), 400
|
| 572 |
+
|
| 573 |
+
# Create the graph
|
| 574 |
+
graph = TextPOAGraph(sequences[0], label=0)
|
| 575 |
+
|
| 576 |
+
# Add remaining sequences
|
| 577 |
+
for i, sequence in enumerate(sequences[1:], 1):
|
| 578 |
+
alignment = TextSeqGraphAlignment(
|
| 579 |
+
text=sequence,
|
| 580 |
+
graph=graph,
|
| 581 |
+
fastMethod=True,
|
| 582 |
+
globalAlign=True,
|
| 583 |
+
matchscore=1,
|
| 584 |
+
mismatchscore=-2,
|
| 585 |
+
gap_open=-1,
|
| 586 |
+
)
|
| 587 |
+
graph.incorporateSeqAlignment(alignment, sequence, label=i)
|
| 588 |
+
|
| 589 |
+
# Refine the graph
|
| 590 |
+
graph.refine_graph(verbose=False)
|
| 591 |
+
|
| 592 |
+
# Save to pickle file
|
| 593 |
+
graph.save_to_pickle(filename)
|
| 594 |
+
|
| 595 |
+
return jsonify(
|
| 596 |
+
{"success": True, "filename": filename, "message": f"Graph saved to {filename}"}
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
except Exception as e:
|
| 600 |
+
return jsonify({"error": str(e)}), 500
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@app.route("/api/graph_info", methods=["POST"])
|
| 604 |
+
def graph_info():
|
| 605 |
+
"""Get information about a graph without creating the full visualization"""
|
| 606 |
+
try:
|
| 607 |
+
data = request.get_json()
|
| 608 |
+
sequences = data.get("sequences", [])
|
| 609 |
+
|
| 610 |
+
if len(sequences) < 2:
|
| 611 |
+
return jsonify({"error": "At least 2 sequences are required"}), 400
|
| 612 |
+
|
| 613 |
+
# Create the graph
|
| 614 |
+
graph = TextPOAGraph(sequences[0], label=0)
|
| 615 |
+
|
| 616 |
+
# Add remaining sequences
|
| 617 |
+
for i, sequence in enumerate(sequences[1:], 1):
|
| 618 |
+
alignment = TextSeqGraphAlignment(
|
| 619 |
+
text=sequence,
|
| 620 |
+
graph=graph,
|
| 621 |
+
fastMethod=True,
|
| 622 |
+
globalAlign=True,
|
| 623 |
+
matchscore=1,
|
| 624 |
+
mismatchscore=-2,
|
| 625 |
+
gap_open=-1,
|
| 626 |
+
)
|
| 627 |
+
graph.incorporateSeqAlignment(alignment, sequence, label=i)
|
| 628 |
+
|
| 629 |
+
# Refine the graph
|
| 630 |
+
graph.refine_graph(verbose=False)
|
| 631 |
+
|
| 632 |
+
# Get consensus response
|
| 633 |
+
consensus_text = graph.consensus_response()
|
| 634 |
+
|
| 635 |
+
return jsonify(
|
| 636 |
+
{
|
| 637 |
+
"success": True,
|
| 638 |
+
"num_sequences": len(sequences),
|
| 639 |
+
"num_nodes": graph._nnodes,
|
| 640 |
+
"consensus_text": consensus_text,
|
| 641 |
+
"consensus_node_ids": graph.consensus_node_ids,
|
| 642 |
+
}
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
except Exception as e:
|
| 646 |
+
return jsonify({"error": str(e)}), 500
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
if __name__ == "__main__":
|
| 650 |
+
print("Starting POA Graph Web Interface Server...")
|
| 651 |
+
print("Open http://localhost:8080 in your browser")
|
| 652 |
+
app.run(debug=True, host="0.0.0.0", port=8080)
|
web_interface/start.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo "Starting POA Graph Web Interface..."
|
| 4 |
+
|
| 5 |
+
# Check if Python is installed
|
| 6 |
+
if ! command -v python3 &> /dev/null; then
|
| 7 |
+
echo "Error: Python 3 is not installed or not in PATH"
|
| 8 |
+
exit 1
|
| 9 |
+
fi
|
| 10 |
+
|
| 11 |
+
# Check if we're in the right directory
|
| 12 |
+
if [ ! -f "server.py" ]; then
|
| 13 |
+
echo "Error: server.py not found. Please run this script from the web_interface directory."
|
| 14 |
+
exit 1
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
# Install dependencies if requirements.txt exists
|
| 18 |
+
if [ -f "requirements.txt" ]; then
|
| 19 |
+
echo "Installing dependencies..."
|
| 20 |
+
pip3 install -r requirements.txt
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Start the server
|
| 24 |
+
echo "Starting server on http://localhost:5000"
|
| 25 |
+
echo "Press Ctrl+C to stop the server"
|
| 26 |
+
python3 server.py
|
web_interface/test_server.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for the POA Graph Web Interface Server
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
BASE_URL = "http://localhost:8080"
|
| 10 |
+
|
| 11 |
+
def test_datasets():
|
| 12 |
+
"""Test the datasets endpoint"""
|
| 13 |
+
print("Testing /api/datasets...")
|
| 14 |
+
try:
|
| 15 |
+
response = requests.get(f"{BASE_URL}/api/datasets")
|
| 16 |
+
if response.status_code == 200:
|
| 17 |
+
data = response.json()
|
| 18 |
+
print(f"β Success! Found {len(data['datasets'])} datasets:")
|
| 19 |
+
for dataset in data['datasets']:
|
| 20 |
+
print(f" - {dataset['display_name']}: {dataset['count']} graphs")
|
| 21 |
+
else:
|
| 22 |
+
print(f"β Error: {response.status_code}")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"β Exception: {e}")
|
| 25 |
+
|
| 26 |
+
def test_entities():
|
| 27 |
+
"""Test the entities endpoint"""
|
| 28 |
+
print("\nTesting /api/entities?dataset=bio...")
|
| 29 |
+
try:
|
| 30 |
+
response = requests.get(f"{BASE_URL}/api/entities?dataset=bio")
|
| 31 |
+
if response.status_code == 200:
|
| 32 |
+
data = response.json()
|
| 33 |
+
print(f"β Success! Found {len(data['entities'])} entities in bio dataset")
|
| 34 |
+
if data['entities']:
|
| 35 |
+
print(f" Sample entity: {data['entities'][0]['entity']} ({data['entities'][0]['model']})")
|
| 36 |
+
else:
|
| 37 |
+
print(f"β Error: {response.status_code}")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"β Exception: {e}")
|
| 40 |
+
|
| 41 |
+
def test_load_graph():
|
| 42 |
+
"""Test loading a specific graph"""
|
| 43 |
+
print("\nTesting /api/load_existing_graph...")
|
| 44 |
+
try:
|
| 45 |
+
# First get entities to find a valid filepath
|
| 46 |
+
response = requests.get(f"{BASE_URL}/api/entities?dataset=bio")
|
| 47 |
+
if response.status_code == 200:
|
| 48 |
+
data = response.json()
|
| 49 |
+
if data['entities']:
|
| 50 |
+
filepath = data['entities'][0]['filepath']
|
| 51 |
+
print(f" Loading graph: {filepath}")
|
| 52 |
+
|
| 53 |
+
# Test loading the graph
|
| 54 |
+
response = requests.post(
|
| 55 |
+
f"{BASE_URL}/api/load_existing_graph",
|
| 56 |
+
json={"filepath": filepath}
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if response.status_code == 200:
|
| 60 |
+
graph_data = response.json()
|
| 61 |
+
if graph_data['success']:
|
| 62 |
+
print(f"β Success! Loaded graph with {graph_data['num_nodes']} nodes and {graph_data['num_edges']} edges")
|
| 63 |
+
print(f" Entity: {graph_data['metadata']['entity']}")
|
| 64 |
+
print(f" Model: {graph_data['metadata']['model']}")
|
| 65 |
+
else:
|
| 66 |
+
print(f"β Error: {graph_data['error']}")
|
| 67 |
+
else:
|
| 68 |
+
print(f"β Error: {response.status_code}")
|
| 69 |
+
else:
|
| 70 |
+
print("β No entities found to test with")
|
| 71 |
+
else:
|
| 72 |
+
print(f"β Error getting entities: {response.status_code}")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"β Exception: {e}")
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
print("Testing POA Graph Web Interface Server...")
|
| 78 |
+
print("Make sure the server is running on http://localhost:8080")
|
| 79 |
+
print("=" * 50)
|
| 80 |
+
|
| 81 |
+
test_datasets()
|
| 82 |
+
test_entities()
|
| 83 |
+
test_load_graph()
|
| 84 |
+
|
| 85 |
+
print("\n" + "=" * 50)
|
| 86 |
+
print("Test completed!")
|