Spaces:
Running
Running
eigentom commited on
Commit ·
90c099b
1
Parent(s): 37d42f7
Initial Update
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README copy.md +313 -0
- README.md +4 -6
- app.py +470 -0
- example.py +54 -0
- gradio_app/__init__.py +9 -0
- gradio_app/app.py +466 -0
- gradio_app/components/__init__.py +73 -0
- gradio_app/components/formatters.py +504 -0
- gradio_app/components/header.py +39 -0
- gradio_app/components/results_panel.py +193 -0
- gradio_app/components/settings.py +82 -0
- gradio_app/components/styles.py +592 -0
- gradio_app/components/upload_section.py +117 -0
- gradio_app/utils_single_paper_inference.py +276 -0
- requirements.txt +26 -0
- scripts/gpt_oss_start_vllm_service.sh +47 -0
- scripts/start_load_balancer.sh +87 -0
- scripts/start_reranker_service.sh +116 -0
- scripts/start_vllm_with_balancer.sh +216 -0
- scripts/stop_reranker_services.sh +106 -0
- scripts/stop_vllm_services.sh +267 -0
- shared/configs/config.yaml +97 -0
- shared/configs/llm_service_config.yaml +57 -0
- shared/configs/prompts.yaml +580 -0
- shared/configs/reranker_endpoint_pool.txt +8 -0
- shared/configs/vllm_endpoint_pool.txt +7 -0
- shared/utils/__init__.py +113 -0
- shared/utils/asta_api_key_pool.py +205 -0
- shared/utils/gpt_service.py +210 -0
- shared/utils/json_parser.py +428 -0
- shared/utils/llm_service.py +64 -0
- shared/utils/llm_service_factory.py +191 -0
- shared/utils/load_balancer.py +382 -0
- shared/utils/mock_llm_service.py +280 -0
- shared/utils/prompt_loader.py +220 -0
- shared/utils/reranker.py +275 -0
- shared/utils/reranker_api_service.py +221 -0
- shared/utils/reranker_endpoint_pool.py +160 -0
- shared/utils/reranker_pool.py +78 -0
- shared/utils/review_logger.py +306 -0
- shared/utils/vllm_endpoint_pool.py +257 -0
- shared/utils/vllm_service.py +314 -0
- shared/utils/vllm_service_simple.py +314 -0
- src/__init__.py +6 -0
- src/evaluator/1_get_rubrics.py +601 -0
- src/evaluator/2_evaluate.py +1730 -0
- src/evaluator/2_evaluate_agenticreview.py +1866 -0
- src/evaluator/2_evaluate_aiscientist.py +1866 -0
- src/evaluator/2_evaluate_cyclereviewer.py +1837 -0
- src/evaluator/configs.yaml +38 -0
README copy.md
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReviewGrounder: Improving Review Substantiveness with Rubric-Guided, Tool-Integrated Agents
|
| 2 |
+
|
| 3 |
+
This repository accompanies the paper: *"ReviewGrounder: Improving Review Substantiveness with Rubric-Guided, Tool-Integrated Agents"*. It contains the implementation of **ReviewGrounder**, a rubric-guided, tool-integrated multi-agent framework for generating substantive, evidence-grounded academic paper reviews.
|
| 4 |
+
|
| 5 |
+
ReviewGrounder addresses the key limitation of existing LLM-based reviewers—their tendency to produce superficial, formulaic comments lacking substantive feedback—by explicitly leveraging reviewer rubrics and contextual grounding in existing work.
|
| 6 |
+
|
| 7 |
+
## System Architecture
|
| 8 |
+
|
| 9 |
+
ReviewGrounder implements a multi-agent framework with clear role separation:
|
| 10 |
+
|
| 11 |
+
### Drafting Agent (`paper_reviewer.py`)
|
| 12 |
+
The **drafter** generates an initial review draft based solely on the paper content. This stage produces a structured review with strengths, weaknesses, suggestions, and questions, but may lack deep contextual grounding.
|
| 13 |
+
|
| 14 |
+
### Grounding Agents
|
| 15 |
+
|
| 16 |
+
1. **Related Work Searcher** (`related_work_searcher.py`):
|
| 17 |
+
- Generates search keywords from paper content
|
| 18 |
+
- Retrieves relevant papers via academic APIs
|
| 19 |
+
- Summarizes and analyzes related work
|
| 20 |
+
- Provides context for novelty assessment
|
| 21 |
+
|
| 22 |
+
2. **Paper Results Analyzer** (`paper_results_analyzer.py`):
|
| 23 |
+
- Extracts and analyzes experimental sections
|
| 24 |
+
- Summarizes experimental setup, results, and findings
|
| 25 |
+
- Identifies limitations and gaps
|
| 26 |
+
|
| 27 |
+
3. **Paper Insight Miner** (`paper_insight_miner.py`):
|
| 28 |
+
- Extracts key insights and contributions
|
| 29 |
+
- Identifies technical strengths and weaknesses
|
| 30 |
+
|
| 31 |
+
4. **Review Refiner** (`review_refiner.py`):
|
| 32 |
+
- Synthesizes information from all grounding agents
|
| 33 |
+
- Refines the initial draft with evidence-based critiques
|
| 34 |
+
- Ensures suggestions are actionable and well-justified
|
| 35 |
+
- Maintains consistency across review sections
|
| 36 |
+
|
| 37 |
+
### Evaluation System (`src/evaluator/`)
|
| 38 |
+
The **ReviewBench** evaluation framework:
|
| 39 |
+
- **Rubric Generation**: Creates paper-specific rubrics from venue guidelines, paper content, and human reviews
|
| 40 |
+
- **LLM-based Evaluation**: Deep qualitative assessment aligned with rubrics
|
| 41 |
+
- **Rule-based Metrics**: Quantitative metrics (MSE, MAE, Spearman correlation)
|
| 42 |
+
|
| 43 |
+
## Installation
|
| 44 |
+
|
| 45 |
+
### Prerequisites
|
| 46 |
+
|
| 47 |
+
- Python >= 3.8
|
| 48 |
+
- CUDA-capable GPU (for local vLLM deployment, optional if using OpenAI API)
|
| 49 |
+
- Sufficient GPU memory for your chosen model (if using vLLM)
|
| 50 |
+
|
| 51 |
+
### Setup
|
| 52 |
+
|
| 53 |
+
1. Clone the repository:
|
| 54 |
+
```bash
|
| 55 |
+
git clone <repository-url>
|
| 56 |
+
cd ReviewGrounder
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
2. Install dependencies:
|
| 60 |
+
```bash
|
| 61 |
+
uv venv
|
| 62 |
+
source .venv/bin/activate
|
| 63 |
+
uv pip install -r requirements.txt
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
3. Configure your API keys and settings:
|
| 67 |
+
- Copy `shared/configs/config.yaml` and customize as needed
|
| 68 |
+
- Set environment variables:
|
| 69 |
+
- `ASTA_API_KEY`: For paper search via Asta API (recommended)
|
| 70 |
+
- `OPENAI_API_KEY`: If using OpenAI API instead of vLLM
|
| 71 |
+
- `S2_API_KEY`: Alternative paper search API (optional)
|
| 72 |
+
|
| 73 |
+
4. (Optional) If using local vLLM, start your vLLM service:
|
| 74 |
+
```bash
|
| 75 |
+
# Start vLLM service on a single port
|
| 76 |
+
bash scripts/gpt_oss_start_vllm_service.sh
|
| 77 |
+
|
| 78 |
+
# Or start multiple services with load balancing
|
| 79 |
+
bash scripts/start_vllm_with_balancer.sh
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Quick Start
|
| 83 |
+
|
| 84 |
+
### Basic Usage
|
| 85 |
+
|
| 86 |
+
Generate a review using the command-line interface:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
python -m src.reviewer_agent.cli --paper paper.json --output review.json
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Where `paper.json` contains your paper data in JSON format with fields like `title`, `abstract`, `text`, etc.
|
| 93 |
+
|
| 94 |
+
### Using the Python API
|
| 95 |
+
|
| 96 |
+
For programmatic access:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
from src.reviewer_agent import review_paper_with_refiner
|
| 100 |
+
|
| 101 |
+
# Load your paper data
|
| 102 |
+
paper_data = {
|
| 103 |
+
"title": "Your Paper Title",
|
| 104 |
+
"abstract": "Paper abstract...",
|
| 105 |
+
"text": "Full paper text...",
|
| 106 |
+
# ... other fields
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Generate review (drafting + grounding stages)
|
| 110 |
+
review = review_paper_with_refiner(paper_data=paper_data)
|
| 111 |
+
print(review)
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
The `review_paper_with_refiner` function implements the full ReviewGrounder pipeline:
|
| 115 |
+
1. **Drafting**: Generates initial review draft
|
| 116 |
+
2. **Grounding**: Retrieves related work, analyzes results, extracts insights
|
| 117 |
+
3. **Refinement**: Synthesizes all information into a refined, evidence-grounded review
|
| 118 |
+
|
| 119 |
+
## Usage Examples
|
| 120 |
+
|
| 121 |
+
### Generate a Review with Related Work Context
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
python -m src.reviewer_agent.cli \
|
| 125 |
+
--paper paper.json \
|
| 126 |
+
--max-related-papers 15 \
|
| 127 |
+
--review-format detailed \
|
| 128 |
+
--output review.json
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
### Filter Related Work by Date and Venue
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
python -m src.reviewer_agent.cli \
|
| 135 |
+
--paper paper.json \
|
| 136 |
+
--publication-date-range "2020:" \
|
| 137 |
+
--venues "ICLR,NeurIPS,ICML" \
|
| 138 |
+
--output review.json
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### Use Custom vLLM Endpoint
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
python -m src.reviewer_agent.cli \
|
| 145 |
+
--paper paper.json \
|
| 146 |
+
--vllm-url "http://your-server:8000/v1" \
|
| 147 |
+
--output review.json
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Evaluate Reviews on ReviewBench
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
# 1. Generate reviews
|
| 154 |
+
from src.reviewer_agent import review_paper_with_refiner
|
| 155 |
+
review = review_paper_with_refiner(paper_data={...})
|
| 156 |
+
|
| 157 |
+
# 2. Evaluate reviews using ReviewBench
|
| 158 |
+
from src.evaluator import evaluate_reviews
|
| 159 |
+
results = evaluate_reviews(parquet_path="reviews.parquet")
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## Directory Structure
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
anonymize_codebase/
|
| 166 |
+
├── src/
|
| 167 |
+
│ ├── reviewer_agent/ # ReviewGrounder implementation
|
| 168 |
+
│ │ ├── __init__.py
|
| 169 |
+
│ │ ├── paper_reviewer.py # Drafting agent
|
| 170 |
+
│ │ ├── review_refiner.py # Grounding agent: review refinement
|
| 171 |
+
│ │ ├── related_work_searcher.py # Grounding agent: literature search
|
| 172 |
+
│ │ ├── paper_results_summarizer.py # Grounding agent: results analysis
|
| 173 |
+
│ │ ├── paper_insight_miner.py # Grounding agent: insight extraction
|
| 174 |
+
│ │ ├── main_pipeline.py # Full pipeline orchestration
|
| 175 |
+
│ │ ├── cli.py # Command-line interface
|
| 176 |
+
│ │ └── paper_search/ # Paper search APIs
|
| 177 |
+
│ │ ├── asta_api.py
|
| 178 |
+
│ │ ├── semantic_scholar_api.py
|
| 179 |
+
│ │ └── paper_retriever.py
|
| 180 |
+
│ │
|
| 181 |
+
│ └── evaluator/ # ReviewBench evaluation framework
|
| 182 |
+
│ ├── 1_get_rubrics.py # Rubric generation
|
| 183 |
+
│ ├── 2_evaluate.py # Review evaluation
|
| 184 |
+
│ └── ...
|
| 185 |
+
│
|
| 186 |
+
├── shared/
|
| 187 |
+
│ ├── utils/ # Shared utilities
|
| 188 |
+
│ │ ├── llm_service.py # LLM service abstraction
|
| 189 |
+
│ │ ├── load_balancer.py # Load balancing for vLLM
|
| 190 |
+
│ │ ├── reranker.py # Paper reranking
|
| 191 |
+
│ │ └── ...
|
| 192 |
+
│ │
|
| 193 |
+
│ └── configs/ # Configuration files
|
| 194 |
+
│ ├── config.yaml # Main config
|
| 195 |
+
│ ├── llm_service_config.yaml # LLM service settings
|
| 196 |
+
│ └── prompts.yaml # Review generation prompts
|
| 197 |
+
│
|
| 198 |
+
├── scripts/ # Utility scripts
|
| 199 |
+
│ ├── start_vllm_with_balancer.sh
|
| 200 |
+
│ ├── start_load_balancer.sh
|
| 201 |
+
│ └── ...
|
| 202 |
+
│
|
| 203 |
+
├── requirements.txt # Python dependencies
|
| 204 |
+
└── README.md # This file
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
## Configuration Guide
|
| 208 |
+
|
| 209 |
+
### LLM Service Configuration
|
| 210 |
+
|
| 211 |
+
ReviewGrounder supports two LLM backends:
|
| 212 |
+
|
| 213 |
+
1. **vLLM** (recommended for local deployment): Fast inference with local GPU
|
| 214 |
+
- Default: GPT-OSS-120B for grounding stage
|
| 215 |
+
- Can use smaller models (e.g., Phi-4-14B) for drafting stage
|
| 216 |
+
|
| 217 |
+
2. **OpenAI API**: Cloud-based, no local GPU required
|
| 218 |
+
|
| 219 |
+
Configure in `shared/configs/llm_service_config.yaml`:
|
| 220 |
+
|
| 221 |
+
```yaml
|
| 222 |
+
vllm:
|
| 223 |
+
base_url: "http://localhost:8000/"
|
| 224 |
+
model_name: "openai/gpt-oss-120b"
|
| 225 |
+
max_tokens: 16384
|
| 226 |
+
|
| 227 |
+
gpt:
|
| 228 |
+
enabled: false
|
| 229 |
+
api_key: "your-api-key-here"
|
| 230 |
+
model_name: "gpt-4o"
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
We offer the option of assigning different backends for each agent.
|
| 234 |
+
```yaml
|
| 235 |
+
llm_assignments:
|
| 236 |
+
keyword_generator: "vllm" # For related work search
|
| 237 |
+
paper_summarizer: "vllm" # For results summarization
|
| 238 |
+
reviewer: "vllm" # For drafting stage
|
| 239 |
+
refiner: "vllm" # For grounding/refinement stage
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
### Paper Search Configuration
|
| 243 |
+
|
| 244 |
+
Configure paper search APIs in `shared/configs/config.yaml`:
|
| 245 |
+
|
| 246 |
+
```yaml
|
| 247 |
+
paper_search:
|
| 248 |
+
asta:
|
| 249 |
+
api_key: null # Set via ASTA_API_KEY env var
|
| 250 |
+
endpoint: "https://asta-tools.allen.ai/mcp/v1"
|
| 251 |
+
|
| 252 |
+
semantic_scholar:
|
| 253 |
+
api_key: null # Set via S2_API_KEY env var
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
### Review Format Options
|
| 257 |
+
|
| 258 |
+
Choose from different review formats:
|
| 259 |
+
- `detailed`: Comprehensive review with all sections (default)
|
| 260 |
+
- `summary`: Concise review summary
|
| 261 |
+
- `structured`: Structured format with specific sections
|
| 262 |
+
- `strict_detailed`: Strict adherence to detailed format requirements
|
| 263 |
+
|
| 264 |
+
## Load Balancing for vLLM
|
| 265 |
+
|
| 266 |
+
For production use with multiple GPUs, you can set up load balancing:
|
| 267 |
+
|
| 268 |
+
```bash
|
| 269 |
+
# Start 4 vLLM services on ports 8000-8003
|
| 270 |
+
bash scripts/gpt_oss_start_vllm_service.sh
|
| 271 |
+
|
| 272 |
+
# Start load balancer on port 8004
|
| 273 |
+
python -m shared.utils.load_balancer \
|
| 274 |
+
--backends http://localhost:8000/v1 http://localhost:8001/v1 http://localhost:8002/v1 http://localhost:8003/v1 \
|
| 275 |
+
--port 8004 \
|
| 276 |
+
--strategy round_robin
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
Then point your config to `http://localhost:8004/v1`.
|
| 280 |
+
|
| 281 |
+
## Evaluation: ReviewBench
|
| 282 |
+
|
| 283 |
+
ReviewGrounder is evaluated on **ReviewBench**, a benchmark that:
|
| 284 |
+
|
| 285 |
+
- Leverages paper-specific rubrics derived from:
|
| 286 |
+
- Official venue guidelines (e.g., ACL, ICML, NeurIPS, ICLR)
|
| 287 |
+
- Paper content
|
| 288 |
+
- Human-written reviews
|
| 289 |
+
|
| 290 |
+
- Evaluates reviews across diverse dimensions:
|
| 291 |
+
- Evidence-based critique
|
| 292 |
+
- Constructive tone
|
| 293 |
+
- Technical depth
|
| 294 |
+
- And more...
|
| 295 |
+
|
| 296 |
+
- Measures both:
|
| 297 |
+
- Alignment with human judgments (scores, decisions)
|
| 298 |
+
- Rubric-based quality (beyond just outcome prediction)
|
| 299 |
+
|
| 300 |
+
See `src/evaluator/` for the evaluation framework implementation.
|
| 301 |
+
|
| 302 |
+
## Citation
|
| 303 |
+
|
| 304 |
+
If you use ReviewGrounder in your research, please cite:
|
| 305 |
+
|
| 306 |
+
```bibtex
|
| 307 |
+
@inproceedings{reviewgrounder2026,
|
| 308 |
+
title={ReviewGrounder: Improving Review Substantiveness with Rubric-Guided, Tool-Integrated Agents},
|
| 309 |
+
author={Anonymous},
|
| 310 |
+
booktitle={Proceedings of ACL 2026},
|
| 311 |
+
year={2026}
|
| 312 |
+
}
|
| 313 |
+
```
|
README.md
CHANGED
|
@@ -1,14 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.5.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
short_description: This is the interactive demo of Review Grounder
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Test Reviewgrounder
|
| 3 |
+
emoji: 💻
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.5.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Review Grounder - Gradio App
|
| 3 |
+
|
| 4 |
+
Main entry point for the Hugging Face Space.
|
| 5 |
+
This module orchestrates the UI components and handles the review pipeline.
|
| 6 |
+
|
| 7 |
+
The app allows users to:
|
| 8 |
+
1. Upload a research paper in PDF format
|
| 9 |
+
2. Configure LLM settings (optional, uses OpenAI defaults)
|
| 10 |
+
3. Generate a comprehensive AI-powered review
|
| 11 |
+
4. View intermediate results from each pipeline stage
|
| 12 |
+
|
| 13 |
+
Components are organized in the `components/` directory for maintainability.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import tempfile
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Tuple, Iterator
|
| 24 |
+
|
| 25 |
+
import gradio as gr
|
| 26 |
+
|
| 27 |
+
# Import utility for running the review pipeline
|
| 28 |
+
from gradio_app.utils_single_paper_inference import (
|
| 29 |
+
run_single_paper_review_from_pdf_stepwise,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Import UI components
|
| 33 |
+
from gradio_app.components import (
|
| 34 |
+
get_custom_css,
|
| 35 |
+
create_header,
|
| 36 |
+
create_upload_section,
|
| 37 |
+
create_advanced_settings,
|
| 38 |
+
create_results_panel,
|
| 39 |
+
format_initial_review_html,
|
| 40 |
+
format_related_work_html,
|
| 41 |
+
format_results_html,
|
| 42 |
+
format_insights_html,
|
| 43 |
+
format_final_review,
|
| 44 |
+
format_raw_json,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ============================================================================
|
| 49 |
+
# App Configuration
|
| 50 |
+
# ============================================================================
|
| 51 |
+
|
| 52 |
+
APP_TITLE = "Review Grounder"
|
| 53 |
+
APP_DESCRIPTION = "AI-Powered Research Paper Review"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _raw_json_md_to_file(raw_json_md: str) -> str:
|
| 57 |
+
"""
|
| 58 |
+
Extract JSON from Raw JSON markdown (```json ... ```) and write to a temp file.
|
| 59 |
+
Returns the file path for gr.DownloadButton.
|
| 60 |
+
"""
|
| 61 |
+
if not raw_json_md or not raw_json_md.strip():
|
| 62 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
|
| 63 |
+
f.write("{}")
|
| 64 |
+
return f.name
|
| 65 |
+
text = raw_json_md.strip()
|
| 66 |
+
match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
|
| 67 |
+
if match:
|
| 68 |
+
text = match.group(1).strip()
|
| 69 |
+
fd, path = tempfile.mkstemp(suffix=".json", prefix="review_")
|
| 70 |
+
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
| 71 |
+
f.write(text)
|
| 72 |
+
return path
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ============================================================================
|
| 76 |
+
# Environment Check
|
| 77 |
+
# ============================================================================
|
| 78 |
+
|
| 79 |
+
def _check_env() -> Tuple[bool, str]:
|
| 80 |
+
"""
|
| 81 |
+
Check for required environment variables.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Tuple of (success, message)
|
| 85 |
+
"""
|
| 86 |
+
missing = []
|
| 87 |
+
if not os.environ.get("ASTA_API_KEY"):
|
| 88 |
+
missing.append("ASTA_API_KEY")
|
| 89 |
+
|
| 90 |
+
if missing:
|
| 91 |
+
return False, (
|
| 92 |
+
"Missing environment variables: "
|
| 93 |
+
+ ", ".join(missing)
|
| 94 |
+
+ ".\nPlease configure them in your Hugging Face Space settings."
|
| 95 |
+
)
|
| 96 |
+
return True, "Environment variables detected correctly."
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ============================================================================
|
| 100 |
+
# Review Pipeline Handler
|
| 101 |
+
# ============================================================================
|
| 102 |
+
|
| 103 |
+
def review_pdf_file(
|
| 104 |
+
file_obj,
|
| 105 |
+
api_base_url: str,
|
| 106 |
+
api_key: str,
|
| 107 |
+
model_name: str,
|
| 108 |
+
show_log: bool,
|
| 109 |
+
show_raw_json: bool,
|
| 110 |
+
) -> Iterator[Tuple[str, str, str, str, str, str, str, gr.update, gr.update, gr.update]]:
|
| 111 |
+
"""
|
| 112 |
+
Main callback: process PDF through the review pipeline with real-time updates.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
file_obj: Uploaded PDF file
|
| 116 |
+
api_base_url: LLM API endpoint URL
|
| 117 |
+
api_key: API key for LLM provider
|
| 118 |
+
model_name: Model identifier
|
| 119 |
+
show_log: Whether to display the execution log
|
| 120 |
+
show_raw_json: Whether to display raw JSON output
|
| 121 |
+
|
| 122 |
+
Yields:
|
| 123 |
+
Tuple of all output component updates (no overview)
|
| 124 |
+
"""
|
| 125 |
+
log_lines: list[str] = []
|
| 126 |
+
|
| 127 |
+
def _log(msg: str) -> None:
|
| 128 |
+
log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
|
| 129 |
+
|
| 130 |
+
def _log_text() -> str:
|
| 131 |
+
return "\n".join(log_lines) if log_lines else ""
|
| 132 |
+
|
| 133 |
+
# Validate file upload
|
| 134 |
+
if file_obj is None:
|
| 135 |
+
gr.Warning("Please upload a PDF file to start the review.")
|
| 136 |
+
_log("⚠️ Please upload a PDF file to start the review.")
|
| 137 |
+
yield (
|
| 138 |
+
_log_text(), "", "", "", "", "", "",
|
| 139 |
+
gr.update(interactive=True),
|
| 140 |
+
gr.update(visible=show_log),
|
| 141 |
+
gr.update(visible=show_raw_json),
|
| 142 |
+
)
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
# Check environment
|
| 146 |
+
ok, msg = _check_env()
|
| 147 |
+
if not ok:
|
| 148 |
+
gr.Error(msg)
|
| 149 |
+
_log(f"❌ {msg}")
|
| 150 |
+
yield (
|
| 151 |
+
_log_text(), "", "", "", "", "", "",
|
| 152 |
+
gr.update(interactive=True),
|
| 153 |
+
gr.update(visible=show_log),
|
| 154 |
+
gr.update(visible=show_raw_json),
|
| 155 |
+
)
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
# Start pipeline
|
| 159 |
+
_log("🚀 Pipeline started.")
|
| 160 |
+
yield (
|
| 161 |
+
_log_text(), "", "", "", "", "", "",
|
| 162 |
+
gr.update(interactive=False),
|
| 163 |
+
gr.update(visible=show_log),
|
| 164 |
+
gr.update(visible=show_raw_json),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
# Normalize file path
|
| 169 |
+
if isinstance(file_obj, dict) and "name" in file_obj:
|
| 170 |
+
src_path = Path(file_obj["name"])
|
| 171 |
+
else:
|
| 172 |
+
src_path = Path(getattr(file_obj, "name", "") or str(file_obj))
|
| 173 |
+
|
| 174 |
+
if not src_path or not src_path.exists():
|
| 175 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
| 176 |
+
tmp_path = Path(tmp.name)
|
| 177 |
+
if hasattr(file_obj, "read"):
|
| 178 |
+
tmp.write(file_obj.read())
|
| 179 |
+
src_path = tmp_path
|
| 180 |
+
|
| 181 |
+
# Initialize output variables
|
| 182 |
+
status = f"📄 Extracting text from PDF: {src_path.name}..."
|
| 183 |
+
_log(status)
|
| 184 |
+
yield (
|
| 185 |
+
_log_text(), "", "", "", "", "", "",
|
| 186 |
+
gr.update(interactive=False),
|
| 187 |
+
gr.update(visible=show_log),
|
| 188 |
+
gr.update(visible=show_raw_json),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
initial = ""
|
| 192 |
+
related_html = ""
|
| 193 |
+
results_html = ""
|
| 194 |
+
insights_html = ""
|
| 195 |
+
final_md = ""
|
| 196 |
+
raw_json = ""
|
| 197 |
+
|
| 198 |
+
# Run the stepwise pipeline
|
| 199 |
+
for ev in run_single_paper_review_from_pdf_stepwise(
|
| 200 |
+
str(src_path),
|
| 201 |
+
api_base_url=api_base_url or None,
|
| 202 |
+
api_key=api_key or None,
|
| 203 |
+
model_name=model_name or None,
|
| 204 |
+
enable_logging=True,
|
| 205 |
+
verbose=True,
|
| 206 |
+
):
|
| 207 |
+
stage = ev.get("stage")
|
| 208 |
+
|
| 209 |
+
# Handle step-level errors
|
| 210 |
+
if stage == "results_analysis_error":
|
| 211 |
+
err = ev.get("error", "Unknown error")
|
| 212 |
+
gr.Warning(f"Results analysis failed: {err}")
|
| 213 |
+
_log(f"⚠️ Results analysis failed: {err}")
|
| 214 |
+
yield (
|
| 215 |
+
_log_text(), initial, related_html, results_html,
|
| 216 |
+
insights_html, final_md, raw_json,
|
| 217 |
+
gr.update(interactive=False),
|
| 218 |
+
gr.update(visible=show_log),
|
| 219 |
+
gr.update(visible=show_raw_json),
|
| 220 |
+
)
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
if stage == "insights_error":
|
| 224 |
+
err = ev.get("error", "Unknown error")
|
| 225 |
+
gr.Warning(f"Insight mining failed: {err}")
|
| 226 |
+
_log(f"⚠️ Insight mining failed: {err}")
|
| 227 |
+
yield (
|
| 228 |
+
_log_text(), initial, related_html, results_html,
|
| 229 |
+
insights_html, final_md, raw_json,
|
| 230 |
+
gr.update(interactive=False),
|
| 231 |
+
gr.update(visible=show_log),
|
| 232 |
+
gr.update(visible=show_raw_json),
|
| 233 |
+
)
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
if stage == "related_work_error":
|
| 237 |
+
err = ev.get("error", "Unknown error")
|
| 238 |
+
gr.Warning(f"Related work search failed: {err}")
|
| 239 |
+
_log(f"⚠️ Related work search failed: {err}")
|
| 240 |
+
yield (
|
| 241 |
+
_log_text(), initial, related_html, results_html,
|
| 242 |
+
insights_html, final_md, raw_json,
|
| 243 |
+
gr.update(interactive=False),
|
| 244 |
+
gr.update(visible=show_log),
|
| 245 |
+
gr.update(visible=show_raw_json),
|
| 246 |
+
)
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
# Process each pipeline stage
|
| 250 |
+
if stage == "extract_pdf":
|
| 251 |
+
status = f"📄 Extracting text from PDF: {src_path.name}..."
|
| 252 |
+
_log(status)
|
| 253 |
+
|
| 254 |
+
elif stage == "parsed_pdf_text":
|
| 255 |
+
_log("✅ Step 0: Extracting PDF text — done")
|
| 256 |
+
_log("⏳ Step 1: Initial review draft — started")
|
| 257 |
+
yield (
|
| 258 |
+
_log_text(), initial, related_html, results_html,
|
| 259 |
+
insights_html, final_md, raw_json,
|
| 260 |
+
gr.update(interactive=False),
|
| 261 |
+
gr.update(visible=show_log),
|
| 262 |
+
gr.update(visible=show_raw_json),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
elif stage == "initial_review":
|
| 266 |
+
tmp = {"initial_review": ev.get("initial_review", {})}
|
| 267 |
+
tmp["title"] = ev.get("title") or tmp["initial_review"].get("title")
|
| 268 |
+
tmp["abstract"] = ev.get("abstract") or tmp["initial_review"].get("abstract")
|
| 269 |
+
initial = format_initial_review_html(tmp)
|
| 270 |
+
_log("✅ Step 1: Initial review draft — done")
|
| 271 |
+
_log("⏳ Step 2: Results analysis — started")
|
| 272 |
+
yield (
|
| 273 |
+
_log_text(), initial, related_html, results_html,
|
| 274 |
+
insights_html, final_md, raw_json,
|
| 275 |
+
gr.update(interactive=False),
|
| 276 |
+
gr.update(visible=show_log),
|
| 277 |
+
gr.update(visible=show_raw_json),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
elif stage == "results_analysis":
|
| 281 |
+
tmp = {"results_analyzer_json": ev.get("results_analyzer_json")}
|
| 282 |
+
results_html = format_results_html(tmp)
|
| 283 |
+
_log("✅ Step 2: Results analysis — done")
|
| 284 |
+
_log("⏳ Step 3: Insight mining — started")
|
| 285 |
+
yield (
|
| 286 |
+
_log_text(), initial, related_html, results_html,
|
| 287 |
+
insights_html, final_md, raw_json,
|
| 288 |
+
gr.update(interactive=False),
|
| 289 |
+
gr.update(visible=show_log),
|
| 290 |
+
gr.update(visible=show_raw_json),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
elif stage == "insights":
|
| 294 |
+
tmp = {"insight_miner_json": ev.get("insight_miner_json")}
|
| 295 |
+
insights_html = format_insights_html(tmp)
|
| 296 |
+
_log("✅ Step 3: Insight mining — done")
|
| 297 |
+
_log("⏳ Step 4: Related work — started")
|
| 298 |
+
yield (
|
| 299 |
+
_log_text(), initial, related_html, results_html,
|
| 300 |
+
insights_html, final_md, raw_json,
|
| 301 |
+
gr.update(interactive=False),
|
| 302 |
+
gr.update(visible=show_log),
|
| 303 |
+
gr.update(visible=show_raw_json),
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
elif stage == "related_work":
|
| 307 |
+
tmp = {
|
| 308 |
+
"related_work_json_list": ev.get("related_work_json_list"),
|
| 309 |
+
"search_keywords": ev.get("search_keywords"),
|
| 310 |
+
}
|
| 311 |
+
related_html = format_related_work_html(tmp)
|
| 312 |
+
_log("✅ Step 4: Related work — done")
|
| 313 |
+
_log("⏳ Step 5: Final refinement — started")
|
| 314 |
+
yield (
|
| 315 |
+
_log_text(), initial, related_html, results_html,
|
| 316 |
+
insights_html, final_md, raw_json,
|
| 317 |
+
gr.update(interactive=False),
|
| 318 |
+
gr.update(visible=show_log),
|
| 319 |
+
gr.update(visible=show_raw_json),
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
elif stage == "final":
|
| 323 |
+
review = ev.get("review", {}) or {}
|
| 324 |
+
initial = format_initial_review_html(review)
|
| 325 |
+
related_html = format_related_work_html(review) if not related_html else related_html
|
| 326 |
+
results_html = format_results_html(review) if not results_html else results_html
|
| 327 |
+
insights_html = format_insights_html(review) if not insights_html else insights_html
|
| 328 |
+
final_md = format_final_review(review)
|
| 329 |
+
raw_json = format_raw_json(review)
|
| 330 |
+
_log("✅ Step 5: Final refinement — done")
|
| 331 |
+
_log(f"🎉 Review complete for: {src_path.name}")
|
| 332 |
+
|
| 333 |
+
else:
|
| 334 |
+
_log(f"⏳ Working... ({stage})")
|
| 335 |
+
|
| 336 |
+
yield (
|
| 337 |
+
_log_text(), initial, related_html, results_html,
|
| 338 |
+
insights_html, final_md, raw_json,
|
| 339 |
+
gr.update(interactive=False),
|
| 340 |
+
gr.update(visible=show_log),
|
| 341 |
+
gr.update(visible=show_raw_json),
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Re-enable button at end
|
| 345 |
+
yield (
|
| 346 |
+
_log_text(), initial, related_html, results_html,
|
| 347 |
+
insights_html, final_md, raw_json,
|
| 348 |
+
gr.update(interactive=True),
|
| 349 |
+
gr.update(visible=show_log),
|
| 350 |
+
gr.update(visible=show_raw_json),
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
import traceback
|
| 355 |
+
error_msg = f"❌ Error during review: {str(e)}"
|
| 356 |
+
error_details = traceback.format_exc()
|
| 357 |
+
gr.Error(f"{error_msg}\n\nDetails: {error_details[:500]}")
|
| 358 |
+
_log(error_msg)
|
| 359 |
+
yield (
|
| 360 |
+
_log_text(), "", "", "", "", "", "",
|
| 361 |
+
gr.update(interactive=True),
|
| 362 |
+
gr.update(visible=show_log),
|
| 363 |
+
gr.update(visible=show_raw_json),
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# ============================================================================
|
| 368 |
+
# Build the Gradio App
|
| 369 |
+
# ============================================================================
|
| 370 |
+
|
| 371 |
+
with gr.Blocks(
|
| 372 |
+
title=APP_TITLE,
|
| 373 |
+
css=get_custom_css(),
|
| 374 |
+
theme=gr.themes.Soft(),
|
| 375 |
+
) as demo:
|
| 376 |
+
|
| 377 |
+
# Header section
|
| 378 |
+
create_header()
|
| 379 |
+
|
| 380 |
+
# Main content: two-column layout
|
| 381 |
+
with gr.Row():
|
| 382 |
+
# Left column: Upload and settings
|
| 383 |
+
with gr.Column(scale=2, elem_classes=["panel-card"]):
|
| 384 |
+
pdf_input, run_button = create_upload_section()
|
| 385 |
+
|
| 386 |
+
# Advanced settings (collapsed by default)
|
| 387 |
+
(
|
| 388 |
+
api_base_url_in,
|
| 389 |
+
api_key_in,
|
| 390 |
+
model_name_in,
|
| 391 |
+
show_log_toggle,
|
| 392 |
+
show_raw_json_toggle,
|
| 393 |
+
) = create_advanced_settings()
|
| 394 |
+
|
| 395 |
+
# Right column: Results (built from component)
|
| 396 |
+
with gr.Column(scale=3, elem_classes=["panel-card", "results-panel"]):
|
| 397 |
+
(
|
| 398 |
+
initial_html,
|
| 399 |
+
results_html,
|
| 400 |
+
insights_html,
|
| 401 |
+
related_html,
|
| 402 |
+
final_md,
|
| 403 |
+
status_output,
|
| 404 |
+
raw_json_md,
|
| 405 |
+
log_accordion,
|
| 406 |
+
raw_json_tab,
|
| 407 |
+
download_json_btn,
|
| 408 |
+
) = create_results_panel(show_log=False, show_raw_json=False)
|
| 409 |
+
|
| 410 |
+
# Toggle visibility of log accordion
|
| 411 |
+
show_log_toggle.change(
|
| 412 |
+
fn=lambda x: gr.update(visible=x),
|
| 413 |
+
inputs=[show_log_toggle],
|
| 414 |
+
outputs=[log_accordion],
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# Toggle visibility of raw JSON tab
|
| 418 |
+
show_raw_json_toggle.change(
|
| 419 |
+
fn=lambda x: gr.update(visible=x),
|
| 420 |
+
inputs=[show_raw_json_toggle],
|
| 421 |
+
outputs=[raw_json_tab],
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Download raw JSON as file
|
| 425 |
+
download_json_btn.click(
|
| 426 |
+
fn=_raw_json_md_to_file,
|
| 427 |
+
inputs=[raw_json_md],
|
| 428 |
+
outputs=[download_json_btn],
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Main review button click handler
|
| 432 |
+
run_button.click(
|
| 433 |
+
fn=review_pdf_file,
|
| 434 |
+
inputs=[
|
| 435 |
+
pdf_input,
|
| 436 |
+
api_base_url_in,
|
| 437 |
+
api_key_in,
|
| 438 |
+
model_name_in,
|
| 439 |
+
show_log_toggle,
|
| 440 |
+
show_raw_json_toggle,
|
| 441 |
+
],
|
| 442 |
+
outputs=[
|
| 443 |
+
status_output,
|
| 444 |
+
initial_html,
|
| 445 |
+
related_html,
|
| 446 |
+
results_html,
|
| 447 |
+
insights_html,
|
| 448 |
+
final_md,
|
| 449 |
+
raw_json_md,
|
| 450 |
+
run_button,
|
| 451 |
+
log_accordion,
|
| 452 |
+
raw_json_tab,
|
| 453 |
+
],
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# Footer
|
| 457 |
+
gr.HTML("""
|
| 458 |
+
<div class="app-footer">
|
| 459 |
+
<p>🔬 Review Grounder · AI-Powered Research Paper Review</p>
|
| 460 |
+
<p>© 2026 ReviewGrounder. All rights reserved.</p>
|
| 461 |
+
</div>
|
| 462 |
+
""")
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# ============================================================================
|
| 466 |
+
# Entry Point
|
| 467 |
+
# ============================================================================
|
| 468 |
+
|
| 469 |
+
if __name__ == "__main__":
|
| 470 |
+
demo.launch()
|
example.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example script for running a single-paper review.
|
| 3 |
+
|
| 4 |
+
This example is intentionally thin and delegates to the reusable
|
| 5 |
+
`review_single_paper_from_text` helper, which is what a Hugging Face
|
| 6 |
+
Space backend would call after performing PDF-to-text conversion.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
# Add project root to path for imports
|
| 15 |
+
project_root = Path(__file__).parent.parent
|
| 16 |
+
if str(project_root) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(project_root))
|
| 18 |
+
|
| 19 |
+
from src.reviewer_agent.single_paper_inference import review_single_paper_from_text
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> None:
|
| 27 |
+
"""Run a minimal single-paper review using the first example paper."""
|
| 28 |
+
json_path = project_root / "examples" / "example_papers.json"
|
| 29 |
+
|
| 30 |
+
with open(json_path, "r") as f:
|
| 31 |
+
data = json.load(f)
|
| 32 |
+
|
| 33 |
+
# For demonstration we take only the first paper
|
| 34 |
+
first_paper = data[0]
|
| 35 |
+
paper_text = first_paper.get("paper_context", "")
|
| 36 |
+
|
| 37 |
+
logger.info("Running single-paper review from example_papers.json...")
|
| 38 |
+
review = review_single_paper_from_text(
|
| 39 |
+
paper_text,
|
| 40 |
+
keywords=first_paper.get("keywords"),
|
| 41 |
+
# Use config defaults for review_format and verbosity
|
| 42 |
+
enable_logging=True,
|
| 43 |
+
verbose=False,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Save the review content to a JSON file
|
| 47 |
+
with open("review_content.json", "w") as f:
|
| 48 |
+
json.dump([review], f, indent=2)
|
| 49 |
+
|
| 50 |
+
logger.info("Review saved to review_content.json")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
gradio_app/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio integration package for the anonymized review system.
|
| 3 |
+
|
| 4 |
+
This package contains:
|
| 5 |
+
- Lightweight utilities that wrap the single-paper review pipeline for UI use
|
| 6 |
+
- The Gradio app definition used for Hugging Face Spaces deployment
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
__all__ = ["app", "utils_single_paper_inference"]
|
gradio_app/app.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Review Grounder - Gradio App
|
| 3 |
+
|
| 4 |
+
Main entry point for the Hugging Face Space.
|
| 5 |
+
This module orchestrates the UI components and handles the review pipeline.
|
| 6 |
+
|
| 7 |
+
The app allows users to:
|
| 8 |
+
1. Upload a research paper in PDF format
|
| 9 |
+
2. Configure LLM settings (optional, uses OpenAI defaults)
|
| 10 |
+
3. Generate a comprehensive AI-powered review
|
| 11 |
+
4. View intermediate results from each pipeline stage
|
| 12 |
+
|
| 13 |
+
Components are organized in the `components/` directory for maintainability.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import tempfile
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Tuple, Iterator
|
| 24 |
+
|
| 25 |
+
import gradio as gr
|
| 26 |
+
|
| 27 |
+
# Import utility for running the review pipeline
|
| 28 |
+
from utils_single_paper_inference import (
|
| 29 |
+
run_single_paper_review_from_pdf_stepwise,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Import UI components
|
| 33 |
+
from components import (
|
| 34 |
+
get_custom_css,
|
| 35 |
+
create_header,
|
| 36 |
+
create_upload_section,
|
| 37 |
+
create_advanced_settings,
|
| 38 |
+
create_results_panel,
|
| 39 |
+
format_initial_review_html,
|
| 40 |
+
format_related_work_html,
|
| 41 |
+
format_results_html,
|
| 42 |
+
format_insights_html,
|
| 43 |
+
format_final_review,
|
| 44 |
+
format_raw_json,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ============================================================================
|
| 49 |
+
# App Configuration
|
| 50 |
+
# ============================================================================
|
| 51 |
+
|
| 52 |
+
APP_TITLE = "Review Grounder"
|
| 53 |
+
APP_DESCRIPTION = "AI-Powered Research Paper Review"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _raw_json_md_to_file(raw_json_md: str) -> str:
|
| 57 |
+
"""
|
| 58 |
+
Extract JSON from Raw JSON markdown (```json ... ```) and write to a temp file.
|
| 59 |
+
Returns the file path for gr.DownloadButton.
|
| 60 |
+
"""
|
| 61 |
+
if not raw_json_md or not raw_json_md.strip():
|
| 62 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
|
| 63 |
+
f.write("{}")
|
| 64 |
+
return f.name
|
| 65 |
+
text = raw_json_md.strip()
|
| 66 |
+
match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
|
| 67 |
+
if match:
|
| 68 |
+
text = match.group(1).strip()
|
| 69 |
+
fd, path = tempfile.mkstemp(suffix=".json", prefix="review_")
|
| 70 |
+
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
| 71 |
+
f.write(text)
|
| 72 |
+
return path
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ============================================================================
|
| 76 |
+
# Environment Check
|
| 77 |
+
# ============================================================================
|
| 78 |
+
|
| 79 |
+
def _check_env() -> Tuple[bool, str]:
|
| 80 |
+
"""
|
| 81 |
+
Check for required environment variables.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Tuple of (success, message)
|
| 85 |
+
"""
|
| 86 |
+
missing = []
|
| 87 |
+
if not os.environ.get("ASTA_API_KEY"):
|
| 88 |
+
missing.append("ASTA_API_KEY")
|
| 89 |
+
|
| 90 |
+
if missing:
|
| 91 |
+
return False, (
|
| 92 |
+
"Missing environment variables: "
|
| 93 |
+
+ ", ".join(missing)
|
| 94 |
+
+ ".\nPlease configure them in your Hugging Face Space settings."
|
| 95 |
+
)
|
| 96 |
+
return True, "Environment variables detected correctly."
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ============================================================================
|
| 100 |
+
# Review Pipeline Handler
|
| 101 |
+
# ============================================================================
|
| 102 |
+
|
| 103 |
+
def review_pdf_file(
|
| 104 |
+
file_obj,
|
| 105 |
+
api_base_url: str,
|
| 106 |
+
api_key: str,
|
| 107 |
+
model_name: str,
|
| 108 |
+
show_log: bool,
|
| 109 |
+
show_raw_json: bool,
|
| 110 |
+
) -> Iterator[Tuple[str, str, str, str, str, str, str, gr.update, gr.update, gr.update]]:
|
| 111 |
+
"""
|
| 112 |
+
Main callback: process PDF through the review pipeline with real-time updates.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
file_obj: Uploaded PDF file
|
| 116 |
+
api_base_url: LLM API endpoint URL
|
| 117 |
+
api_key: API key for LLM provider
|
| 118 |
+
model_name: Model identifier
|
| 119 |
+
show_log: Whether to display the execution log
|
| 120 |
+
show_raw_json: Whether to display raw JSON output
|
| 121 |
+
|
| 122 |
+
Yields:
|
| 123 |
+
Tuple of all output component updates (no overview)
|
| 124 |
+
"""
|
| 125 |
+
log_lines: list[str] = []
|
| 126 |
+
|
| 127 |
+
def _log(msg: str) -> None:
|
| 128 |
+
log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
|
| 129 |
+
|
| 130 |
+
def _log_text() -> str:
|
| 131 |
+
return "\n".join(log_lines) if log_lines else ""
|
| 132 |
+
|
| 133 |
+
# Validate file upload
|
| 134 |
+
if file_obj is None:
|
| 135 |
+
gr.Warning("Please upload a PDF file to start the review.")
|
| 136 |
+
_log("⚠️ Please upload a PDF file to start the review.")
|
| 137 |
+
yield (
|
| 138 |
+
_log_text(), "", "", "", "", "", "",
|
| 139 |
+
gr.update(interactive=True),
|
| 140 |
+
gr.update(visible=show_log),
|
| 141 |
+
gr.update(visible=show_raw_json),
|
| 142 |
+
)
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
# Check environment
|
| 146 |
+
ok, msg = _check_env()
|
| 147 |
+
if not ok:
|
| 148 |
+
gr.Error(msg)
|
| 149 |
+
_log(f"❌ {msg}")
|
| 150 |
+
yield (
|
| 151 |
+
_log_text(), "", "", "", "", "", "",
|
| 152 |
+
gr.update(interactive=True),
|
| 153 |
+
gr.update(visible=show_log),
|
| 154 |
+
gr.update(visible=show_raw_json),
|
| 155 |
+
)
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
# Start pipeline
|
| 159 |
+
_log("🚀 Pipeline started.")
|
| 160 |
+
yield (
|
| 161 |
+
_log_text(), "", "", "", "", "", "",
|
| 162 |
+
gr.update(interactive=False),
|
| 163 |
+
gr.update(visible=show_log),
|
| 164 |
+
gr.update(visible=show_raw_json),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
# Normalize file path
|
| 169 |
+
if isinstance(file_obj, dict) and "name" in file_obj:
|
| 170 |
+
src_path = Path(file_obj["name"])
|
| 171 |
+
else:
|
| 172 |
+
src_path = Path(getattr(file_obj, "name", "") or str(file_obj))
|
| 173 |
+
|
| 174 |
+
if not src_path or not src_path.exists():
|
| 175 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
| 176 |
+
tmp_path = Path(tmp.name)
|
| 177 |
+
if hasattr(file_obj, "read"):
|
| 178 |
+
tmp.write(file_obj.read())
|
| 179 |
+
src_path = tmp_path
|
| 180 |
+
|
| 181 |
+
# Initialize output variables
|
| 182 |
+
status = f"📄 Extracting text from PDF: {src_path.name}..."
|
| 183 |
+
_log(status)
|
| 184 |
+
yield (
|
| 185 |
+
_log_text(), "", "", "", "", "", "",
|
| 186 |
+
gr.update(interactive=False),
|
| 187 |
+
gr.update(visible=show_log),
|
| 188 |
+
gr.update(visible=show_raw_json),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
initial = ""
|
| 192 |
+
related_html = ""
|
| 193 |
+
results_html = ""
|
| 194 |
+
insights_html = ""
|
| 195 |
+
final_md = ""
|
| 196 |
+
raw_json = ""
|
| 197 |
+
|
| 198 |
+
# Run the stepwise pipeline
|
| 199 |
+
for ev in run_single_paper_review_from_pdf_stepwise(
|
| 200 |
+
str(src_path),
|
| 201 |
+
api_base_url=api_base_url or None,
|
| 202 |
+
api_key=api_key or None,
|
| 203 |
+
model_name=model_name or None,
|
| 204 |
+
enable_logging=True,
|
| 205 |
+
verbose=True,
|
| 206 |
+
):
|
| 207 |
+
stage = ev.get("stage")
|
| 208 |
+
|
| 209 |
+
# Handle step-level errors
|
| 210 |
+
if stage == "results_analysis_error":
|
| 211 |
+
err = ev.get("error", "Unknown error")
|
| 212 |
+
gr.Warning(f"Results analysis failed: {err}")
|
| 213 |
+
_log(f"⚠️ Results analysis failed: {err}")
|
| 214 |
+
yield (
|
| 215 |
+
_log_text(), initial, related_html, results_html,
|
| 216 |
+
insights_html, final_md, raw_json,
|
| 217 |
+
gr.update(interactive=False),
|
| 218 |
+
gr.update(visible=show_log),
|
| 219 |
+
gr.update(visible=show_raw_json),
|
| 220 |
+
)
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
if stage == "insights_error":
|
| 224 |
+
err = ev.get("error", "Unknown error")
|
| 225 |
+
gr.Warning(f"Insight mining failed: {err}")
|
| 226 |
+
_log(f"⚠️ Insight mining failed: {err}")
|
| 227 |
+
yield (
|
| 228 |
+
_log_text(), initial, related_html, results_html,
|
| 229 |
+
insights_html, final_md, raw_json,
|
| 230 |
+
gr.update(interactive=False),
|
| 231 |
+
gr.update(visible=show_log),
|
| 232 |
+
gr.update(visible=show_raw_json),
|
| 233 |
+
)
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
if stage == "related_work_error":
|
| 237 |
+
err = ev.get("error", "Unknown error")
|
| 238 |
+
gr.Warning(f"Related work search failed: {err}")
|
| 239 |
+
_log(f"⚠️ Related work search failed: {err}")
|
| 240 |
+
yield (
|
| 241 |
+
_log_text(), initial, related_html, results_html,
|
| 242 |
+
insights_html, final_md, raw_json,
|
| 243 |
+
gr.update(interactive=False),
|
| 244 |
+
gr.update(visible=show_log),
|
| 245 |
+
gr.update(visible=show_raw_json),
|
| 246 |
+
)
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
# Process each pipeline stage
|
| 250 |
+
if stage == "extract_pdf":
|
| 251 |
+
status = f"📄 Extracting text from PDF: {src_path.name}..."
|
| 252 |
+
_log(status)
|
| 253 |
+
|
| 254 |
+
elif stage == "parsed_pdf_text":
|
| 255 |
+
_log("✅ Step 0: Extracting PDF text — done")
|
| 256 |
+
_log("⏳ Step 1: Initial review draft — started")
|
| 257 |
+
yield (
|
| 258 |
+
_log_text(), initial, related_html, results_html,
|
| 259 |
+
insights_html, final_md, raw_json,
|
| 260 |
+
gr.update(interactive=False),
|
| 261 |
+
gr.update(visible=show_log),
|
| 262 |
+
gr.update(visible=show_raw_json),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
elif stage == "initial_review":
|
| 266 |
+
tmp = {"initial_review": ev.get("initial_review", {})}
|
| 267 |
+
tmp["title"] = ev.get("title") or tmp["initial_review"].get("title")
|
| 268 |
+
tmp["abstract"] = ev.get("abstract") or tmp["initial_review"].get("abstract")
|
| 269 |
+
initial = format_initial_review_html(tmp)
|
| 270 |
+
_log("✅ Step 1: Initial review draft — done")
|
| 271 |
+
_log("⏳ Step 2: Results analysis — started")
|
| 272 |
+
yield (
|
| 273 |
+
_log_text(), initial, related_html, results_html,
|
| 274 |
+
insights_html, final_md, raw_json,
|
| 275 |
+
gr.update(interactive=False),
|
| 276 |
+
gr.update(visible=show_log),
|
| 277 |
+
gr.update(visible=show_raw_json),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
elif stage == "results_analysis":
|
| 281 |
+
tmp = {"results_analyzer_json": ev.get("results_analyzer_json")}
|
| 282 |
+
results_html = format_results_html(tmp)
|
| 283 |
+
_log("✅ Step 2: Results analysis — done")
|
| 284 |
+
_log("⏳ Step 3: Insight mining — started")
|
| 285 |
+
yield (
|
| 286 |
+
_log_text(), initial, related_html, results_html,
|
| 287 |
+
insights_html, final_md, raw_json,
|
| 288 |
+
gr.update(interactive=False),
|
| 289 |
+
gr.update(visible=show_log),
|
| 290 |
+
gr.update(visible=show_raw_json),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
elif stage == "insights":
|
| 294 |
+
tmp = {"insight_miner_json": ev.get("insight_miner_json")}
|
| 295 |
+
insights_html = format_insights_html(tmp)
|
| 296 |
+
_log("✅ Step 3: Insight mining — done")
|
| 297 |
+
_log("⏳ Step 4: Related work — started")
|
| 298 |
+
yield (
|
| 299 |
+
_log_text(), initial, related_html, results_html,
|
| 300 |
+
insights_html, final_md, raw_json,
|
| 301 |
+
gr.update(interactive=False),
|
| 302 |
+
gr.update(visible=show_log),
|
| 303 |
+
gr.update(visible=show_raw_json),
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
elif stage == "related_work":
|
| 307 |
+
tmp = {
|
| 308 |
+
"related_work_json_list": ev.get("related_work_json_list"),
|
| 309 |
+
"search_keywords": ev.get("search_keywords"),
|
| 310 |
+
}
|
| 311 |
+
related_html = format_related_work_html(tmp)
|
| 312 |
+
_log("✅ Step 4: Related work — done")
|
| 313 |
+
_log("⏳ Step 5: Final refinement — started")
|
| 314 |
+
yield (
|
| 315 |
+
_log_text(), initial, related_html, results_html,
|
| 316 |
+
insights_html, final_md, raw_json,
|
| 317 |
+
gr.update(interactive=False),
|
| 318 |
+
gr.update(visible=show_log),
|
| 319 |
+
gr.update(visible=show_raw_json),
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
elif stage == "final":
|
| 323 |
+
review = ev.get("review", {}) or {}
|
| 324 |
+
initial = format_initial_review_html(review)
|
| 325 |
+
related_html = format_related_work_html(review) if not related_html else related_html
|
| 326 |
+
results_html = format_results_html(review) if not results_html else results_html
|
| 327 |
+
insights_html = format_insights_html(review) if not insights_html else insights_html
|
| 328 |
+
final_md = format_final_review(review)
|
| 329 |
+
raw_json = format_raw_json(review)
|
| 330 |
+
_log("✅ Step 5: Final refinement — done")
|
| 331 |
+
_log(f"🎉 Review complete for: {src_path.name}")
|
| 332 |
+
|
| 333 |
+
else:
|
| 334 |
+
_log(f"⏳ Working... ({stage})")
|
| 335 |
+
|
| 336 |
+
yield (
|
| 337 |
+
_log_text(), initial, related_html, results_html,
|
| 338 |
+
insights_html, final_md, raw_json,
|
| 339 |
+
gr.update(interactive=False),
|
| 340 |
+
gr.update(visible=show_log),
|
| 341 |
+
gr.update(visible=show_raw_json),
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Re-enable button at end
|
| 345 |
+
yield (
|
| 346 |
+
_log_text(), initial, related_html, results_html,
|
| 347 |
+
insights_html, final_md, raw_json,
|
| 348 |
+
gr.update(interactive=True),
|
| 349 |
+
gr.update(visible=show_log),
|
| 350 |
+
gr.update(visible=show_raw_json),
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
import traceback
|
| 355 |
+
error_msg = f"❌ Error during review: {str(e)}"
|
| 356 |
+
error_details = traceback.format_exc()
|
| 357 |
+
gr.Error(f"{error_msg}\n\nDetails: {error_details[:500]}")
|
| 358 |
+
_log(error_msg)
|
| 359 |
+
yield (
|
| 360 |
+
_log_text(), "", "", "", "", "", "",
|
| 361 |
+
gr.update(interactive=True),
|
| 362 |
+
gr.update(visible=show_log),
|
| 363 |
+
gr.update(visible=show_raw_json),
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# ============================================================================
|
| 368 |
+
# Build the Gradio App
|
| 369 |
+
# ============================================================================
|
| 370 |
+
|
| 371 |
+
with gr.Blocks(title=APP_TITLE) as demo:
|
| 372 |
+
|
| 373 |
+
# Header section
|
| 374 |
+
create_header()
|
| 375 |
+
|
| 376 |
+
# Main content: two-column layout
|
| 377 |
+
with gr.Row():
|
| 378 |
+
# Left column: Upload and settings
|
| 379 |
+
with gr.Column(scale=2, elem_classes=["panel-card"]):
|
| 380 |
+
pdf_input, run_button = create_upload_section()
|
| 381 |
+
|
| 382 |
+
# Advanced settings (collapsed by default)
|
| 383 |
+
(
|
| 384 |
+
api_base_url_in,
|
| 385 |
+
api_key_in,
|
| 386 |
+
model_name_in,
|
| 387 |
+
show_log_toggle,
|
| 388 |
+
show_raw_json_toggle,
|
| 389 |
+
) = create_advanced_settings()
|
| 390 |
+
|
| 391 |
+
# Right column: Results (built from component)
|
| 392 |
+
with gr.Column(scale=3, elem_classes=["panel-card", "results-panel"]):
|
| 393 |
+
(
|
| 394 |
+
initial_html,
|
| 395 |
+
results_html,
|
| 396 |
+
insights_html,
|
| 397 |
+
related_html,
|
| 398 |
+
final_md,
|
| 399 |
+
status_output,
|
| 400 |
+
raw_json_md,
|
| 401 |
+
log_accordion,
|
| 402 |
+
raw_json_tab,
|
| 403 |
+
download_json_btn,
|
| 404 |
+
) = create_results_panel(show_log=False, show_raw_json=False)
|
| 405 |
+
|
| 406 |
+
# Toggle visibility of log accordion
|
| 407 |
+
show_log_toggle.change(
|
| 408 |
+
fn=lambda x: gr.update(visible=x),
|
| 409 |
+
inputs=[show_log_toggle],
|
| 410 |
+
outputs=[log_accordion],
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Toggle visibility of raw JSON tab
|
| 414 |
+
show_raw_json_toggle.change(
|
| 415 |
+
fn=lambda x: gr.update(visible=x),
|
| 416 |
+
inputs=[show_raw_json_toggle],
|
| 417 |
+
outputs=[raw_json_tab],
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Download raw JSON as file
|
| 421 |
+
download_json_btn.click(
|
| 422 |
+
fn=_raw_json_md_to_file,
|
| 423 |
+
inputs=[raw_json_md],
|
| 424 |
+
outputs=[download_json_btn],
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Main review button click handler
|
| 428 |
+
run_button.click(
|
| 429 |
+
fn=review_pdf_file,
|
| 430 |
+
inputs=[
|
| 431 |
+
pdf_input,
|
| 432 |
+
api_base_url_in,
|
| 433 |
+
api_key_in,
|
| 434 |
+
model_name_in,
|
| 435 |
+
show_log_toggle,
|
| 436 |
+
show_raw_json_toggle,
|
| 437 |
+
],
|
| 438 |
+
outputs=[
|
| 439 |
+
status_output,
|
| 440 |
+
initial_html,
|
| 441 |
+
related_html,
|
| 442 |
+
results_html,
|
| 443 |
+
insights_html,
|
| 444 |
+
final_md,
|
| 445 |
+
raw_json_md,
|
| 446 |
+
run_button,
|
| 447 |
+
log_accordion,
|
| 448 |
+
raw_json_tab,
|
| 449 |
+
],
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Footer
|
| 453 |
+
gr.HTML("""
|
| 454 |
+
<div class="app-footer">
|
| 455 |
+
<p>🔬 Review Grounder · AI-Powered Research Paper Review</p>
|
| 456 |
+
<p>© 2026 ReviewGrounder. All rights reserved.</p>
|
| 457 |
+
</div>
|
| 458 |
+
""")
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# ============================================================================
|
| 462 |
+
# Entry Point
|
| 463 |
+
# ============================================================================
|
| 464 |
+
|
| 465 |
+
if __name__ == "__main__":
|
| 466 |
+
demo.launch(css=get_custom_css(), theme=gr.themes.Soft())
|
gradio_app/components/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Components package for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This package contains modular UI components for building
|
| 5 |
+
the Review Grounder interface.
|
| 6 |
+
|
| 7 |
+
Modules:
|
| 8 |
+
- styles: Custom CSS styles
|
| 9 |
+
- formatters: Data formatting utilities
|
| 10 |
+
- header: App header component
|
| 11 |
+
- upload_section: PDF upload and instructions
|
| 12 |
+
- settings: Advanced settings panel
|
| 13 |
+
- results_panel: Review results display
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from .styles import get_custom_css
|
| 17 |
+
from .formatters import (
|
| 18 |
+
safe_json_parse,
|
| 19 |
+
format_overview,
|
| 20 |
+
format_initial_review,
|
| 21 |
+
format_initial_review_html,
|
| 22 |
+
format_related_work_html,
|
| 23 |
+
format_results_html,
|
| 24 |
+
format_insights_html,
|
| 25 |
+
format_final_review,
|
| 26 |
+
format_raw_json,
|
| 27 |
+
)
|
| 28 |
+
from .header import create_header
|
| 29 |
+
from .upload_section import (
|
| 30 |
+
create_how_it_works,
|
| 31 |
+
create_upload_area,
|
| 32 |
+
create_action_buttons,
|
| 33 |
+
create_upload_section,
|
| 34 |
+
)
|
| 35 |
+
from .settings import (
|
| 36 |
+
create_advanced_settings,
|
| 37 |
+
DEFAULT_API_ENDPOINT,
|
| 38 |
+
DEFAULT_MODEL_NAME,
|
| 39 |
+
)
|
| 40 |
+
from .results_panel import (
|
| 41 |
+
create_results_placeholder,
|
| 42 |
+
create_results_panel,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
__all__ = [
|
| 47 |
+
# Styles
|
| 48 |
+
"get_custom_css",
|
| 49 |
+
# Formatters
|
| 50 |
+
"safe_json_parse",
|
| 51 |
+
"format_overview",
|
| 52 |
+
"format_initial_review",
|
| 53 |
+
"format_initial_review_html",
|
| 54 |
+
"format_related_work_html",
|
| 55 |
+
"format_results_html",
|
| 56 |
+
"format_insights_html",
|
| 57 |
+
"format_final_review",
|
| 58 |
+
"format_raw_json",
|
| 59 |
+
# Header
|
| 60 |
+
"create_header",
|
| 61 |
+
# Upload section
|
| 62 |
+
"create_how_it_works",
|
| 63 |
+
"create_upload_area",
|
| 64 |
+
"create_action_buttons",
|
| 65 |
+
"create_upload_section",
|
| 66 |
+
# Settings
|
| 67 |
+
"create_advanced_settings",
|
| 68 |
+
"DEFAULT_API_ENDPOINT",
|
| 69 |
+
"DEFAULT_MODEL_NAME",
|
| 70 |
+
# Results panel
|
| 71 |
+
"create_results_placeholder",
|
| 72 |
+
"create_results_panel",
|
| 73 |
+
]
|
gradio_app/components/formatters.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Formatting utilities for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This module contains all functions for formatting review data
|
| 5 |
+
into displayable HTML or Markdown for the UI components.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def safe_json_parse(value: Any) -> Any:
|
| 15 |
+
"""
|
| 16 |
+
Safely parse a JSON string or return the value if already parsed.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
value: A JSON string or already-parsed object
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Parsed JSON object or None if parsing fails
|
| 23 |
+
"""
|
| 24 |
+
if value is None:
|
| 25 |
+
return None
|
| 26 |
+
try:
|
| 27 |
+
if isinstance(value, str):
|
| 28 |
+
return json.loads(value)
|
| 29 |
+
return value
|
| 30 |
+
except Exception as e:
|
| 31 |
+
# print out the exact error
|
| 32 |
+
print(f"Error parsing JSON: {e}")
|
| 33 |
+
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def format_overview(review: dict) -> str:
|
| 38 |
+
"""
|
| 39 |
+
Format high-level overview: scores and keywords only.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
review: The review dictionary containing scores and metadata
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Formatted Markdown string with scores and search keywords
|
| 46 |
+
"""
|
| 47 |
+
if not review:
|
| 48 |
+
return "No review data."
|
| 49 |
+
|
| 50 |
+
scores = review.get("scores", {}) or {}
|
| 51 |
+
rating = scores.get("rating") or review.get("rating")
|
| 52 |
+
confidence = scores.get("confidence") or review.get("confidence")
|
| 53 |
+
decision = scores.get("decision") or review.get("decision")
|
| 54 |
+
|
| 55 |
+
parts = [
|
| 56 |
+
"### Scores",
|
| 57 |
+
f"- **Rating**: {rating if rating is not None else 'N/A'}",
|
| 58 |
+
f"- **Confidence**: {confidence if confidence is not None else 'N/A'}",
|
| 59 |
+
f"- **Decision**: {decision or 'N/A'}",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
keywords = review.get("search_keywords")
|
| 63 |
+
if keywords:
|
| 64 |
+
parts.append("")
|
| 65 |
+
parts.append("### Search Keywords")
|
| 66 |
+
parts.append("".join(f"- {k}\n" for k in keywords).rstrip())
|
| 67 |
+
|
| 68 |
+
return "\n".join(parts)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _escape_html(text: str) -> str:
|
| 72 |
+
"""Escape HTML special characters for safe display."""
|
| 73 |
+
if not text:
|
| 74 |
+
return ""
|
| 75 |
+
return (
|
| 76 |
+
str(text)
|
| 77 |
+
.replace("&", "&")
|
| 78 |
+
.replace("<", "<")
|
| 79 |
+
.replace(">", ">")
|
| 80 |
+
.replace('"', """)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def format_initial_review(review: dict) -> str:
|
| 85 |
+
"""
|
| 86 |
+
Format the initial draft review as plain text (legacy).
|
| 87 |
+
Prefer format_initial_review_html for UI display.
|
| 88 |
+
"""
|
| 89 |
+
initial = review.get("initial_review")
|
| 90 |
+
if not initial:
|
| 91 |
+
return "Initial draft review not available (pipeline may have failed early)."
|
| 92 |
+
text = initial.get("review") or ""
|
| 93 |
+
if not text:
|
| 94 |
+
return json.dumps(initial, indent=2, ensure_ascii=False)
|
| 95 |
+
return text
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def format_initial_review_html(review: dict) -> str:
|
| 99 |
+
"""
|
| 100 |
+
Format the initial draft review as styled HTML cards (never raw JSON string).
|
| 101 |
+
|
| 102 |
+
Renders Summary, scores, Strengths, Weaknesses, Questions in card/section
|
| 103 |
+
layout. If initial_review comes as a JSON string, parses it first then renders cards.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
review: The review dictionary containing initial_review
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
HTML string for display in gr.HTML
|
| 110 |
+
"""
|
| 111 |
+
initial = review.get("initial_review")
|
| 112 |
+
if not initial:
|
| 113 |
+
return "<p class='review-message'>Initial draft review not available (pipeline may have failed early).</p>"
|
| 114 |
+
|
| 115 |
+
# If backend passed a JSON string, parse to dict so we can render cards
|
| 116 |
+
if isinstance(initial, str):
|
| 117 |
+
initial = safe_json_parse(initial) or {}
|
| 118 |
+
if not initial:
|
| 119 |
+
return "<p class='review-message'>Initial draft data could not be parsed.</p>"
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Prefer structured HTML when we have typical draft fields
|
| 124 |
+
if _looks_like_raw_json(initial):
|
| 125 |
+
pass # fall through to structured format below
|
| 126 |
+
else:
|
| 127 |
+
print(f"[WARN] Failed to parse initial draft as JSON, treating it as single text.")
|
| 128 |
+
# Single "review" text only
|
| 129 |
+
text = initial.get("review") or ""
|
| 130 |
+
if text:
|
| 131 |
+
return f'<div class="review-draft-content"><div class="review-text">{_nl2br(text)}</div></div>'
|
| 132 |
+
|
| 133 |
+
# Structured format: summary, scores, strengths, weaknesses, questions, etc.
|
| 134 |
+
html = '<div class="initial-draft-card card-grid"><div class="card"><h4>📝 Initial Draft</h4>'
|
| 135 |
+
|
| 136 |
+
summary = initial.get("summary") or ""
|
| 137 |
+
if summary:
|
| 138 |
+
html += f'<div class="kv"><div class="k">Summary</div><div class="v">{_escape_html(summary)}</div></div>'
|
| 139 |
+
|
| 140 |
+
score_fields = [
|
| 141 |
+
("soundness", "Soundness"),
|
| 142 |
+
("presentation", "Presentation"),
|
| 143 |
+
("contribution", "Contribution"),
|
| 144 |
+
("rating", "Rating"),
|
| 145 |
+
("confidence", "Confidence"),
|
| 146 |
+
("decision", "Decision"),
|
| 147 |
+
]
|
| 148 |
+
for key, label in score_fields:
|
| 149 |
+
val = initial.get(key)
|
| 150 |
+
if val is not None and val != "":
|
| 151 |
+
html += f'<div class="kv"><div class="k">{label}</div><div class="v">{_escape_html(str(val))}</div></div>'
|
| 152 |
+
|
| 153 |
+
strengths = initial.get("strengths") or ""
|
| 154 |
+
if strengths:
|
| 155 |
+
html += f'<div class="kv"><div class="k">Strengths</div><div class="v">{_nl2br(strengths)}</div></div>'
|
| 156 |
+
|
| 157 |
+
weaknesses = initial.get("weaknesses") or ""
|
| 158 |
+
if weaknesses:
|
| 159 |
+
html += f'<div class="kv"><div class="k">Weaknesses</div><div class="v">{_nl2br(weaknesses)}</div></div>'
|
| 160 |
+
|
| 161 |
+
questions = initial.get("questions")
|
| 162 |
+
if questions:
|
| 163 |
+
if isinstance(questions, list):
|
| 164 |
+
q_text = "\n".join(f"• {q}" for q in questions if q)
|
| 165 |
+
else:
|
| 166 |
+
q_text = str(questions)
|
| 167 |
+
if q_text:
|
| 168 |
+
html += f'<div class="kv"><div class="k">Questions</div><div class="v">{_nl2br(q_text)}</div></div>'
|
| 169 |
+
|
| 170 |
+
html += "</div></div>"
|
| 171 |
+
return html
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _looks_like_raw_json(obj: Any) -> bool:
|
| 175 |
+
"""Heuristic: dict has typical review keys (summary, strengths) then treat as structured."""
|
| 176 |
+
if not isinstance(obj, dict):
|
| 177 |
+
return False
|
| 178 |
+
return any(k in obj for k in ("summary", "strengths", "weaknesses", "soundness", "rating"))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _nl2br(text: str) -> str:
|
| 182 |
+
"""Escape HTML and convert newlines to <br> for safe display."""
|
| 183 |
+
if not text:
|
| 184 |
+
return ""
|
| 185 |
+
return _escape_html(text).replace("\n", "<br>\n")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def format_related_work_html(review: dict) -> str:
|
| 189 |
+
"""
|
| 190 |
+
Format related work as HTML with styled cards.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
review: The review dictionary containing related_work_json_list
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
HTML string with related work cards
|
| 197 |
+
"""
|
| 198 |
+
rw = review.get("related_work_json_list")
|
| 199 |
+
if not rw:
|
| 200 |
+
return "<p>No related work information available.</p>"
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
data = json.loads(rw) if isinstance(rw, str) else rw
|
| 204 |
+
except Exception:
|
| 205 |
+
return f"<p>Error parsing related work data: {str(rw)[:200]}</p>"
|
| 206 |
+
|
| 207 |
+
if not data:
|
| 208 |
+
return "<p>No related work summaries found.</p>"
|
| 209 |
+
|
| 210 |
+
html = '<div class="related-work-container"><h3>Related Work Summaries</h3>'
|
| 211 |
+
|
| 212 |
+
for idx, item in enumerate(data, start=1):
|
| 213 |
+
summary = item.get("summary", "").strip()
|
| 214 |
+
main_methods = item.get("main_methods", "").strip()
|
| 215 |
+
key_findings = item.get("key_findings", "").strip()
|
| 216 |
+
relation = item.get("relation", "").strip()
|
| 217 |
+
|
| 218 |
+
html += f'<div class="related-paper-card">'
|
| 219 |
+
html += f'<div class="paper-header">{idx}. {summary[:100] or "Related paper"}...</div>'
|
| 220 |
+
|
| 221 |
+
if summary:
|
| 222 |
+
html += f'''
|
| 223 |
+
<div class="paper-field">
|
| 224 |
+
<div class="paper-field-label">Summary</div>
|
| 225 |
+
<div class="paper-field-value">{summary}</div>
|
| 226 |
+
</div>
|
| 227 |
+
'''
|
| 228 |
+
if main_methods:
|
| 229 |
+
html += f'''
|
| 230 |
+
<div class="paper-field">
|
| 231 |
+
<div class="paper-field-label">Main Methods</div>
|
| 232 |
+
<div class="paper-field-value">{main_methods}</div>
|
| 233 |
+
</div>
|
| 234 |
+
'''
|
| 235 |
+
if key_findings:
|
| 236 |
+
html += f'''
|
| 237 |
+
<div class="paper-field">
|
| 238 |
+
<div class="paper-field-label">Key Findings</div>
|
| 239 |
+
<div class="paper-field-value">{key_findings}</div>
|
| 240 |
+
</div>
|
| 241 |
+
'''
|
| 242 |
+
if relation:
|
| 243 |
+
html += f'''
|
| 244 |
+
<div class="paper-field">
|
| 245 |
+
<div class="paper-field-label">Relation</div>
|
| 246 |
+
<div class="paper-field-value">{relation}</div>
|
| 247 |
+
</div>
|
| 248 |
+
'''
|
| 249 |
+
html += "</div>"
|
| 250 |
+
|
| 251 |
+
html += "</div>"
|
| 252 |
+
return html
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _render_review_issues_html(issues: dict) -> str:
|
| 256 |
+
"""
|
| 257 |
+
Render review issues (incorrect, missing, needs specificity) as HTML.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
issues: Dictionary containing issue categories
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
HTML string with formatted issues
|
| 264 |
+
"""
|
| 265 |
+
if not issues:
|
| 266 |
+
return ""
|
| 267 |
+
|
| 268 |
+
def render_issue_list(title: str, items: list) -> str:
|
| 269 |
+
if not items:
|
| 270 |
+
return ""
|
| 271 |
+
blocks = []
|
| 272 |
+
for it in items:
|
| 273 |
+
if isinstance(it, dict):
|
| 274 |
+
head = (
|
| 275 |
+
it.get("review_claim")
|
| 276 |
+
or it.get("what_missing")
|
| 277 |
+
or it.get("review_text")
|
| 278 |
+
or "Issue"
|
| 279 |
+
)
|
| 280 |
+
body_parts = []
|
| 281 |
+
for k in ["why_wrong", "why_important", "how_to_fix", "evidence"]:
|
| 282 |
+
if it.get(k):
|
| 283 |
+
body_parts.append(
|
| 284 |
+
f"<div class='k'>{k.replace('_', ' ').title()}</div>"
|
| 285 |
+
f"<div class='v'>{it.get(k)}</div>"
|
| 286 |
+
)
|
| 287 |
+
body = "".join(body_parts) or (
|
| 288 |
+
f"<div class='v mono'>{json.dumps(it, indent=2, ensure_ascii=False)}</div>"
|
| 289 |
+
)
|
| 290 |
+
blocks.append(f"<details><summary>{head}</summary>{body}</details>")
|
| 291 |
+
else:
|
| 292 |
+
blocks.append(f"<div class='v'>{str(it)}</div>")
|
| 293 |
+
return (
|
| 294 |
+
f'<div class="kv"><div class="k">{title}</div><div class="v">'
|
| 295 |
+
+ "".join(blocks)
|
| 296 |
+
+ "</div></div>"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
html = ""
|
| 300 |
+
html += render_issue_list("Incorrect / Hallucinated", issues.get("incorrect_or_hallucinated", []))
|
| 301 |
+
html += render_issue_list("Missing Key Points", issues.get("missing_key_points", []))
|
| 302 |
+
html += render_issue_list("Needs Specificity", issues.get("needs_specificity", []))
|
| 303 |
+
return html
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _render_rewrite_suggestions_html(suggestions: list) -> str:
|
| 307 |
+
"""
|
| 308 |
+
Render rewrite suggestions as collapsible HTML blocks.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
suggestions: List of rewrite suggestion dictionaries
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
HTML string with formatted suggestions
|
| 315 |
+
"""
|
| 316 |
+
if not suggestions:
|
| 317 |
+
return ""
|
| 318 |
+
|
| 319 |
+
blocks = []
|
| 320 |
+
for s in suggestions:
|
| 321 |
+
if isinstance(s, dict):
|
| 322 |
+
head = f"{s.get('apply_to', 'Rewrite')} · {s.get('target', '')}".strip(" ·")
|
| 323 |
+
suggested = s.get("suggested_text", "")
|
| 324 |
+
evidence = s.get("evidence", "")
|
| 325 |
+
body = ""
|
| 326 |
+
if suggested:
|
| 327 |
+
body += f"<div class='k'>Suggested Text</div><div class='v'>{suggested}</div>"
|
| 328 |
+
if evidence:
|
| 329 |
+
body += f"<div class='k'>Evidence</div><div class='v'>{evidence}</div>"
|
| 330 |
+
blocks.append(
|
| 331 |
+
f"<details><summary>{head or 'Rewrite suggestion'}</summary>{body}</details>"
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
blocks.append(f"<div class='v'>{str(s)}</div>")
|
| 335 |
+
|
| 336 |
+
return (
|
| 337 |
+
'<div class="kv"><div class="k">Rewrite Suggestions</div><div class="v">'
|
| 338 |
+
+ "".join(blocks)
|
| 339 |
+
+ "</div></div>"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def format_results_html(review: dict) -> str:
|
| 344 |
+
"""
|
| 345 |
+
Format results analyzer output as structured HTML cards.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
review: The review dictionary containing results_analyzer_json
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
HTML string with formatted results analysis
|
| 352 |
+
"""
|
| 353 |
+
parsed = safe_json_parse(review.get("results_analyzer_json"))
|
| 354 |
+
if not parsed:
|
| 355 |
+
return "<p>Unable to parse the results analysis data.</p>"
|
| 356 |
+
|
| 357 |
+
facts = parsed.get("facts", {}) if isinstance(parsed, dict) else {}
|
| 358 |
+
datasets = facts.get("datasets", [])
|
| 359 |
+
metrics = facts.get("metrics", [])
|
| 360 |
+
baselines = facts.get("baselines", [])
|
| 361 |
+
key_results = facts.get("key_results", [])
|
| 362 |
+
review_issues = parsed.get("review_issues", {}) if isinstance(parsed, dict) else {}
|
| 363 |
+
rewrite_suggestions = parsed.get("rewrite_suggestions", []) if isinstance(parsed, dict) else []
|
| 364 |
+
|
| 365 |
+
html = '<div class="card-grid">'
|
| 366 |
+
html += '<div class="card"><h4>Results Analysis <span class="pill">structured</span></h4>'
|
| 367 |
+
|
| 368 |
+
if datasets:
|
| 369 |
+
html += (
|
| 370 |
+
'<div class="kv"><div class="k">Datasets</div><div class="v">'
|
| 371 |
+
+ "\n".join(f"- {x}" for x in datasets)
|
| 372 |
+
+ "</div></div>"
|
| 373 |
+
)
|
| 374 |
+
if metrics:
|
| 375 |
+
html += (
|
| 376 |
+
'<div class="kv"><div class="k">Metrics</div><div class="v">'
|
| 377 |
+
+ "\n".join(f"- {x}" for x in metrics)
|
| 378 |
+
+ "</div></div>"
|
| 379 |
+
)
|
| 380 |
+
if baselines:
|
| 381 |
+
html += (
|
| 382 |
+
'<div class="kv"><div class="k">Baselines</div><div class="v">'
|
| 383 |
+
+ "\n".join(f"- {x}" for x in baselines)
|
| 384 |
+
+ "</div></div>"
|
| 385 |
+
)
|
| 386 |
+
if key_results:
|
| 387 |
+
items = []
|
| 388 |
+
for kr in key_results:
|
| 389 |
+
if isinstance(kr, dict):
|
| 390 |
+
claim = kr.get("claim", "")
|
| 391 |
+
evidence = kr.get("evidence", "")
|
| 392 |
+
block = (
|
| 393 |
+
f"<details><summary>{claim or 'Key result'}</summary>"
|
| 394 |
+
f"<div class='v'>{evidence}</div></details>"
|
| 395 |
+
)
|
| 396 |
+
items.append(block)
|
| 397 |
+
if items:
|
| 398 |
+
html += (
|
| 399 |
+
'<div class="kv"><div class="k">Key Results (claim → evidence)</div>'
|
| 400 |
+
'<div class="v">' + "".join(items) + "</div></div>"
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if review_issues:
|
| 404 |
+
html += _render_review_issues_html(review_issues)
|
| 405 |
+
if rewrite_suggestions:
|
| 406 |
+
html += _render_rewrite_suggestions_html(rewrite_suggestions)
|
| 407 |
+
|
| 408 |
+
html += "</div></div>"
|
| 409 |
+
return html
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def format_insights_html(review: dict) -> str:
|
| 413 |
+
"""
|
| 414 |
+
Format insight miner output as structured HTML cards.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
review: The review dictionary containing insight_miner_json
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
HTML string with formatted insights
|
| 421 |
+
"""
|
| 422 |
+
parsed = safe_json_parse(review.get("insight_miner_json"))
|
| 423 |
+
if not parsed:
|
| 424 |
+
return "<p>Unable to parse the insights data.</p>"
|
| 425 |
+
|
| 426 |
+
facts = parsed.get("facts", {}) if isinstance(parsed, dict) else {}
|
| 427 |
+
review_issues = parsed.get("review_issues", {}) if isinstance(parsed, dict) else {}
|
| 428 |
+
rewrite_suggestions = parsed.get("rewrite_suggestions", []) if isinstance(parsed, dict) else []
|
| 429 |
+
|
| 430 |
+
def render_list(title: str, items: list) -> str:
|
| 431 |
+
if not items:
|
| 432 |
+
return ""
|
| 433 |
+
blocks = []
|
| 434 |
+
for it in items:
|
| 435 |
+
if isinstance(it, dict):
|
| 436 |
+
head = (
|
| 437 |
+
it.get("claim")
|
| 438 |
+
or it.get("point")
|
| 439 |
+
or it.get("item")
|
| 440 |
+
or it.get("what_missing")
|
| 441 |
+
or "Item"
|
| 442 |
+
)
|
| 443 |
+
evidence = (
|
| 444 |
+
it.get("evidence")
|
| 445 |
+
or it.get("why_important")
|
| 446 |
+
or it.get("why_wrong")
|
| 447 |
+
or it.get("how_to_fix")
|
| 448 |
+
or ""
|
| 449 |
+
)
|
| 450 |
+
blocks.append(
|
| 451 |
+
f"<details><summary>{head}</summary><div class='v'>{evidence}</div></details>"
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
blocks.append(f"<div class='v'>{str(it)}</div>")
|
| 455 |
+
return (
|
| 456 |
+
f'<div class="kv"><div class="k">{title}</div><div class="v">'
|
| 457 |
+
+ "".join(blocks)
|
| 458 |
+
+ "</div></div>"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
html = '<div class="card-grid">'
|
| 462 |
+
html += '<div class="card"><h4>Paper Insights <span class="pill">structured</span></h4>'
|
| 463 |
+
|
| 464 |
+
html += render_list("Core Contributions", facts.get("core_contributions", []))
|
| 465 |
+
html += render_list("Method Summary", facts.get("method_summary", []))
|
| 466 |
+
html += render_list("Assumptions & Scope", facts.get("assumptions_and_scope", []))
|
| 467 |
+
html += render_list("Novelty Claims (paper)", facts.get("novelty_claims_in_paper", []))
|
| 468 |
+
|
| 469 |
+
if review_issues:
|
| 470 |
+
html += _render_review_issues_html(review_issues)
|
| 471 |
+
if rewrite_suggestions:
|
| 472 |
+
html += _render_rewrite_suggestions_html(rewrite_suggestions)
|
| 473 |
+
|
| 474 |
+
html += "</div></div>"
|
| 475 |
+
return html
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def format_final_review(review: dict) -> str:
|
| 479 |
+
"""
|
| 480 |
+
Extract and return the final review markdown.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
review: The review dictionary containing the final review
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
The final review markdown string
|
| 487 |
+
"""
|
| 488 |
+
return review.get("review_markdown") or review.get("review") or "Final review markdown missing."
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def format_raw_json(review: dict) -> str:
|
| 492 |
+
"""
|
| 493 |
+
Format the complete review as a JSON code block.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
review: The complete review dictionary
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
JSON formatted as a Markdown code block
|
| 500 |
+
"""
|
| 501 |
+
try:
|
| 502 |
+
return "```json\n" + json.dumps(review, indent=2, ensure_ascii=False) + "\n```"
|
| 503 |
+
except Exception:
|
| 504 |
+
return str(review)
|
gradio_app/components/header.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Header component for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This module provides the app header with gradient background,
|
| 5 |
+
title, BETA badge, and privacy notice banner.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_header() -> None:
|
| 12 |
+
"""
|
| 13 |
+
Create the app header with gradient background and privacy notice.
|
| 14 |
+
|
| 15 |
+
Renders:
|
| 16 |
+
- Purple gradient header with title and BETA badge
|
| 17 |
+
- Privacy notice banner explaining the demo mode
|
| 18 |
+
"""
|
| 19 |
+
# Main header with gradient background
|
| 20 |
+
gr.HTML("""
|
| 21 |
+
<div class="app-header">
|
| 22 |
+
<div class="app-header-content">
|
| 23 |
+
<h1 class="app-title" style="margin-bottom: 12px;">
|
| 24 |
+
🔬 Review Grounder
|
| 25 |
+
<span class="beta-badge">BETA</span>
|
| 26 |
+
</h1>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
<div class="privacy-notice">
|
| 31 |
+
<span class="privacy-icon">🔒</span>
|
| 32 |
+
<div class="privacy-content">
|
| 33 |
+
<h4 style="color: white;">Privacy Notice: Anonymous Demo</h4>
|
| 34 |
+
<p style="color: white;">This is an anonymous demonstration. We do not save your PDF file, paper information, or any uploaded content.</p>
|
| 35 |
+
</div>
|
| 36 |
+
</div>
|
| 37 |
+
|
| 38 |
+
</div>
|
| 39 |
+
""")
|
gradio_app/components/results_panel.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Results panel component for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This module provides the right panel with tabs for displaying
|
| 5 |
+
different stages of the review pipeline results.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_results_placeholder() -> str:
|
| 13 |
+
"""
|
| 14 |
+
Return the HTML for the initial placeholder state.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
HTML string for the "Ready to Review" placeholder
|
| 18 |
+
"""
|
| 19 |
+
return """
|
| 20 |
+
<div class="results-placeholder">
|
| 21 |
+
<div class="results-placeholder-icon">📋</div>
|
| 22 |
+
<h3>Ready to Review</h3>
|
| 23 |
+
<p>Upload your PDF and click "🚀 Generate AI Review" to get comprehensive feedback on your research.</p>
|
| 24 |
+
</div>
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def create_initial_draft_placeholder() -> str:
|
| 28 |
+
"""
|
| 29 |
+
Return the HTML for the initial draft placeholder state.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
HTML string for the initial draft placeholder
|
| 33 |
+
"""
|
| 34 |
+
return """
|
| 35 |
+
<div class="results-placeholder">
|
| 36 |
+
<div class="results-placeholder-icon">📝</div>
|
| 37 |
+
<h3>Initial Draft</h3>
|
| 38 |
+
<p>The draft of the paper review will appear here.</p>
|
| 39 |
+
</div>
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def create_results_analyzer_placeholder() -> str:
|
| 43 |
+
"""
|
| 44 |
+
Return the HTML for the results analyzer placeholder state.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
HTML string for the results analyzer placeholder
|
| 48 |
+
"""
|
| 49 |
+
return """
|
| 50 |
+
<div class="results-placeholder">
|
| 51 |
+
<div class="results-placeholder-icon">📈</div>
|
| 52 |
+
<h3>Results Analyzer</h3>
|
| 53 |
+
<p>The analysis of the paper's experiments and results will appear here.</p>
|
| 54 |
+
</div>
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def create_insights_miner_placeholder() -> str:
|
| 58 |
+
"""
|
| 59 |
+
Return the HTML for the insights miner placeholder state.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
HTML string for the insights miner placeholder
|
| 63 |
+
"""
|
| 64 |
+
return """
|
| 65 |
+
<div class="results-placeholder">
|
| 66 |
+
<div class="results-placeholder-icon">💡</div>
|
| 67 |
+
<h3>Insight Miner</h3>
|
| 68 |
+
<p>The insights retrieved from the paper's content will appear here.</p>
|
| 69 |
+
</div>
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def create_related_work_placeholder() -> str:
|
| 73 |
+
"""
|
| 74 |
+
Return the HTML for the related work placeholder state.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
HTML string for the related work placeholder
|
| 78 |
+
"""
|
| 79 |
+
return """
|
| 80 |
+
<div class="results-placeholder">
|
| 81 |
+
<div class="results-placeholder-icon">📚</div>
|
| 82 |
+
<h3>Related Work</h3>
|
| 83 |
+
<p>The curated research papers and their summaries related to your uploaded paper will appear here.</p>
|
| 84 |
+
</div>
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def create_final_review_placeholder() -> str:
|
| 88 |
+
"""
|
| 89 |
+
Return the HTML for the final review placeholder state.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
HTML string for the final review placeholder
|
| 93 |
+
"""
|
| 94 |
+
return """
|
| 95 |
+
<div class="results-placeholder">
|
| 96 |
+
<div class="results-placeholder-icon">🎯</div>
|
| 97 |
+
<h3>Final Review</h3>
|
| 98 |
+
<p>The final refined review of the paper will appear here. It is the refinement of the initial draft based on the joint information from the results analyzer, insight miner, and related work.</p>
|
| 99 |
+
</div>
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_results_panel(
|
| 104 |
+
show_log: bool = False,
|
| 105 |
+
show_raw_json: bool = False,
|
| 106 |
+
) -> Tuple[gr.HTML, gr.HTML, gr.HTML, gr.HTML, gr.Markdown, gr.Textbox, gr.Markdown, gr.Accordion, gr.Tab, gr.DownloadButton]:
|
| 107 |
+
"""
|
| 108 |
+
Create the results panel with tabs for different pipeline stages.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Tuple including download_json_btn for downloading raw JSON.
|
| 112 |
+
"""
|
| 113 |
+
gr.HTML('<div class="panel-title">📝 AI Review Results</div>')
|
| 114 |
+
|
| 115 |
+
# Status log (conditionally visible)
|
| 116 |
+
with gr.Accordion("📋 Pipeline Log", open=False, visible=show_log) as log_accordion:
|
| 117 |
+
status_output = gr.Textbox(
|
| 118 |
+
value="Ready. Upload a PDF and click '🚀 Generate AI Review' to start.",
|
| 119 |
+
lines=8,
|
| 120 |
+
max_lines=20,
|
| 121 |
+
interactive=False,
|
| 122 |
+
autoscroll=True,
|
| 123 |
+
elem_classes=["status-log"],
|
| 124 |
+
show_label=False,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Main results tabs
|
| 128 |
+
with gr.Tabs():
|
| 129 |
+
with gr.Tab("🎯 Final Review", id="final"):
|
| 130 |
+
gr.HTML("""
|
| 131 |
+
<div class="final-review-toolbar">
|
| 132 |
+
<button type="button" class="copy-final-btn" onclick="(function(){
|
| 133 |
+
var el = document.getElementById('final-review-md');
|
| 134 |
+
if (!el) el = document.querySelector('[id=\\'final-review-md\\']');
|
| 135 |
+
if (!el) el = document.querySelector('.final-review-wrap .gr-markdown, .final-review-wrap [class*=\\'markdown\\']');
|
| 136 |
+
var text = el ? (el.innerText || el.textContent || '') : '';
|
| 137 |
+
if (text && navigator.clipboard && navigator.clipboard.writeText) {
|
| 138 |
+
navigator.clipboard.writeText(text).then(function(){ alert('Copied to clipboard'); }).catch(function(){ alert('Copy failed'); });
|
| 139 |
+
} else { alert('Nothing to copy'); }
|
| 140 |
+
})();" title="Copy full review text">📋 Copy to clipboard</button>
|
| 141 |
+
</div>
|
| 142 |
+
""")
|
| 143 |
+
with gr.Group(elem_classes=["final-review-wrap"]):
|
| 144 |
+
final_md = gr.Markdown(
|
| 145 |
+
value=create_results_placeholder(),
|
| 146 |
+
label="Final Review",
|
| 147 |
+
elem_id="final-review-md",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
with gr.Tab("📝 Initial Draft", id="initial"):
|
| 151 |
+
initial_html = gr.HTML(
|
| 152 |
+
value=create_initial_draft_placeholder(),
|
| 153 |
+
label="Initial Draft",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
with gr.Tab("📈 Results Analyzer", id="results"):
|
| 157 |
+
results_html = gr.HTML(
|
| 158 |
+
value=create_results_analyzer_placeholder(),
|
| 159 |
+
label="Results Analyzer",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
with gr.Tab("💡 Insight Miner", id="insights"):
|
| 163 |
+
insights_html = gr.HTML(
|
| 164 |
+
value=create_insights_miner_placeholder(),
|
| 165 |
+
label="Insight Miner",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Tab("📚 Related Work", id="related"):
|
| 169 |
+
related_html = gr.HTML(
|
| 170 |
+
value=create_related_work_placeholder(),
|
| 171 |
+
label="Related Work",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Raw JSON tab (conditionally visible based on toggle)
|
| 175 |
+
with gr.Tab("🔧 Raw JSON", id="raw_json", visible=show_raw_json) as raw_json_tab:
|
| 176 |
+
raw_json_md = gr.Markdown(
|
| 177 |
+
value="Raw JSON output for debugging will appear here.",
|
| 178 |
+
label="Raw JSON",
|
| 179 |
+
)
|
| 180 |
+
download_json_btn = gr.DownloadButton("⬇️ Download as JSON")
|
| 181 |
+
|
| 182 |
+
return (
|
| 183 |
+
initial_html,
|
| 184 |
+
results_html,
|
| 185 |
+
insights_html,
|
| 186 |
+
related_html,
|
| 187 |
+
final_md,
|
| 188 |
+
status_output,
|
| 189 |
+
raw_json_md,
|
| 190 |
+
log_accordion,
|
| 191 |
+
raw_json_tab,
|
| 192 |
+
download_json_btn,
|
| 193 |
+
)
|
gradio_app/components/settings.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Settings component for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This module provides the advanced settings panel with:
|
| 5 |
+
- LLM endpoint configuration
|
| 6 |
+
- API key input
|
| 7 |
+
- Model name selection
|
| 8 |
+
- Toggle for showing/hiding log and raw JSON
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from typing import Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Default values for OpenAI API
|
| 16 |
+
DEFAULT_API_ENDPOINT = "https://api.openai.com/v1"
|
| 17 |
+
DEFAULT_MODEL_NAME = "gpt-4o"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_advanced_settings() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox, gr.Checkbox, gr.Checkbox]:
|
| 21 |
+
"""
|
| 22 |
+
Create the advanced settings accordion with LLM configuration.
|
| 23 |
+
|
| 24 |
+
The accordion is collapsed by default to hide technical details
|
| 25 |
+
from casual users, while allowing power users to customize.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Tuple containing:
|
| 29 |
+
- api_base_url_in: Textbox for LLM endpoint URL
|
| 30 |
+
- api_key_in: Textbox for API key (password masked)
|
| 31 |
+
- model_name_in: Textbox for model name
|
| 32 |
+
- show_log_toggle: Checkbox for showing/hiding log
|
| 33 |
+
- show_raw_json_toggle: Checkbox for showing/hiding raw JSON
|
| 34 |
+
"""
|
| 35 |
+
with gr.Accordion(
|
| 36 |
+
"⚙️ Advanced Settings",
|
| 37 |
+
open=False,
|
| 38 |
+
elem_classes=["advanced-settings"]
|
| 39 |
+
):
|
| 40 |
+
gr.Markdown("""
|
| 41 |
+
Configure your LLM provider and display preferences.
|
| 42 |
+
Leave fields empty to use environment variables.
|
| 43 |
+
""")
|
| 44 |
+
|
| 45 |
+
api_base_url_in = gr.Textbox(
|
| 46 |
+
label="🔗 LLM Endpoint (base_url)",
|
| 47 |
+
placeholder="https://api.openai.com/v1",
|
| 48 |
+
value=DEFAULT_API_ENDPOINT,
|
| 49 |
+
info="The base URL for your LLM API (OpenAI, OpenRouter, local, etc.)",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
api_key_in = gr.Textbox(
|
| 53 |
+
label="🔑 API Key",
|
| 54 |
+
type="password",
|
| 55 |
+
placeholder="sk-...",
|
| 56 |
+
value="",
|
| 57 |
+
info="Your API key (leave empty to use OPENAI_API_KEY env var)",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
model_name_in = gr.Textbox(
|
| 61 |
+
label="🤖 Model Name",
|
| 62 |
+
placeholder="gpt-4o",
|
| 63 |
+
value=DEFAULT_MODEL_NAME,
|
| 64 |
+
info="Model identifier (e.g., gpt-4o, gpt-4-turbo, claude-3-opus)",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
gr.Markdown("---")
|
| 68 |
+
gr.Markdown("**Display Options**")
|
| 69 |
+
|
| 70 |
+
show_log_toggle = gr.Checkbox(
|
| 71 |
+
label="📋 Show Pipeline Log",
|
| 72 |
+
value=False,
|
| 73 |
+
info="Display detailed execution log during processing",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
show_raw_json_toggle = gr.Checkbox(
|
| 77 |
+
label="📄 Show Raw JSON Output",
|
| 78 |
+
value=False,
|
| 79 |
+
info="Display raw JSON data in results (for debugging)",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return api_base_url_in, api_key_in, model_name_in, show_log_toggle, show_raw_json_toggle
|
gradio_app/components/styles.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom CSS styles for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This module provides all custom CSS needed for the modern UI design,
|
| 5 |
+
including the gradient header, card layouts, buttons, and other visual elements.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_custom_css() -> str:
|
| 10 |
+
"""
|
| 11 |
+
Return the complete custom CSS for the app.
|
| 12 |
+
|
| 13 |
+
Includes styles for:
|
| 14 |
+
- Gradient header with purple theme
|
| 15 |
+
- Privacy notice banner
|
| 16 |
+
- How it works section with numbered steps
|
| 17 |
+
- Card-based layouts
|
| 18 |
+
- Custom button styling
|
| 19 |
+
- Results panel tabs
|
| 20 |
+
"""
|
| 21 |
+
return """
|
| 22 |
+
/* ===== Global Styles ===== */
|
| 23 |
+
.gradio-container {
|
| 24 |
+
max-width: 1400px !important;
|
| 25 |
+
margin: 0 auto !important;
|
| 26 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif !important;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
/* ===== Header Styles ===== */
|
| 30 |
+
.app-header {
|
| 31 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 32 |
+
padding: 20px 30px;
|
| 33 |
+
border-radius: 12px;
|
| 34 |
+
margin-bottom: 20px;
|
| 35 |
+
position: relative;
|
| 36 |
+
overflow: hidden;
|
| 37 |
+
display: flex;
|
| 38 |
+
flex-direction: column;
|
| 39 |
+
width: calc(100% + 24px) !important;
|
| 40 |
+
margin-left: -12px;
|
| 41 |
+
margin-right: -12px;
|
| 42 |
+
align-items: stretch; /* Stretch children to fill container width */
|
| 43 |
+
justify-content: flex-start; /* Align children from top to bottom */
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.app-header::before {
|
| 47 |
+
content: '';
|
| 48 |
+
position: absolute;
|
| 49 |
+
top: 0;
|
| 50 |
+
left: 0;
|
| 51 |
+
right: 0;
|
| 52 |
+
bottom: 0;
|
| 53 |
+
background-image: radial-gradient(circle at 20% 50%, rgba(255,255,255,0.1) 1px, transparent 1px),
|
| 54 |
+
radial-gradient(circle at 80% 50%, rgba(255,255,255,0.1) 1px, transparent 1px);
|
| 55 |
+
background-size: 40px 40px;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
.app-header-content {
|
| 59 |
+
position: relative;
|
| 60 |
+
z-index: 1;
|
| 61 |
+
display: flex;
|
| 62 |
+
justify-content: space-between;
|
| 63 |
+
align-items: flex-start;
|
| 64 |
+
gap: 12px;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
.app-title {
|
| 70 |
+
color: white;
|
| 71 |
+
font-size: 1.8em;
|
| 72 |
+
font-weight: 700;
|
| 73 |
+
margin: 0;
|
| 74 |
+
display: flex;
|
| 75 |
+
align-items: center;
|
| 76 |
+
gap: 12px;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
.beta-badge {
|
| 80 |
+
background: rgba(255,255,255,0.2);
|
| 81 |
+
color: white;
|
| 82 |
+
padding: 4px 12px;
|
| 83 |
+
border-radius: 20px;
|
| 84 |
+
font-size: 0.5em;
|
| 85 |
+
font-weight: 600;
|
| 86 |
+
text-transform: uppercase;
|
| 87 |
+
letter-spacing: 0.5px;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
.back-button {
|
| 91 |
+
background: rgba(255,255,255,0.2);
|
| 92 |
+
color: white;
|
| 93 |
+
padding: 8px 16px;
|
| 94 |
+
border-radius: 8px;
|
| 95 |
+
text-decoration: none;
|
| 96 |
+
font-size: 0.9em;
|
| 97 |
+
transition: background 0.2s;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.back-button:hover {
|
| 101 |
+
background: rgba(255,255,255,0.3);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
/* ===== Privacy Notice Banner ===== */
|
| 105 |
+
.privacy-notice {
|
| 106 |
+
background: linear-gradient(90deg, #5b4fa8 0%, #7c3aed 100%);
|
| 107 |
+
color: white;
|
| 108 |
+
|
| 109 |
+
padding: 12px 20px;
|
| 110 |
+
border-radius: 8px;
|
| 111 |
+
margin-top: 20px;
|
| 112 |
+
margin-bottom: 20px;
|
| 113 |
+
display: flex;
|
| 114 |
+
align-items: flex-start;
|
| 115 |
+
gap: 12px;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
.privacy-icon {
|
| 119 |
+
font-size: 1.2em;
|
| 120 |
+
margin-top: 2px;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
.privacy-content h4 {
|
| 124 |
+
margin: 0 0 4px 0;
|
| 125 |
+
font-size: 1em;
|
| 126 |
+
font-weight: 600;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
.privacy-content p {
|
| 130 |
+
margin: 0;
|
| 131 |
+
font-size: 0.85em;
|
| 132 |
+
opacity: 0.9;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
/* ===== Main Panel Cards ===== */
|
| 136 |
+
.panel-card {
|
| 137 |
+
background: white;
|
| 138 |
+
border-radius: 12px;
|
| 139 |
+
padding: 24px;
|
| 140 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
|
| 141 |
+
border: 1px solid #e5e7eb;
|
| 142 |
+
overflow: hidden;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# /* Unify width: Gradio theme often gives HTML/Accordion blocks a max-width.
|
| 146 |
+
# Force all blocks in the panel to span full width like the button. */
|
| 147 |
+
# .panel-card .gr-block,
|
| 148 |
+
# .panel-card .block,
|
| 149 |
+
# .panel-card > div {
|
| 150 |
+
# max-width: none !important;
|
| 151 |
+
# width: 100% !important;
|
| 152 |
+
# box-sizing: border-box;
|
| 153 |
+
# }
|
| 154 |
+
|
| 155 |
+
# .panel-card .how-it-works {
|
| 156 |
+
# width: 100% !important;
|
| 157 |
+
# box-sizing: border-box;
|
| 158 |
+
# }
|
| 159 |
+
|
| 160 |
+
.panel-title {
|
| 161 |
+
font-size: 1.1em;
|
| 162 |
+
font-weight: 600;
|
| 163 |
+
color: #1f2937;
|
| 164 |
+
margin: 0 0 20px 0;
|
| 165 |
+
display: flex;
|
| 166 |
+
align-items: center;
|
| 167 |
+
gap: 8px;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/* ===== How It Works Section ===== */
|
| 171 |
+
/* Stretch to same width as button: cancel panel padding with negative margin, then add inner padding */
|
| 172 |
+
.how-it-works {
|
| 173 |
+
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
|
| 174 |
+
border-radius: 10px;
|
| 175 |
+
padding: 16px 24px;
|
| 176 |
+
margin-left: -12px;
|
| 177 |
+
margin-right: -12px;
|
| 178 |
+
margin-bottom: 20px;
|
| 179 |
+
width: calc(100% + 24px) !important;
|
| 180 |
+
max-width: none;
|
| 181 |
+
box-sizing: border-box;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.how-it-works-title {
|
| 185 |
+
color: #d97706;
|
| 186 |
+
font-weight: 600;
|
| 187 |
+
font-size: 0.95em;
|
| 188 |
+
margin-bottom: 12px;
|
| 189 |
+
display: flex;
|
| 190 |
+
align-items: center;
|
| 191 |
+
gap: 6px;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.step-item {
|
| 195 |
+
display: flex;
|
| 196 |
+
align-items: flex-start;
|
| 197 |
+
gap: 10px;
|
| 198 |
+
margin-bottom: 8px;
|
| 199 |
+
font-size: 0.9em;
|
| 200 |
+
color: #374151;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.step-number {
|
| 204 |
+
background: #f97316;
|
| 205 |
+
color: white;
|
| 206 |
+
width: 20px;
|
| 207 |
+
height: 20px;
|
| 208 |
+
border-radius: 50%;
|
| 209 |
+
display: flex;
|
| 210 |
+
align-items: center;
|
| 211 |
+
justify-content: center;
|
| 212 |
+
font-size: 0.75em;
|
| 213 |
+
font-weight: 600;
|
| 214 |
+
flex-shrink: 0;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/* ===== Upload Area ===== */
|
| 218 |
+
.upload-area {
|
| 219 |
+
border: 2px dashed #d1d5db;
|
| 220 |
+
border-radius: 12px;
|
| 221 |
+
padding: 40px 20px;
|
| 222 |
+
text-align: center;
|
| 223 |
+
background: #f9fafb;
|
| 224 |
+
transition: all 0.2s;
|
| 225 |
+
margin-bottom: 20px;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
.upload-area:hover {
|
| 229 |
+
border-color: #9ca3af;
|
| 230 |
+
background: #f3f4f6;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/* Hide default Gradio file upload prompt text ("Click to upload or drag and drop", etc.) */
|
| 234 |
+
.file-upload-minimal .gr-formatted-text,
|
| 235 |
+
.file-upload-minimal .gr-box > div:not([class*="file"]):not([class*="preview"]),
|
| 236 |
+
#pdf-upload .gr-formatted-text,
|
| 237 |
+
#pdf-upload .wrap-inner .gr-formatted-text {
|
| 238 |
+
display: none !important;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
.upload-icon {
|
| 242 |
+
font-size: 3em;
|
| 243 |
+
color: #9ca3af;
|
| 244 |
+
margin-bottom: 12px;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.upload-text {
|
| 248 |
+
color: #6b7280;
|
| 249 |
+
font-size: 0.95em;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.upload-link {
|
| 253 |
+
color: #6366f1;
|
| 254 |
+
font-weight: 600;
|
| 255 |
+
cursor: pointer;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.upload-hint {
|
| 259 |
+
color: #9ca3af;
|
| 260 |
+
font-size: 0.85em;
|
| 261 |
+
margin-top: 8px;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
/* ===== Primary Action Button ===== */
|
| 265 |
+
.primary-btn {
|
| 266 |
+
background: linear-gradient(135deg, #fb923c 0%, #f97316 100%) !important;
|
| 267 |
+
color: white !important;
|
| 268 |
+
padding: 14px 28px !important;
|
| 269 |
+
border-radius: 10px !important;
|
| 270 |
+
font-weight: 600 !important;
|
| 271 |
+
font-size: 1em !important;
|
| 272 |
+
border: none !important;
|
| 273 |
+
cursor: pointer !important;
|
| 274 |
+
transition: all 0.2s !important;
|
| 275 |
+
width: 100% !important;
|
| 276 |
+
box-shadow: 0 4px 12px rgba(249, 115, 22, 0.3) !important;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
.primary-btn:hover {
|
| 280 |
+
transform: translateY(-1px) !important;
|
| 281 |
+
box-shadow: 0 6px 16px rgba(249, 115, 22, 0.4) !important;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.primary-btn:disabled {
|
| 285 |
+
opacity: 0.6 !important;
|
| 286 |
+
cursor: not-allowed !important;
|
| 287 |
+
transform: none !important;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
/* ===== Secondary Buttons (Disabled State) ===== */
|
| 291 |
+
.secondary-btn {
|
| 292 |
+
background: #f3f4f6 !important;
|
| 293 |
+
color: #9ca3af !important;
|
| 294 |
+
padding: 12px 24px !important;
|
| 295 |
+
border-radius: 10px !important;
|
| 296 |
+
font-weight: 500 !important;
|
| 297 |
+
border: 1px solid #e5e7eb !important;
|
| 298 |
+
cursor: not-allowed !important;
|
| 299 |
+
width: 100% !important;
|
| 300 |
+
margin-top: 8px !important;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.unavailable-badge {
|
| 304 |
+
background: #e5e7eb;
|
| 305 |
+
color: #9ca3af;
|
| 306 |
+
padding: 2px 8px;
|
| 307 |
+
border-radius: 4px;
|
| 308 |
+
font-size: 0.7em;
|
| 309 |
+
margin-left: 8px;
|
| 310 |
+
text-transform: uppercase;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
/* ===== Advanced Settings Accordion ===== */
|
| 314 |
+
/* Stretch to same width as button: cancel panel padding, align inner padding with panel */
|
| 315 |
+
.panel-card .advanced-settings,
|
| 316 |
+
.advanced-settings {
|
| 317 |
+
margin-left: -12px !important;
|
| 318 |
+
margin-right: -12px !important;
|
| 319 |
+
width: calc(100% + 24px) !important;
|
| 320 |
+
max-width: none !important;
|
| 321 |
+
box-sizing: border-box;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
.advanced-settings .label-wrap {
|
| 325 |
+
background: #f9fafb !important;
|
| 326 |
+
border-radius: 8px !important;
|
| 327 |
+
padding: 12px 24px !important;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
.advanced-settings .label-wrap span {
|
| 331 |
+
font-weight: 500 !important;
|
| 332 |
+
color: #4b5563 !important;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
/* Accordion content area: same horizontal padding as panel */
|
| 336 |
+
.advanced-settings .gr-group,
|
| 337 |
+
.advanced-settings .wrap {
|
| 338 |
+
padding-left: 24px !important;
|
| 339 |
+
padding-right: 24px !important;
|
| 340 |
+
box-sizing: border-box;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
/* ===== Results Panel ===== */
|
| 344 |
+
.results-panel {
|
| 345 |
+
background: white;
|
| 346 |
+
border-radius: 12px;
|
| 347 |
+
padding: 24px;
|
| 348 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
|
| 349 |
+
border: 1px solid #e5e7eb;
|
| 350 |
+
min-height: 500px;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
.results-placeholder {
|
| 354 |
+
text-align: center;
|
| 355 |
+
padding: 60px 20px;
|
| 356 |
+
color: #9ca3af;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
.results-placeholder-icon {
|
| 360 |
+
font-size: 4em;
|
| 361 |
+
margin-bottom: 16px;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
.results-placeholder h3 {
|
| 365 |
+
color: #374151;
|
| 366 |
+
margin: 0 0 8px 0;
|
| 367 |
+
font-weight: 600;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
.results-placeholder p {
|
| 371 |
+
margin: 0;
|
| 372 |
+
font-size: 0.95em;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
/* ===== Tab Styling ===== */
|
| 376 |
+
.tabs .tab-nav {
|
| 377 |
+
border-bottom: 2px solid #e5e7eb !important;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
.tabs .tab-nav button {
|
| 381 |
+
font-weight: 500 !important;
|
| 382 |
+
color: #6b7280 !important;
|
| 383 |
+
padding: 12px 20px !important;
|
| 384 |
+
border: none !important;
|
| 385 |
+
background: transparent !important;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
.tabs .tab-nav button.selected {
|
| 389 |
+
color: #6366f1 !important;
|
| 390 |
+
border-bottom: 2px solid #6366f1 !important;
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
/* ===== Initial Draft & Final Review content ===== */
|
| 394 |
+
.review-message {
|
| 395 |
+
color: #6b7280;
|
| 396 |
+
font-style: italic;
|
| 397 |
+
padding: 1em;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
.review-draft-content,
|
| 401 |
+
.initial-draft-card {
|
| 402 |
+
max-width: 100%;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
.review-text {
|
| 406 |
+
white-space: pre-wrap;
|
| 407 |
+
word-break: break-word;
|
| 408 |
+
line-height: 1.6;
|
| 409 |
+
color: #1f2937;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
/* Final Review: toolbar with copy button at top right */
|
| 413 |
+
.final-review-toolbar {
|
| 414 |
+
display: flex;
|
| 415 |
+
justify-content: flex-end;
|
| 416 |
+
margin-bottom: 12px;
|
| 417 |
+
}
|
| 418 |
+
.copy-final-btn {
|
| 419 |
+
padding: 8px 16px;
|
| 420 |
+
border-radius: 8px;
|
| 421 |
+
border: 1px solid #e5e7eb;
|
| 422 |
+
background: #f9fafb;
|
| 423 |
+
color: #374151;
|
| 424 |
+
font-size: 0.9em;
|
| 425 |
+
cursor: pointer;
|
| 426 |
+
transition: background 0.2s, color 0.2s;
|
| 427 |
+
}
|
| 428 |
+
.copy-final-btn:hover {
|
| 429 |
+
background: #f3f4f6;
|
| 430 |
+
color: #111827;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
/* Final Review markdown area: improve readability */
|
| 434 |
+
.results-panel .gr-markdown,
|
| 435 |
+
.results-panel .prose {
|
| 436 |
+
line-height: 1.7 !important;
|
| 437 |
+
color: #1f2937 !important;
|
| 438 |
+
max-width: 100% !important;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
.results-panel .gr-markdown h1,
|
| 442 |
+
.results-panel .gr-markdown h2,
|
| 443 |
+
.results-panel .gr-markdown h3 {
|
| 444 |
+
margin-top: 1em !important;
|
| 445 |
+
margin-bottom: 0.5em !important;
|
| 446 |
+
color: #111827 !important;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
.results-panel .gr-markdown ul,
|
| 450 |
+
.results-panel .gr-markdown ol {
|
| 451 |
+
padding-left: 1.5em !important;
|
| 452 |
+
margin: 0.5em 0 !important;
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
.results-panel .gr-markdown p {
|
| 456 |
+
margin: 0.5em 0 !important;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
/* ===== Card Grid for Results ===== */
|
| 460 |
+
.card-grid {
|
| 461 |
+
display: flex;
|
| 462 |
+
flex-direction: column;
|
| 463 |
+
gap: 12px;
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
.card {
|
| 467 |
+
background: #f8f9fa;
|
| 468 |
+
border: 1px solid #e9ecef;
|
| 469 |
+
border-radius: 10px;
|
| 470 |
+
padding: 14px 16px;
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
.card h4 {
|
| 474 |
+
margin: 0 0 8px 0;
|
| 475 |
+
font-size: 1.05em;
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
.kv {
|
| 479 |
+
margin: 8px 0;
|
| 480 |
+
padding: 10px;
|
| 481 |
+
background: #ffffff;
|
| 482 |
+
border-radius: 8px;
|
| 483 |
+
border: 1px solid #eef1f4;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
.k {
|
| 487 |
+
font-weight: 650;
|
| 488 |
+
color: #495057;
|
| 489 |
+
margin-bottom: 4px;
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
.v {
|
| 493 |
+
color: #212529;
|
| 494 |
+
line-height: 1.55;
|
| 495 |
+
white-space: pre-wrap;
|
| 496 |
+
word-break: break-word;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
details {
|
| 500 |
+
background: #ffffff;
|
| 501 |
+
border: 1px solid #eef1f4;
|
| 502 |
+
border-radius: 8px;
|
| 503 |
+
padding: 10px 12px;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
summary {
|
| 507 |
+
cursor: pointer;
|
| 508 |
+
font-weight: 650;
|
| 509 |
+
color: #212529;
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
.mono {
|
| 513 |
+
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;
|
| 514 |
+
font-size: 12.5px;
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
.pill {
|
| 518 |
+
display: inline-block;
|
| 519 |
+
padding: 2px 8px;
|
| 520 |
+
border-radius: 999px;
|
| 521 |
+
background: #e7f1ff;
|
| 522 |
+
color: #0b5ed7;
|
| 523 |
+
font-size: 12px;
|
| 524 |
+
margin-left: 8px;
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
/* ===== Related Work Cards ===== */
|
| 528 |
+
.related-work-container {
|
| 529 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 530 |
+
max-width: 100%;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
.related-paper-card {
|
| 534 |
+
background: #f8f9fa;
|
| 535 |
+
border-left: 4px solid #007bff;
|
| 536 |
+
border-radius: 6px;
|
| 537 |
+
padding: 16px;
|
| 538 |
+
margin-bottom: 16px;
|
| 539 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 540 |
+
transition: transform 0.2s, box-shadow 0.2s;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
.related-paper-card:hover {
|
| 544 |
+
transform: translateY(-2px);
|
| 545 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.15);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
.paper-header {
|
| 549 |
+
font-weight: 600;
|
| 550 |
+
font-size: 1.1em;
|
| 551 |
+
color: #212529;
|
| 552 |
+
margin-bottom: 12px;
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
.paper-field {
|
| 556 |
+
margin: 8px 0;
|
| 557 |
+
padding: 8px;
|
| 558 |
+
background: white;
|
| 559 |
+
border-radius: 4px;
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
.paper-field-label {
|
| 563 |
+
font-weight: 600;
|
| 564 |
+
color: #495057;
|
| 565 |
+
font-size: 0.9em;
|
| 566 |
+
margin-bottom: 4px;
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
.paper-field-value {
|
| 570 |
+
color: #212529;
|
| 571 |
+
line-height: 1.5;
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
/* ===== Log/Status Text Area ===== */
|
| 575 |
+
.status-log textarea {
|
| 576 |
+
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important;
|
| 577 |
+
font-size: 0.85em !important;
|
| 578 |
+
background: #1f2937 !important;
|
| 579 |
+
color: #10b981 !important;
|
| 580 |
+
border-radius: 8px !important;
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
/* ===== Footer ===== */
|
| 584 |
+
.app-footer {
|
| 585 |
+
text-align: center;
|
| 586 |
+
padding: 20px;
|
| 587 |
+
color: #9ca3af;
|
| 588 |
+
font-size: 0.85em;
|
| 589 |
+
border-top: 1px solid #e5e7eb;
|
| 590 |
+
margin-top: 30px;
|
| 591 |
+
}
|
| 592 |
+
"""
|
gradio_app/components/upload_section.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Upload section component for the Review Grounder Gradio app.
|
| 3 |
+
|
| 4 |
+
This module provides the left panel with:
|
| 5 |
+
- "How it works" instructions
|
| 6 |
+
- PDF upload area
|
| 7 |
+
- Action buttons
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from typing import Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_how_it_works() -> None:
|
| 15 |
+
"""
|
| 16 |
+
Create the "How it works" instruction section.
|
| 17 |
+
|
| 18 |
+
Displays a numbered list of steps explaining the review process.
|
| 19 |
+
"""
|
| 20 |
+
gr.HTML("""
|
| 21 |
+
<div class="how-it-works">
|
| 22 |
+
<div class="how-it-works-title">
|
| 23 |
+
⚡ How it works
|
| 24 |
+
</div>
|
| 25 |
+
<div class="step-item">
|
| 26 |
+
<span class="step-number">1</span>
|
| 27 |
+
<span>Upload your research paper in PDF format</span>
|
| 28 |
+
</div>
|
| 29 |
+
<div class="step-item">
|
| 30 |
+
<span class="step-number">2</span>
|
| 31 |
+
<span>Configure your LLM settings (or use defaults)</span>
|
| 32 |
+
</div>
|
| 33 |
+
<div class="step-item">
|
| 34 |
+
<span class="step-number">3</span>
|
| 35 |
+
<span>Click "🚀 Generate AI Review" to start</span>
|
| 36 |
+
</div>
|
| 37 |
+
<div class="step-item">
|
| 38 |
+
<span class="step-number">4</span>
|
| 39 |
+
<span>Watch as our AI analyzes your paper in real-time</span>
|
| 40 |
+
</div>
|
| 41 |
+
<div class="step-item">
|
| 42 |
+
<span class="step-number">5</span>
|
| 43 |
+
<span>Get comprehensive, grounded feedback on your research</span>
|
| 44 |
+
</div>
|
| 45 |
+
</div>
|
| 46 |
+
""")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_upload_area() -> gr.File:
|
| 50 |
+
"""
|
| 51 |
+
Create the PDF upload area component.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
gr.File: The file upload component for PDF files
|
| 55 |
+
"""
|
| 56 |
+
pdf_input = gr.File(
|
| 57 |
+
label="",
|
| 58 |
+
file_types=[".pdf"],
|
| 59 |
+
type="filepath",
|
| 60 |
+
elem_classes=["upload-area", "file-upload-minimal"],
|
| 61 |
+
elem_id="pdf-upload",
|
| 62 |
+
show_label=False,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
gr.HTML("""
|
| 66 |
+
<div style="text-align: center; color: #9ca3af; font-size: 0.85em; margin-top: -10px; margin-bottom: 15px;">
|
| 67 |
+
PDF files only (max 10MB)
|
| 68 |
+
</div>
|
| 69 |
+
""")
|
| 70 |
+
|
| 71 |
+
return pdf_input
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def create_action_buttons() -> gr.Button:
|
| 75 |
+
"""
|
| 76 |
+
Create the action buttons for starting the review.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
gr.Button: The primary "Generate Review" button
|
| 80 |
+
"""
|
| 81 |
+
run_button = gr.Button(
|
| 82 |
+
"🚀 Generate AI Review",
|
| 83 |
+
variant="primary",
|
| 84 |
+
elem_classes=["primary-btn"],
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# # Placeholder buttons for future features (disabled)
|
| 88 |
+
# gr.HTML("""
|
| 89 |
+
# <button class="secondary-btn" disabled>
|
| 90 |
+
# 📊 DeepReviewer
|
| 91 |
+
# <span class="unavailable-badge">COMING SOON</span>
|
| 92 |
+
# </button>
|
| 93 |
+
# <button class="secondary-btn" disabled>
|
| 94 |
+
# 🛡️ SafeReviewer
|
| 95 |
+
# <span class="unavailable-badge">COMING SOON</span>
|
| 96 |
+
# </button>
|
| 97 |
+
# """)
|
| 98 |
+
|
| 99 |
+
return run_button
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def create_upload_section() -> Tuple[gr.File, gr.Button]:
|
| 103 |
+
"""
|
| 104 |
+
Create the complete upload section panel.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Tuple containing:
|
| 108 |
+
- pdf_input: The file upload component
|
| 109 |
+
- run_button: The generate review button
|
| 110 |
+
"""
|
| 111 |
+
gr.HTML('<div class="panel-title">📁 Upload Your Paper</div>')
|
| 112 |
+
|
| 113 |
+
create_how_it_works()
|
| 114 |
+
pdf_input = create_upload_area()
|
| 115 |
+
run_button = create_action_buttons()
|
| 116 |
+
|
| 117 |
+
return pdf_input, run_button
|
gradio_app/utils_single_paper_inference.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility wrappers and a minimal CLI for single-paper inference using:
|
| 3 |
+
- ASTA API key from environment variable `ASTA_API_KEY`
|
| 4 |
+
- OpenAI/OpenRouter endpoint and key from environment variables
|
| 5 |
+
|
| 6 |
+
This module is designed to be imported by Gradio or other web frontends,
|
| 7 |
+
while still remaining executable as a standalone CLI tool for debugging.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import json
|
| 15 |
+
from typing import Any, Dict, Iterator, Optional
|
| 16 |
+
|
| 17 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 18 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 20 |
+
|
| 21 |
+
from src.reviewer_agent.single_paper_inference import (
|
| 22 |
+
extract_text_from_pdf,
|
| 23 |
+
_split_paper_latex_sections,
|
| 24 |
+
_init_single_paper_pipeline,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
logging.basicConfig(level=logging.INFO)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_single_paper_review_from_pdf(
|
| 32 |
+
pdf_path: str,
|
| 33 |
+
*,
|
| 34 |
+
enable_logging: bool = True,
|
| 35 |
+
verbose: bool = True,
|
| 36 |
+
api_base_url: str | None = None,
|
| 37 |
+
api_key: str | None = None,
|
| 38 |
+
model_name: str | None = None,
|
| 39 |
+
) -> dict:
|
| 40 |
+
"""
|
| 41 |
+
High-level utility to run the single-paper review pipeline on a PDF path.
|
| 42 |
+
|
| 43 |
+
This is the main entry point intended to be called by Gradio or other UIs.
|
| 44 |
+
It delegates to `review_single_paper_from_pdf` which uses:
|
| 45 |
+
- ASTA API key from `ASTA_API_KEY`
|
| 46 |
+
- LLM settings and OpenAI/OpenRouter keys from environment/config files,
|
| 47 |
+
but can be overridden via `api_base_url`, `api_key`, and `model_name`.
|
| 48 |
+
"""
|
| 49 |
+
pdf_path = str(Path(pdf_path).expanduser())
|
| 50 |
+
logger.info(f"Running single-paper review for PDF: {pdf_path}")
|
| 51 |
+
# Keep the original one-shot behavior for backward compatibility.
|
| 52 |
+
# For true streaming updates, use `run_single_paper_review_from_pdf_stepwise`.
|
| 53 |
+
from src.reviewer_agent.single_paper_inference import review_single_paper_from_pdf
|
| 54 |
+
|
| 55 |
+
review = review_single_paper_from_pdf(
|
| 56 |
+
pdf_path,
|
| 57 |
+
enable_logging=enable_logging,
|
| 58 |
+
verbose=verbose,
|
| 59 |
+
gpt_api_key=api_key,
|
| 60 |
+
gpt_base_url=api_base_url,
|
| 61 |
+
gpt_model_name=model_name,
|
| 62 |
+
)
|
| 63 |
+
return review
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _normalize_base_url(base_url: Optional[str]) -> Optional[str]:
|
| 67 |
+
"""
|
| 68 |
+
Normalize an OpenAI-compatible base_url.
|
| 69 |
+
|
| 70 |
+
Your local gateway expects requests at:
|
| 71 |
+
http://localhost:8000/chat/completions
|
| 72 |
+
|
| 73 |
+
The OpenAI client will append `/chat/completions` to whatever `base_url`
|
| 74 |
+
we pass in. That means:
|
| 75 |
+
base_url = "http://localhost:8000"
|
| 76 |
+
-> "http://localhost:8000/chat/completions" ✅
|
| 77 |
+
|
| 78 |
+
If a user accidentally includes `/chat/completions` in the textbox, we
|
| 79 |
+
strip that suffix so the final URL is still correct.
|
| 80 |
+
"""
|
| 81 |
+
if not base_url:
|
| 82 |
+
return None
|
| 83 |
+
u = base_url.strip()
|
| 84 |
+
if not u:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
# Strip trailing slash for normalization
|
| 88 |
+
u = u.rstrip("/")
|
| 89 |
+
|
| 90 |
+
# If user pasted the full path (…/chat/completions), strip it back to the host.
|
| 91 |
+
if u.endswith("/chat/completions"):
|
| 92 |
+
u = u[: -len("/chat/completions")]
|
| 93 |
+
|
| 94 |
+
return u
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def run_single_paper_review_from_pdf_stepwise(
|
| 98 |
+
pdf_path: str,
|
| 99 |
+
*,
|
| 100 |
+
enable_logging: bool = True,
|
| 101 |
+
verbose: bool = True,
|
| 102 |
+
api_base_url: str | None = None,
|
| 103 |
+
api_key: str | None = None,
|
| 104 |
+
model_name: str | None = None,
|
| 105 |
+
) -> Iterator[Dict[str, Any]]:
|
| 106 |
+
"""
|
| 107 |
+
Stepwise (streamable) single-paper pipeline.
|
| 108 |
+
|
| 109 |
+
Yields dict events like:
|
| 110 |
+
{"stage": "extract_pdf", ...}
|
| 111 |
+
{"stage": "initial_review", "initial_review": {...}}
|
| 112 |
+
{"stage": "results_analysis", "results_analyzer_json": "..."}
|
| 113 |
+
{"stage": "insights", "insight_miner_json": "..."}
|
| 114 |
+
{"stage": "related_work", "related_work_json_list": [...], "search_keywords": [...]}
|
| 115 |
+
{"stage": "final", "review": {...}}
|
| 116 |
+
"""
|
| 117 |
+
pdf_path = str(Path(pdf_path).expanduser())
|
| 118 |
+
yield {"stage": "extract_pdf", "pdf_path": pdf_path}
|
| 119 |
+
|
| 120 |
+
paper_text = extract_text_from_pdf(pdf_path)
|
| 121 |
+
yield {"stage": "parsed_pdf_text", "text_len": len(paper_text)}
|
| 122 |
+
|
| 123 |
+
sections = _split_paper_latex_sections(paper_text)
|
| 124 |
+
title = (sections.get("title") or "").strip()
|
| 125 |
+
abstract = (sections.get("abstract") or "").strip()
|
| 126 |
+
content = (sections.get("content") or "").strip()
|
| 127 |
+
yield {"stage": "parsed_sections", "title": title, "abstract": abstract}
|
| 128 |
+
|
| 129 |
+
reviewer, refiner, related_work_searcher, paper_results_analyzer, paper_insight_miner = (
|
| 130 |
+
_init_single_paper_pipeline(
|
| 131 |
+
enable_logging=enable_logging,
|
| 132 |
+
use_test_llm=False,
|
| 133 |
+
gpt_api_key=api_key,
|
| 134 |
+
gpt_base_url=api_base_url,
|
| 135 |
+
gpt_model_name=model_name,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Step 1: initial draft (reviewer)
|
| 140 |
+
initial_review = reviewer.review_paper(
|
| 141 |
+
title=title,
|
| 142 |
+
abstract=abstract,
|
| 143 |
+
content=content,
|
| 144 |
+
keywords=None,
|
| 145 |
+
review_format="ai_researcher",
|
| 146 |
+
auto_save_log=False,
|
| 147 |
+
verbose=verbose,
|
| 148 |
+
)
|
| 149 |
+
yield {"stage": "initial_review", "initial_review": initial_review}
|
| 150 |
+
|
| 151 |
+
# Helper: format initial review for analyzers.
|
| 152 |
+
try:
|
| 153 |
+
initial_review_text = (
|
| 154 |
+
refiner._format_review_dict(initial_review, "detailed")
|
| 155 |
+
if hasattr(refiner, "_format_review_dict")
|
| 156 |
+
else str(initial_review)
|
| 157 |
+
)
|
| 158 |
+
except Exception:
|
| 159 |
+
initial_review_text = str(initial_review)
|
| 160 |
+
|
| 161 |
+
# Step 2a: results analyzer
|
| 162 |
+
results_analyzer_json = None
|
| 163 |
+
if paper_results_analyzer and content:
|
| 164 |
+
try:
|
| 165 |
+
results_analyzer_json = paper_results_analyzer.analyze_paper_results(
|
| 166 |
+
content, initial_review_text
|
| 167 |
+
)
|
| 168 |
+
except Exception as e:
|
| 169 |
+
results_analyzer_json = None
|
| 170 |
+
yield {"stage": "results_analysis_error", "error": str(e)}
|
| 171 |
+
yield {"stage": "results_analysis", "results_analyzer_json": results_analyzer_json}
|
| 172 |
+
|
| 173 |
+
# Step 2b: insight miner
|
| 174 |
+
insight_miner_json = None
|
| 175 |
+
if paper_insight_miner and content:
|
| 176 |
+
try:
|
| 177 |
+
insight_miner_json = paper_insight_miner.mine_paper_insights(
|
| 178 |
+
content, initial_review_text
|
| 179 |
+
)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
insight_miner_json = None
|
| 182 |
+
yield {"stage": "insights_error", "error": str(e)}
|
| 183 |
+
yield {"stage": "insights", "insight_miner_json": insight_miner_json}
|
| 184 |
+
|
| 185 |
+
# Step 2c: related work (structured list)
|
| 186 |
+
related_work_list = []
|
| 187 |
+
search_keywords = None
|
| 188 |
+
if related_work_searcher:
|
| 189 |
+
try:
|
| 190 |
+
related_work_list = related_work_searcher.generate_related_work_json_list(
|
| 191 |
+
title=title,
|
| 192 |
+
abstract=abstract,
|
| 193 |
+
content=content,
|
| 194 |
+
keywords=None,
|
| 195 |
+
publication_date_range=None,
|
| 196 |
+
venues=None,
|
| 197 |
+
)
|
| 198 |
+
search_keywords = getattr(related_work_searcher, "last_keywords", None)
|
| 199 |
+
except Exception as e:
|
| 200 |
+
related_work_list = []
|
| 201 |
+
yield {"stage": "related_work_error", "error": str(e)}
|
| 202 |
+
yield {
|
| 203 |
+
"stage": "related_work",
|
| 204 |
+
"related_work_json_list": related_work_list,
|
| 205 |
+
"search_keywords": search_keywords,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# Step 3: refine (final)
|
| 209 |
+
related_work_json_str = json.dumps(related_work_list, ensure_ascii=False)
|
| 210 |
+
refined = refiner.refine_review(
|
| 211 |
+
initial_review=initial_review,
|
| 212 |
+
insight_miner_json=insight_miner_json,
|
| 213 |
+
results_analyzer_json=results_analyzer_json,
|
| 214 |
+
related_work_json_list=related_work_json_str,
|
| 215 |
+
title=title,
|
| 216 |
+
abstract=abstract,
|
| 217 |
+
content=content,
|
| 218 |
+
review_format="detailed",
|
| 219 |
+
verbose=verbose,
|
| 220 |
+
)
|
| 221 |
+
if search_keywords is not None:
|
| 222 |
+
refined["search_keywords"] = search_keywords
|
| 223 |
+
|
| 224 |
+
yield {"stage": "final", "review": refined}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def main() -> None:
|
| 228 |
+
"""
|
| 229 |
+
Simple CLI wrapper mainly for local debugging.
|
| 230 |
+
|
| 231 |
+
Example:
|
| 232 |
+
python -m gradio.test_single_paper_openrouter \\
|
| 233 |
+
--pdf /path/to/paper.pdf
|
| 234 |
+
"""
|
| 235 |
+
parser = argparse.ArgumentParser(
|
| 236 |
+
description="Run single-paper review on a PDF file."
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--pdf",
|
| 240 |
+
type=str,
|
| 241 |
+
required=True,
|
| 242 |
+
help="Path to the PDF file to review.",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--no-logging",
|
| 246 |
+
action="store_true",
|
| 247 |
+
help="Disable on-disk logging for this run.",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--quiet",
|
| 251 |
+
action="store_true",
|
| 252 |
+
help="Reduce console output verbosity.",
|
| 253 |
+
)
|
| 254 |
+
args = parser.parse_args()
|
| 255 |
+
|
| 256 |
+
from time import time
|
| 257 |
+
|
| 258 |
+
start_time = time()
|
| 259 |
+
review = run_single_paper_review_from_pdf(
|
| 260 |
+
args.pdf,
|
| 261 |
+
enable_logging=not args.no_logging,
|
| 262 |
+
verbose=not args.quiet,
|
| 263 |
+
)
|
| 264 |
+
end_time = time()
|
| 265 |
+
|
| 266 |
+
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
| 267 |
+
print("\n=== Review keys ===")
|
| 268 |
+
print(list(review.keys()))
|
| 269 |
+
print("\n=== Review Markdown ===")
|
| 270 |
+
print(review.get("review_markdown", ""))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
main()
|
| 275 |
+
|
| 276 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Unified Review System Requirements
|
| 2 |
+
|
| 3 |
+
# Core dependencies
|
| 4 |
+
pandas>=1.5.0
|
| 5 |
+
pyarrow>=10.0.0
|
| 6 |
+
pyyaml>=6.0.0
|
| 7 |
+
python-dotenv>=1.0.0
|
| 8 |
+
|
| 9 |
+
# LLM and ML dependencies
|
| 10 |
+
vllm>=0.6.0
|
| 11 |
+
openai>=1.0.0
|
| 12 |
+
transformers>=4.35.0
|
| 13 |
+
torch>=2.0.0
|
| 14 |
+
numpy>=1.24.0
|
| 15 |
+
FlagEmbedding>=1.2.0
|
| 16 |
+
|
| 17 |
+
# Utilities
|
| 18 |
+
requests>=2.31.0
|
| 19 |
+
pydantic>=2.0.0
|
| 20 |
+
tqdm>=4.66.0
|
| 21 |
+
scipy>=1.10.0
|
| 22 |
+
scikit-learn>=1.3.0
|
| 23 |
+
|
| 24 |
+
# PDF + UI (for Gradio / Hugging Face Space)
|
| 25 |
+
pdfminer.six>=20221105
|
| 26 |
+
gradio>=4.0.0
|
scripts/gpt_oss_start_vllm_service.sh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Script to start vLLM service for Qwen3-235B-A22B-Instruct-2507
|
| 2 |
+
|
| 3 |
+
# optional: limit GPU usage
|
| 4 |
+
# export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 6 |
+
|
| 7 |
+
# Configuration
|
| 8 |
+
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Instruct-2507"
|
| 9 |
+
MODEL_NAME="openai/gpt-oss-120b"
|
| 10 |
+
PORT=${VLLM_PORT:-8000}
|
| 11 |
+
TP_SIZE=${TP_SIZE:-4} # Tensor parallelism size, smaller or equal to the number of available GPUs
|
| 12 |
+
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.85} # ideally 0.85
|
| 13 |
+
MAX_MODEL_LEN=${MAX_MODEL_LEN:-131072} # Native context length, can extend to 1010000
|
| 14 |
+
|
| 15 |
+
# Check if model path is provided
|
| 16 |
+
if [ -z "$MODEL_PATH" ]; then
|
| 17 |
+
MODEL_PATH="$MODEL_NAME"
|
| 18 |
+
echo "Using HuggingFace model: $MODEL_PATH"
|
| 19 |
+
else
|
| 20 |
+
echo "Using local model: $MODEL_PATH"
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
echo "Starting vLLM service..."
|
| 24 |
+
echo "Model: $MODEL_PATH"
|
| 25 |
+
echo "Port: $PORT"
|
| 26 |
+
echo "Tensor Parallelism: $TP_SIZE"
|
| 27 |
+
echo "GPU Memory Utilization: $GPU_MEMORY_UTILIZATION"
|
| 28 |
+
echo "Max Model Length: $MAX_MODEL_LEN"
|
| 29 |
+
|
| 30 |
+
# python3 -m vllm.entrypoints.openai.api_server \
|
| 31 |
+
# --model "$MODEL_PATH" \
|
| 32 |
+
# --port $PORT \
|
| 33 |
+
# --tensor-parallel-size $TP_SIZE \
|
| 34 |
+
# --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
| 35 |
+
# --max-model-len $MAX_MODEL_LEN \
|
| 36 |
+
# --trust-remote-code \
|
| 37 |
+
# # --dtype bfloat16
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
vllm serve openai/gpt-oss-120b \
|
| 41 |
+
--port $PORT \
|
| 42 |
+
--tensor-parallel-size $TP_SIZE \
|
| 43 |
+
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
| 44 |
+
--max-model-len $MAX_MODEL_LEN \
|
| 45 |
+
--trust-remote-code \
|
| 46 |
+
--dtype bfloat16
|
| 47 |
+
|
scripts/start_load_balancer.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Start Python Load Balancer for vLLM and Reranker services
|
| 3 |
+
# Usage: ./scripts/start_load_balancer.sh [service_type] [num_instances] [base_port] [lb_port]
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
SERVICE_TYPE="${1:-vllm}" # vllm or reranker
|
| 8 |
+
NUM_INSTANCES="${2:-4}"
|
| 9 |
+
BASE_PORT="${3:-8000}"
|
| 10 |
+
LB_PORT="${4:-$BASE_PORT}"
|
| 11 |
+
|
| 12 |
+
echo "Starting Load Balancer for $SERVICE_TYPE"
|
| 13 |
+
echo "Number of instances: $NUM_INSTANCES"
|
| 14 |
+
echo "Base port: $BASE_PORT"
|
| 15 |
+
echo "Load balancer port: $LB_PORT"
|
| 16 |
+
echo ""
|
| 17 |
+
|
| 18 |
+
# Get script directory
|
| 19 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 20 |
+
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
| 21 |
+
cd "$PROJECT_ROOT"
|
| 22 |
+
|
| 23 |
+
# Activate virtual environment if it exists
|
| 24 |
+
if [ -d ".venv" ]; then
|
| 25 |
+
source .venv/bin/activate
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
# Check if FastAPI is installed
|
| 29 |
+
python3 -c "import fastapi" 2>/dev/null || {
|
| 30 |
+
echo "Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx"
|
| 31 |
+
exit 1
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# Build backend list
|
| 35 |
+
BACKENDS=()
|
| 36 |
+
for i in $(seq 0 $((NUM_INSTANCES - 1))); do
|
| 37 |
+
PORT=$((BASE_PORT + i))
|
| 38 |
+
if [ "$SERVICE_TYPE" = "vllm" ]; then
|
| 39 |
+
BACKENDS+=("http://localhost:${PORT}/v1")
|
| 40 |
+
else
|
| 41 |
+
BACKENDS+=("http://localhost:${PORT}")
|
| 42 |
+
fi
|
| 43 |
+
done
|
| 44 |
+
|
| 45 |
+
# Create logs directory based on service type
|
| 46 |
+
if [ "$SERVICE_TYPE" = "vllm" ]; then
|
| 47 |
+
LB_LOG_DIR="logs/vllm"
|
| 48 |
+
else
|
| 49 |
+
LB_LOG_DIR="logs/reranker"
|
| 50 |
+
fi
|
| 51 |
+
mkdir -p "$LB_LOG_DIR"
|
| 52 |
+
|
| 53 |
+
echo "Backends:"
|
| 54 |
+
for backend in "${BACKENDS[@]}"; do
|
| 55 |
+
echo " - $backend"
|
| 56 |
+
done
|
| 57 |
+
echo ""
|
| 58 |
+
|
| 59 |
+
# Start load balancer
|
| 60 |
+
echo "Starting load balancer..."
|
| 61 |
+
python3 -m shared.utils.load_balancer \
|
| 62 |
+
--backends "${BACKENDS[@]}" \
|
| 63 |
+
--host 0.0.0.0 \
|
| 64 |
+
--port "$LB_PORT" \
|
| 65 |
+
--strategy round_robin \
|
| 66 |
+
--health-check-interval 10.0 \
|
| 67 |
+
> "${LB_LOG_DIR}/load_balancer_${SERVICE_TYPE}_port${LB_PORT}.log" 2>&1 &
|
| 68 |
+
|
| 69 |
+
LB_PID=$!
|
| 70 |
+
|
| 71 |
+
# Save PID to file based on service type
|
| 72 |
+
if [ "$SERVICE_TYPE" = "vllm" ]; then
|
| 73 |
+
PID_FILE="logs/vllm/vllm_lb_pid.txt"
|
| 74 |
+
mkdir -p logs/vllm
|
| 75 |
+
else
|
| 76 |
+
PID_FILE="logs/reranker/reranker_lb_pid.txt"
|
| 77 |
+
mkdir -p logs/reranker
|
| 78 |
+
fi
|
| 79 |
+
echo "$LB_PID" > "$PID_FILE"
|
| 80 |
+
|
| 81 |
+
echo "Load balancer started with PID: $LB_PID"
|
| 82 |
+
echo "Load balancer URL: http://localhost:${LB_PORT}"
|
| 83 |
+
echo "PID saved to: $PID_FILE"
|
| 84 |
+
echo ""
|
| 85 |
+
echo "To check status: curl http://localhost:${LB_PORT}/health"
|
| 86 |
+
echo "To stop: ./scripts/stop_vllm_services.sh (for vllm) or ./scripts/stop_reranker_services.sh (for reranker)"
|
| 87 |
+
echo "Or manually: kill $LB_PID"
|
scripts/start_reranker_service.sh
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Start Reranker API Service on multiple GPUs
|
| 3 |
+
# Usage: ./scripts/start_reranker_service.sh [model_path] [num_gpus] [base_port]
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Default values
|
| 8 |
+
# MODEL_PATH="${1:-OpenScholar/OpenScholar_Reranker}"
|
| 9 |
+
# NUM_GPUS="${2:-4}"
|
| 10 |
+
# BASE_PORT="${3:-8005}"
|
| 11 |
+
|
| 12 |
+
# MODEL_PATH="BAAI/bge-reranker-base"
|
| 13 |
+
# MODEL_PATH="BAAI/bge-reranker-large"
|
| 14 |
+
MODEL_PATH="${1:-OpenScholar/OpenScholar_Reranker}"
|
| 15 |
+
NUM_GPUS=8
|
| 16 |
+
BASE_PORT=8008
|
| 17 |
+
|
| 18 |
+
echo "Starting Reranker API Service"
|
| 19 |
+
echo "Model: $MODEL_PATH"
|
| 20 |
+
echo "Number of GPUs: $NUM_GPUS"
|
| 21 |
+
echo "Base port: $BASE_PORT"
|
| 22 |
+
echo ""
|
| 23 |
+
|
| 24 |
+
# Get script directory
|
| 25 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 26 |
+
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
| 27 |
+
cd "$PROJECT_ROOT"
|
| 28 |
+
|
| 29 |
+
# Activate virtual environment if it exists
|
| 30 |
+
if [ -d ".venv" ]; then
|
| 31 |
+
source .venv/bin/activate
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# Check if FastAPI is installed
|
| 35 |
+
python3 -c "import fastapi" 2>/dev/null || {
|
| 36 |
+
echo "Error: FastAPI not installed. Install with: pip install fastapi uvicorn"
|
| 37 |
+
exit 1
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Check if FlagEmbedding is installed
|
| 41 |
+
python3 -c "from FlagEmbedding import FlagReranker" 2>/dev/null || {
|
| 42 |
+
echo "Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding"
|
| 43 |
+
exit 1
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Create logs directory
|
| 47 |
+
mkdir -p logs
|
| 48 |
+
|
| 49 |
+
# PID file for stopping services later
|
| 50 |
+
PID_FILE="logs/reranker/reranker_pids.txt"
|
| 51 |
+
LB_PID_FILE="logs/reranker/reranker_lb_pid.txt"
|
| 52 |
+
|
| 53 |
+
# Start services on each GPU
|
| 54 |
+
PIDS=()
|
| 55 |
+
ENDPOINTS=()
|
| 56 |
+
|
| 57 |
+
# Ensure we use GPUs 0, 1, 2, 3 (explicitly)
|
| 58 |
+
|
| 59 |
+
# now we use gpus: 1,2,3,4,5,6,7
|
| 60 |
+
for i in $(seq 0 $((NUM_GPUS - 1))); do
|
| 61 |
+
PORT=$((BASE_PORT + i))
|
| 62 |
+
GPU_ID=$i # Use GPU 0, 1, 2, 3 explicitly
|
| 63 |
+
|
| 64 |
+
echo "Starting reranker service on GPU $GPU_ID, port $PORT..."
|
| 65 |
+
|
| 66 |
+
# Set CUDA device (each service will see only one GPU)
|
| 67 |
+
export CUDA_VISIBLE_DEVICES=$GPU_ID
|
| 68 |
+
|
| 69 |
+
# Start service in background
|
| 70 |
+
# Note: When CUDA_VISIBLE_DEVICES is set, cuda:0 refers to the visible GPU
|
| 71 |
+
nohup python3 -m shared.utils.reranker_api_service \
|
| 72 |
+
--model_path "$MODEL_PATH" \
|
| 73 |
+
--host 0.0.0.0 \
|
| 74 |
+
--port "$PORT" \
|
| 75 |
+
--use_fp16 \
|
| 76 |
+
--device "cuda:0" \
|
| 77 |
+
> "logs/reranker/reranker_service_gpu${GPU_ID}_port${PORT}.log" 2>&1 &
|
| 78 |
+
|
| 79 |
+
PID=$!
|
| 80 |
+
PIDS+=($PID)
|
| 81 |
+
ENDPOINTS+=("http://localhost:${PORT}")
|
| 82 |
+
|
| 83 |
+
echo " Started with PID: $PID"
|
| 84 |
+
echo " Endpoint: http://localhost:${PORT}"
|
| 85 |
+
sleep 2 # Give service time to start
|
| 86 |
+
done
|
| 87 |
+
|
| 88 |
+
echo ""
|
| 89 |
+
echo "All reranker services started!"
|
| 90 |
+
echo ""
|
| 91 |
+
echo "Endpoints:"
|
| 92 |
+
for endpoint in "${ENDPOINTS[@]}"; do
|
| 93 |
+
echo " - $endpoint"
|
| 94 |
+
done
|
| 95 |
+
|
| 96 |
+
# Create endpoint pool file
|
| 97 |
+
ENDPOINT_POOL_FILE="shared/configs/reranker_endpoint_pool.txt"
|
| 98 |
+
mkdir -p "$(dirname "$ENDPOINT_POOL_FILE")"
|
| 99 |
+
printf "%s\n" "${ENDPOINTS[@]}" > "$ENDPOINT_POOL_FILE"
|
| 100 |
+
echo ""
|
| 101 |
+
echo "Endpoint pool file created: $ENDPOINT_POOL_FILE"
|
| 102 |
+
|
| 103 |
+
# Save PIDs to file (one per line)
|
| 104 |
+
printf "%s\n" "${PIDS[@]}" > "$PID_FILE"
|
| 105 |
+
echo ""
|
| 106 |
+
echo "PIDs saved to: $PID_FILE"
|
| 107 |
+
echo ""
|
| 108 |
+
echo "To stop these specific reranker services, run:"
|
| 109 |
+
echo " ./scripts/stop_reranker_services.sh"
|
| 110 |
+
echo ""
|
| 111 |
+
echo "This will only kill the processes listed above, not other reranker services."
|
| 112 |
+
echo ""
|
| 113 |
+
echo "To check service status, run:"
|
| 114 |
+
for endpoint in "${ENDPOINTS[@]}"; do
|
| 115 |
+
echo "curl $endpoint/health"
|
| 116 |
+
done
|
scripts/start_vllm_with_balancer.sh
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to start vLLM services on GPU 4,5,6,7 and load balancer
|
| 3 |
+
# Usage: ./scripts/start_vllm_with_balancer.sh
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Get script directory
|
| 8 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 9 |
+
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
| 10 |
+
cd "$PROJECT_ROOT"
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
# MODEL_NAME="ZhuofengLi/Qwen3-4B-Instruct-2507-DeepReview-lora-sft" #
|
| 14 |
+
MODEL_NAME="openai/gpt-oss-120b"
|
| 15 |
+
# GPU_CONFIG="2:8001,3:8002,4:8003,5:8004" # GPU:PORT pairs
|
| 16 |
+
# GPU_CONFIG="1:8001,2:8002,3:8003,4:8004, 5:8005, 6:8006, 7:8007, 0:7999" # GPU:PORT pairs
|
| 17 |
+
GPU_CONFIG="0:8001,1:8002,2:8003,3:8004, 5:8005, 6:8006, 7:8007" # GPU:PORT pairs
|
| 18 |
+
TP_SIZE=1 # Tensor parallelism size per instance
|
| 19 |
+
GPU_MEMORY_UTILIZATION=0.85
|
| 20 |
+
MAX_MODEL_LEN=131072
|
| 21 |
+
|
| 22 |
+
# Load balancer configuration
|
| 23 |
+
LB_PORT=8000 # Load balancer port
|
| 24 |
+
LB_STRATEGY="round_robin" # or "least_conn"
|
| 25 |
+
LB_HEALTH_CHECK_INTERVAL=10.0
|
| 26 |
+
|
| 27 |
+
# Log directory
|
| 28 |
+
LOG_DIR="./logs/vllm"
|
| 29 |
+
mkdir -p "$LOG_DIR"
|
| 30 |
+
|
| 31 |
+
# Endpoint pool file
|
| 32 |
+
ENDPOINT_POOL_FILE="shared/configs/vllm_endpoint_pool.txt"
|
| 33 |
+
mkdir -p "$(dirname "$ENDPOINT_POOL_FILE")"
|
| 34 |
+
|
| 35 |
+
echo "=========================================="
|
| 36 |
+
echo "Starting vLLM Services + Load Balancer"
|
| 37 |
+
echo "=========================================="
|
| 38 |
+
echo "Model: $MODEL_NAME"
|
| 39 |
+
echo "GPU Configuration: $GPU_CONFIG"
|
| 40 |
+
echo "Load Balancer Port: $LB_PORT"
|
| 41 |
+
echo "Log Directory: $LOG_DIR"
|
| 42 |
+
echo ""
|
| 43 |
+
|
| 44 |
+
# Step 1: Start vLLM services
|
| 45 |
+
echo "=== Step 1: Starting vLLM services ==="
|
| 46 |
+
echo ""
|
| 47 |
+
|
| 48 |
+
# Clear existing endpoints
|
| 49 |
+
> "$ENDPOINT_POOL_FILE"
|
| 50 |
+
|
| 51 |
+
# Parse GPU configuration
|
| 52 |
+
IFS=',' read -ra GPU_CONFIGS <<< "$GPU_CONFIG"
|
| 53 |
+
|
| 54 |
+
# Array to store PIDs
|
| 55 |
+
VLLM_PIDS=()
|
| 56 |
+
|
| 57 |
+
for gpu_config in "${GPU_CONFIGS[@]}"; do
|
| 58 |
+
IFS=':' read -r gpu_id port <<< "$gpu_config"
|
| 59 |
+
|
| 60 |
+
echo "Starting vLLM on GPU $gpu_id, port $port..."
|
| 61 |
+
|
| 62 |
+
# Set CUDA_VISIBLE_DEVICES for this specific GPU
|
| 63 |
+
export CUDA_VISIBLE_DEVICES=$gpu_id
|
| 64 |
+
|
| 65 |
+
# Log file
|
| 66 |
+
LOG_FILE="$LOG_DIR/vllm_gpu${gpu_id}_port${port}.log"
|
| 67 |
+
|
| 68 |
+
# Start vLLM service in background
|
| 69 |
+
(
|
| 70 |
+
echo "=== GPU $gpu_id, Port $port ===" >> "$LOG_FILE"
|
| 71 |
+
echo "Starting at $(date)" >> "$LOG_FILE"
|
| 72 |
+
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" >> "$LOG_FILE"
|
| 73 |
+
echo "" >> "$LOG_FILE"
|
| 74 |
+
|
| 75 |
+
vllm serve "$MODEL_NAME" \
|
| 76 |
+
--port "$port" \
|
| 77 |
+
--tensor-parallel-size "$TP_SIZE" \
|
| 78 |
+
--gpu-memory-utilization "$GPU_MEMORY_UTILIZATION" \
|
| 79 |
+
--max-model-len "$MAX_MODEL_LEN" \
|
| 80 |
+
--trust-remote-code \
|
| 81 |
+
--dtype bfloat16 \
|
| 82 |
+
>> "$LOG_FILE" 2>&1
|
| 83 |
+
) &
|
| 84 |
+
|
| 85 |
+
PID=$!
|
| 86 |
+
VLLM_PIDS+=($PID)
|
| 87 |
+
|
| 88 |
+
# Add endpoint to pool file (for load balancer)
|
| 89 |
+
echo "http://localhost:$port/v1" >> "$ENDPOINT_POOL_FILE"
|
| 90 |
+
|
| 91 |
+
echo " -> Started with PID $PID"
|
| 92 |
+
echo " -> Endpoint: http://localhost:$port/v1"
|
| 93 |
+
echo " -> Log: $LOG_FILE"
|
| 94 |
+
|
| 95 |
+
# Wait a bit before starting next service
|
| 96 |
+
sleep 3
|
| 97 |
+
done
|
| 98 |
+
|
| 99 |
+
# Save PIDs (one per line for easier parsing)
|
| 100 |
+
printf "%s\n" "${VLLM_PIDS[@]}" > "$LOG_DIR/vllm_pids.txt"
|
| 101 |
+
echo ""
|
| 102 |
+
echo "vLLM service PIDs saved to: $LOG_DIR/vllm_pids.txt"
|
| 103 |
+
echo ""
|
| 104 |
+
|
| 105 |
+
# Step 2: Wait for services to be ready
|
| 106 |
+
echo "=== Step 2: Waiting for vLLM services to be ready ==="
|
| 107 |
+
echo "Waiting 90 seconds for services to initialize..."
|
| 108 |
+
sleep 90
|
| 109 |
+
|
| 110 |
+
# Check service health
|
| 111 |
+
echo ""
|
| 112 |
+
echo "Checking service health..."
|
| 113 |
+
HEALTHY_COUNT=0
|
| 114 |
+
for gpu_config in "${GPU_CONFIGS[@]}"; do
|
| 115 |
+
IFS=':' read -r gpu_id port <<< "$gpu_config"
|
| 116 |
+
if curl -s "http://localhost:$port/v1/models" > /dev/null 2>&1; then
|
| 117 |
+
echo " GPU $gpu_id (port $port): HEALTHY"
|
| 118 |
+
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
|
| 119 |
+
else
|
| 120 |
+
echo " GPU $gpu_id (port $port): NOT READY (may still be initializing)"
|
| 121 |
+
fi
|
| 122 |
+
done
|
| 123 |
+
|
| 124 |
+
if [ $HEALTHY_COUNT -eq 0 ]; then
|
| 125 |
+
echo ""
|
| 126 |
+
echo "WARNING: No services are healthy yet. They may still be loading the model."
|
| 127 |
+
echo "You can check logs in $LOG_DIR/ for progress."
|
| 128 |
+
fi
|
| 129 |
+
|
| 130 |
+
echo ""
|
| 131 |
+
|
| 132 |
+
# Step 3: Start load balancer
|
| 133 |
+
echo "=== Step 3: Starting Load Balancer ==="
|
| 134 |
+
echo ""
|
| 135 |
+
|
| 136 |
+
# Build backend URLs
|
| 137 |
+
BACKEND_URLS=()
|
| 138 |
+
for gpu_config in "${GPU_CONFIGS[@]}"; do
|
| 139 |
+
IFS=':' read -r gpu_id port <<< "$gpu_config"
|
| 140 |
+
BACKEND_URLS+=("http://localhost:$port/v1")
|
| 141 |
+
done
|
| 142 |
+
|
| 143 |
+
echo "Load Balancer Configuration:"
|
| 144 |
+
echo " Port: $LB_PORT"
|
| 145 |
+
echo " Strategy: $LB_STRATEGY"
|
| 146 |
+
echo " Backends: ${BACKEND_URLS[*]}"
|
| 147 |
+
echo ""
|
| 148 |
+
|
| 149 |
+
# Activate virtual environment if it exists
|
| 150 |
+
if [ -d ".venv" ]; then
|
| 151 |
+
source .venv/bin/activate
|
| 152 |
+
fi
|
| 153 |
+
|
| 154 |
+
# Check if FastAPI is installed
|
| 155 |
+
python3 -c "import fastapi" 2>/dev/null || {
|
| 156 |
+
echo "Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx"
|
| 157 |
+
exit 1
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Start load balancer in background
|
| 161 |
+
echo "Starting load balancer..."
|
| 162 |
+
nohup python3 -m shared.utils.load_balancer \
|
| 163 |
+
--backends "${BACKEND_URLS[@]}" \
|
| 164 |
+
--host 0.0.0.0 \
|
| 165 |
+
--port "$LB_PORT" \
|
| 166 |
+
--strategy "$LB_STRATEGY" \
|
| 167 |
+
--health-check-interval "$LB_HEALTH_CHECK_INTERVAL" \
|
| 168 |
+
> "$LOG_DIR/load_balancer_port${LB_PORT}.log" 2>&1 &
|
| 169 |
+
|
| 170 |
+
LB_PID=$!
|
| 171 |
+
|
| 172 |
+
# Save load balancer PID
|
| 173 |
+
echo "$LB_PID" > "$LOG_DIR/vllm_lb_pid.txt"
|
| 174 |
+
|
| 175 |
+
echo " -> Load balancer started with PID $LB_PID"
|
| 176 |
+
echo " -> Endpoint: http://localhost:$LB_PORT"
|
| 177 |
+
echo " -> Log: $LOG_DIR/load_balancer_port${LB_PORT}.log"
|
| 178 |
+
echo " -> PID saved to: $LOG_DIR/vllm_lb_pid.txt"
|
| 179 |
+
|
| 180 |
+
# Wait a bit for load balancer to start
|
| 181 |
+
sleep 5
|
| 182 |
+
|
| 183 |
+
# Check load balancer health
|
| 184 |
+
echo ""
|
| 185 |
+
echo "Checking load balancer health..."
|
| 186 |
+
if curl -s "http://localhost:$LB_PORT/health" > /dev/null 2>&1; then
|
| 187 |
+
echo " Load balancer: HEALTHY"
|
| 188 |
+
curl -s "http://localhost:$LB_PORT/health" | python3 -m json.tool 2>/dev/null || curl -s "http://localhost:$LB_PORT/health"
|
| 189 |
+
else
|
| 190 |
+
echo " Load balancer: NOT READY (check log: $LOG_DIR/load_balancer_port${LB_PORT}.log)"
|
| 191 |
+
fi
|
| 192 |
+
|
| 193 |
+
echo ""
|
| 194 |
+
echo "=========================================="
|
| 195 |
+
echo "Deployment Complete!"
|
| 196 |
+
echo "=========================================="
|
| 197 |
+
echo ""
|
| 198 |
+
echo "vLLM Services:"
|
| 199 |
+
for i in "${!GPU_CONFIGS[@]}"; do
|
| 200 |
+
gpu_config="${GPU_CONFIGS[$i]}"
|
| 201 |
+
IFS=':' read -r gpu_id port <<< "$gpu_config"
|
| 202 |
+
PID="${VLLM_PIDS[$i]}"
|
| 203 |
+
echo " GPU $gpu_id: http://localhost:$port/v1 (PID: $PID)"
|
| 204 |
+
done
|
| 205 |
+
echo ""
|
| 206 |
+
echo "Load Balancer:"
|
| 207 |
+
echo " http://localhost:$LB_PORT (PID: $LB_PID)"
|
| 208 |
+
echo ""
|
| 209 |
+
echo "Configuration:"
|
| 210 |
+
echo " Update llm_service_config.yaml: base_url: \"http://localhost:$LB_PORT/v1\""
|
| 211 |
+
echo ""
|
| 212 |
+
echo "To stop these specific services, run:"
|
| 213 |
+
echo " ./scripts/stop_vllm_services.sh"
|
| 214 |
+
echo ""
|
| 215 |
+
echo "This will only kill the processes listed above, not other vLLM services."
|
| 216 |
+
echo ""
|
scripts/stop_reranker_services.sh
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to stop reranker services and load balancer (only the ones we started)
|
| 3 |
+
# Usage: ./scripts/stop_reranker_services.sh
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
# Get script directory
|
| 8 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 9 |
+
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
| 10 |
+
cd "$PROJECT_ROOT"
|
| 11 |
+
|
| 12 |
+
LOG_DIR="./logs/reranker"
|
| 13 |
+
PID_FILE="$LOG_DIR/reranker_pids.txt"
|
| 14 |
+
LB_PID_FILE="$LOG_DIR/reranker_lb_pid.txt"
|
| 15 |
+
|
| 16 |
+
echo "=== Stopping Reranker Services and Load Balancer ==="
|
| 17 |
+
echo ""
|
| 18 |
+
|
| 19 |
+
# Step 1: Stop load balancer (if PID file exists)
|
| 20 |
+
echo "Step 1: Stopping reranker load balancer..."
|
| 21 |
+
if [ -f "$LB_PID_FILE" ]; then
|
| 22 |
+
LB_PID=$(cat "$LB_PID_FILE" 2>/dev/null | head -1)
|
| 23 |
+
if [ -n "$LB_PID" ] && ps -p $LB_PID > /dev/null 2>&1; then
|
| 24 |
+
echo " Killing load balancer PID $LB_PID..."
|
| 25 |
+
kill -TERM $LB_PID 2>/dev/null || true
|
| 26 |
+
sleep 2
|
| 27 |
+
if ps -p $LB_PID > /dev/null 2>&1; then
|
| 28 |
+
echo " Force killing load balancer PID $LB_PID..."
|
| 29 |
+
kill -KILL $LB_PID 2>/dev/null || true
|
| 30 |
+
fi
|
| 31 |
+
echo " Load balancer stopped"
|
| 32 |
+
rm -f "$LB_PID_FILE"
|
| 33 |
+
else
|
| 34 |
+
echo " Load balancer PID from file not found (may have already terminated)"
|
| 35 |
+
rm -f "$LB_PID_FILE"
|
| 36 |
+
fi
|
| 37 |
+
else
|
| 38 |
+
echo " No load balancer PID file found ($LB_PID_FILE)"
|
| 39 |
+
echo " If load balancer is running, you may need to find and kill it manually"
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
echo ""
|
| 43 |
+
|
| 44 |
+
# Step 2: Stop reranker services (ONLY the ones we started)
|
| 45 |
+
echo "Step 2: Stopping reranker services (only the ones we started)..."
|
| 46 |
+
|
| 47 |
+
if [ -f "$PID_FILE" ]; then
|
| 48 |
+
echo " Reading PIDs from $PID_FILE"
|
| 49 |
+
KILLED_COUNT=0
|
| 50 |
+
NOT_FOUND_COUNT=0
|
| 51 |
+
|
| 52 |
+
while IFS= read -r pid || [ -n "$pid" ]; do
|
| 53 |
+
# Skip empty lines
|
| 54 |
+
[ -z "$pid" ] && continue
|
| 55 |
+
|
| 56 |
+
if ps -p $pid > /dev/null 2>&1; then
|
| 57 |
+
echo " Killing reranker service PID $pid..."
|
| 58 |
+
kill -TERM $pid 2>/dev/null || true
|
| 59 |
+
KILLED_COUNT=$((KILLED_COUNT + 1))
|
| 60 |
+
else
|
| 61 |
+
echo " PID $pid: Process not found (may have already terminated)"
|
| 62 |
+
NOT_FOUND_COUNT=$((NOT_FOUND_COUNT + 1))
|
| 63 |
+
fi
|
| 64 |
+
done < "$PID_FILE"
|
| 65 |
+
|
| 66 |
+
if [ $KILLED_COUNT -gt 0 ]; then
|
| 67 |
+
echo " Waiting 3 seconds for graceful shutdown..."
|
| 68 |
+
sleep 3
|
| 69 |
+
|
| 70 |
+
# Force kill if still running
|
| 71 |
+
while IFS= read -r pid || [ -n "$pid" ]; do
|
| 72 |
+
[ -z "$pid" ] && continue
|
| 73 |
+
if ps -p $pid > /dev/null 2>&1; then
|
| 74 |
+
echo " Force killing reranker service PID $pid..."
|
| 75 |
+
kill -KILL $pid 2>/dev/null || true
|
| 76 |
+
fi
|
| 77 |
+
done < "$PID_FILE"
|
| 78 |
+
|
| 79 |
+
echo " Stopped $KILLED_COUNT reranker service(s)"
|
| 80 |
+
else
|
| 81 |
+
echo " No running processes found from saved PIDs"
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
if [ $NOT_FOUND_COUNT -gt 0 ]; then
|
| 85 |
+
echo " ($NOT_FOUND_COUNT process(es) were already terminated)"
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
# Remove PID file after stopping
|
| 89 |
+
rm -f "$PID_FILE"
|
| 90 |
+
else
|
| 91 |
+
echo " WARNING: $PID_FILE not found!"
|
| 92 |
+
echo " Cannot safely stop services without PID file."
|
| 93 |
+
echo " If you know the PIDs, you can manually kill them."
|
| 94 |
+
echo " To avoid affecting other users, DO NOT use pkill!"
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
echo ""
|
| 98 |
+
echo " NOTE: Only processes from reranker_pids.txt were killed."
|
| 99 |
+
echo " Other reranker services (if any) were NOT affected."
|
| 100 |
+
|
| 101 |
+
echo ""
|
| 102 |
+
echo "=== Checking GPU status ==="
|
| 103 |
+
nvidia-smi --query-gpu=index,memory.used --format=csv,noheader | grep -E '^ 0,|^ 1,|^ 2,|^ 3,' || echo "GPU 0,1,2,3 status:"
|
| 104 |
+
|
| 105 |
+
echo ""
|
| 106 |
+
echo "Done!"
|
scripts/stop_vllm_services.sh
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to stop vLLM services and load balancer
|
| 3 |
+
# Usage: ./scripts/stop_vllm_services.sh
|
| 4 |
+
|
| 5 |
+
# Don't use set -e here because we want to continue even if some kills fail
|
| 6 |
+
# set -e
|
| 7 |
+
|
| 8 |
+
# Get script directory
|
| 9 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 10 |
+
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
| 11 |
+
cd "$PROJECT_ROOT"
|
| 12 |
+
|
| 13 |
+
LOG_DIR="./logs/vllm"
|
| 14 |
+
|
| 15 |
+
# Function to recursively collect all descendant PIDs of a given PID
|
| 16 |
+
# Returns space-separated list of all PIDs in the process tree
|
| 17 |
+
collect_descendant_pids() {
|
| 18 |
+
local root_pid=$1
|
| 19 |
+
local all_pids="$root_pid"
|
| 20 |
+
local to_check="$root_pid"
|
| 21 |
+
local new_pids=""
|
| 22 |
+
|
| 23 |
+
# Iteratively collect all descendants until no new children are found
|
| 24 |
+
while [ -n "$to_check" ]; do
|
| 25 |
+
new_pids=""
|
| 26 |
+
for pid in $to_check; do
|
| 27 |
+
# Find direct children of this PID
|
| 28 |
+
local children=$(ps -o pid --no-headers --ppid $pid 2>/dev/null | tr '\n' ' ')
|
| 29 |
+
if [ -n "$children" ]; then
|
| 30 |
+
# Add children to the list
|
| 31 |
+
all_pids="$all_pids $children"
|
| 32 |
+
new_pids="$new_pids $children"
|
| 33 |
+
fi
|
| 34 |
+
done
|
| 35 |
+
to_check="$new_pids"
|
| 36 |
+
done
|
| 37 |
+
|
| 38 |
+
echo "$all_pids"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Function to collect log files opened by a process and its descendants
|
| 42 |
+
# Returns newline-separated list of log file paths
|
| 43 |
+
collect_process_log_files() {
|
| 44 |
+
local root_pid=$1
|
| 45 |
+
local log_files=""
|
| 46 |
+
|
| 47 |
+
# Collect all descendant PIDs (including the root)
|
| 48 |
+
local all_pids=$(collect_descendant_pids $root_pid)
|
| 49 |
+
|
| 50 |
+
# Use lsof to find all log files opened by these processes
|
| 51 |
+
# Look for files in the log directory that are opened by any of these PIDs
|
| 52 |
+
for pid in $all_pids; do
|
| 53 |
+
[ -z "$pid" ] && continue
|
| 54 |
+
if ps -p $pid > /dev/null 2>&1; then
|
| 55 |
+
# Find log files opened by this PID (files with .log extension in LOG_DIR)
|
| 56 |
+
# lsof output format: COMMAND PID USER FD TYPE DEVICE SIZE/OFF NODE NAME
|
| 57 |
+
# We need the last field (NAME) which is the file path
|
| 58 |
+
# Try both absolute and relative paths
|
| 59 |
+
local log_dir_abs=$(cd "$PROJECT_ROOT" && cd "$LOG_DIR" && pwd 2>/dev/null || echo "$LOG_DIR")
|
| 60 |
+
local pid_logs=$(lsof -p $pid 2>/dev/null | awk 'NR>1 {print $NF}' | grep -E "\.log$" | grep -E "(^$log_dir_abs/|$LOG_DIR/)" | sort -u)
|
| 61 |
+
if [ -n "$pid_logs" ]; then
|
| 62 |
+
log_files="$log_files"$'\n'"$pid_logs"
|
| 63 |
+
fi
|
| 64 |
+
fi
|
| 65 |
+
done
|
| 66 |
+
|
| 67 |
+
# Remove duplicates and empty lines, return unique log files
|
| 68 |
+
echo "$log_files" | grep -v '^$' | sort -u
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Function to kill a PID and all its descendants
|
| 72 |
+
# This ensures all child processes (including GPU processes) are terminated
|
| 73 |
+
kill_process_tree() {
|
| 74 |
+
local root_pid=$1
|
| 75 |
+
local signal=${2:-TERM}
|
| 76 |
+
|
| 77 |
+
if ! ps -p $root_pid > /dev/null 2>&1; then
|
| 78 |
+
return 1
|
| 79 |
+
fi
|
| 80 |
+
|
| 81 |
+
# Collect all descendant PIDs (including the root)
|
| 82 |
+
local all_pids=$(collect_descendant_pids $root_pid)
|
| 83 |
+
|
| 84 |
+
# Kill all processes
|
| 85 |
+
# For TERM, we kill from leaves to root (reverse order) for graceful shutdown
|
| 86 |
+
# For KILL, order doesn't matter
|
| 87 |
+
if [ "$signal" = "KILL" ]; then
|
| 88 |
+
# Force kill all processes
|
| 89 |
+
for pid in $all_pids; do
|
| 90 |
+
[ -z "$pid" ] && continue
|
| 91 |
+
kill -KILL $pid 2>/dev/null || true
|
| 92 |
+
done
|
| 93 |
+
else
|
| 94 |
+
# Graceful shutdown: kill children first, then parent
|
| 95 |
+
# Convert to array and kill in reverse order
|
| 96 |
+
local pids_array=($all_pids)
|
| 97 |
+
for ((idx=${#pids_array[@]}-1; idx>=0; idx--)); do
|
| 98 |
+
pid=${pids_array[$idx]}
|
| 99 |
+
[ -z "$pid" ] && continue
|
| 100 |
+
kill -TERM $pid 2>/dev/null || true
|
| 101 |
+
done
|
| 102 |
+
fi
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
echo "=== Stopping vLLM Services and Load Balancer ==="
|
| 106 |
+
echo ""
|
| 107 |
+
|
| 108 |
+
# Step 1: Stop load balancer (if PID file exists)
|
| 109 |
+
echo "Step 1: Stopping vLLM load balancer..."
|
| 110 |
+
LB_PID_FILE="$LOG_DIR/vllm_lb_pid.txt"
|
| 111 |
+
LB_LOG_FILES=""
|
| 112 |
+
if [ -f "$LB_PID_FILE" ]; then
|
| 113 |
+
LB_PID=$(cat "$LB_PID_FILE" 2>/dev/null | head -1)
|
| 114 |
+
if [ -n "$LB_PID" ] && ps -p $LB_PID > /dev/null 2>&1; then
|
| 115 |
+
echo " Killing load balancer PID $LB_PID..."
|
| 116 |
+
# Collect log files before killing
|
| 117 |
+
LB_LOG_FILES=$(collect_process_log_files $LB_PID)
|
| 118 |
+
# Also try to find load balancer log files by pattern (fallback if lsof doesn't work)
|
| 119 |
+
if [ -z "$LB_LOG_FILES" ]; then
|
| 120 |
+
LB_LOG_FILES=$(find "$LOG_DIR" -maxdepth 1 -name "load_balancer*.log" -type f 2>/dev/null)
|
| 121 |
+
fi
|
| 122 |
+
kill -TERM $LB_PID 2>/dev/null || true
|
| 123 |
+
sleep 2
|
| 124 |
+
if ps -p $LB_PID > /dev/null 2>&1; then
|
| 125 |
+
echo " Force killing load balancer PID $LB_PID..."
|
| 126 |
+
kill -KILL $LB_PID 2>/dev/null || true
|
| 127 |
+
fi
|
| 128 |
+
echo " Load balancer stopped"
|
| 129 |
+
rm -f "$LB_PID_FILE"
|
| 130 |
+
|
| 131 |
+
# Remove load balancer log files
|
| 132 |
+
if [ -n "$LB_LOG_FILES" ]; then
|
| 133 |
+
echo " Removing load balancer log files..."
|
| 134 |
+
while IFS= read -r log_file; do
|
| 135 |
+
[ -z "$log_file" ] && continue
|
| 136 |
+
if [ -f "$log_file" ]; then
|
| 137 |
+
rm -f "$log_file"
|
| 138 |
+
echo " Removed: $log_file"
|
| 139 |
+
fi
|
| 140 |
+
done <<< "$LB_LOG_FILES"
|
| 141 |
+
else
|
| 142 |
+
echo " Note: Could not detect load balancer log file (process may have already terminated)"
|
| 143 |
+
fi
|
| 144 |
+
else
|
| 145 |
+
echo " Load balancer PID from file not found (may have already terminated)"
|
| 146 |
+
rm -f "$LB_PID_FILE"
|
| 147 |
+
fi
|
| 148 |
+
else
|
| 149 |
+
echo " No load balancer PID file found ($LB_PID_FILE)"
|
| 150 |
+
echo " If load balancer is running, you may need to find and kill it manually"
|
| 151 |
+
fi
|
| 152 |
+
|
| 153 |
+
echo ""
|
| 154 |
+
|
| 155 |
+
# Step 2: Stop vLLM services (ONLY the ones we started)
|
| 156 |
+
echo "Step 2: Stopping vLLM services (only the ones we started)..."
|
| 157 |
+
|
| 158 |
+
# Try to read PIDs from file
|
| 159 |
+
if [ -f "$LOG_DIR/vllm_pids.txt" ]; then
|
| 160 |
+
echo " Reading PIDs from $LOG_DIR/vllm_pids.txt"
|
| 161 |
+
|
| 162 |
+
# Read all PIDs into an array
|
| 163 |
+
pids_array=()
|
| 164 |
+
while IFS= read -r pid || [ -n "$pid" ]; do
|
| 165 |
+
# Skip empty lines
|
| 166 |
+
[ -z "$pid" ] && continue
|
| 167 |
+
pids_array+=($pid)
|
| 168 |
+
done < "$LOG_DIR/vllm_pids.txt"
|
| 169 |
+
|
| 170 |
+
KILLED_COUNT=0
|
| 171 |
+
NOT_FOUND_COUNT=0
|
| 172 |
+
|
| 173 |
+
# Collect log files for all vLLM services before killing
|
| 174 |
+
vllm_log_files=""
|
| 175 |
+
for pid in "${pids_array[@]}"; do
|
| 176 |
+
if ps -p $pid > /dev/null 2>&1; then
|
| 177 |
+
# Collect log files for this PID
|
| 178 |
+
pid_logs=$(collect_process_log_files $pid)
|
| 179 |
+
if [ -n "$pid_logs" ]; then
|
| 180 |
+
vllm_log_files="$vllm_log_files"$'\n'"$pid_logs"
|
| 181 |
+
fi
|
| 182 |
+
fi
|
| 183 |
+
done
|
| 184 |
+
|
| 185 |
+
# First pass: graceful shutdown (TERM signal)
|
| 186 |
+
for pid in "${pids_array[@]}"; do
|
| 187 |
+
if ps -p $pid > /dev/null 2>&1; then
|
| 188 |
+
echo " Killing vLLM service PID $pid and all its descendant processes..."
|
| 189 |
+
# Collect and show how many processes will be killed
|
| 190 |
+
descendant_pids=$(collect_descendant_pids $pid)
|
| 191 |
+
pid_count=$(echo $descendant_pids | wc -w)
|
| 192 |
+
echo " Found $pid_count process(es) in the process tree"
|
| 193 |
+
# Use our recursive function to kill the entire process tree
|
| 194 |
+
kill_process_tree $pid TERM
|
| 195 |
+
KILLED_COUNT=$((KILLED_COUNT + 1))
|
| 196 |
+
else
|
| 197 |
+
echo " PID $pid: Process not found (may have already terminated)"
|
| 198 |
+
NOT_FOUND_COUNT=$((NOT_FOUND_COUNT + 1))
|
| 199 |
+
fi
|
| 200 |
+
done
|
| 201 |
+
|
| 202 |
+
if [ $KILLED_COUNT -gt 0 ]; then
|
| 203 |
+
echo " Waiting 3 seconds for graceful shutdown..."
|
| 204 |
+
sleep 3
|
| 205 |
+
|
| 206 |
+
# Second pass: force kill (KILL signal) if still running
|
| 207 |
+
for pid in "${pids_array[@]}"; do
|
| 208 |
+
if ps -p $pid > /dev/null 2>&1; then
|
| 209 |
+
echo " Force killing vLLM service PID $pid and all its descendant processes..."
|
| 210 |
+
# Collect and show how many processes will be force killed
|
| 211 |
+
descendant_pids=$(collect_descendant_pids $pid)
|
| 212 |
+
pid_count=$(echo $descendant_pids | wc -w)
|
| 213 |
+
echo " Force killing $pid_count process(es) in the process tree"
|
| 214 |
+
# Use our recursive function to force kill the entire process tree
|
| 215 |
+
kill_process_tree $pid KILL
|
| 216 |
+
fi
|
| 217 |
+
done
|
| 218 |
+
|
| 219 |
+
echo " Stopped $KILLED_COUNT vLLM service(s)"
|
| 220 |
+
else
|
| 221 |
+
echo " No running processes found from saved PIDs"
|
| 222 |
+
fi
|
| 223 |
+
|
| 224 |
+
if [ $NOT_FOUND_COUNT -gt 0 ]; then
|
| 225 |
+
echo " ($NOT_FOUND_COUNT process(es) were already terminated)"
|
| 226 |
+
fi
|
| 227 |
+
|
| 228 |
+
# Remove vLLM log files
|
| 229 |
+
if [ -n "$vllm_log_files" ]; then
|
| 230 |
+
echo ""
|
| 231 |
+
echo " Removing vLLM service log files..."
|
| 232 |
+
removed_count=0
|
| 233 |
+
while IFS= read -r log_file; do
|
| 234 |
+
[ -z "$log_file" ] && continue
|
| 235 |
+
if [ -f "$log_file" ]; then
|
| 236 |
+
rm -f "$log_file"
|
| 237 |
+
echo " Removed: $log_file"
|
| 238 |
+
removed_count=$((removed_count + 1))
|
| 239 |
+
fi
|
| 240 |
+
done <<< "$vllm_log_files"
|
| 241 |
+
if [ $removed_count -eq 0 ] && [ -n "$vllm_log_files" ]; then
|
| 242 |
+
echo " (No log files found to remove - they may have already been deleted)"
|
| 243 |
+
fi
|
| 244 |
+
else
|
| 245 |
+
echo ""
|
| 246 |
+
echo " Note: Could not detect vLLM log files (processes may have already terminated)"
|
| 247 |
+
fi
|
| 248 |
+
|
| 249 |
+
# Remove PID file after stopping
|
| 250 |
+
rm -f "$LOG_DIR/vllm_pids.txt"
|
| 251 |
+
else
|
| 252 |
+
echo " WARNING: $LOG_DIR/vllm_pids.txt not found!"
|
| 253 |
+
echo " Cannot safely stop services without PID file."
|
| 254 |
+
echo " If you know the PIDs, you can manually kill them."
|
| 255 |
+
echo " To avoid affecting other users, DO NOT use pkill!"
|
| 256 |
+
fi
|
| 257 |
+
|
| 258 |
+
echo ""
|
| 259 |
+
echo " NOTE: Only processes from vllm_pids.txt were killed."
|
| 260 |
+
echo " Other vLLM services (if any) were NOT affected."
|
| 261 |
+
|
| 262 |
+
echo ""
|
| 263 |
+
echo "=== Checking GPU status ==="
|
| 264 |
+
nvidia-smi --query-gpu=index,memory.used --format=csv,noheader | grep -E '^ 4,|^ 5,|^ 6,|^ 7,' || echo "GPU 4,5,6,7 status:"
|
| 265 |
+
|
| 266 |
+
echo ""
|
| 267 |
+
echo "Done!"
|
shared/configs/config.yaml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration file for Paper Reviewer Agent
|
| 2 |
+
#
|
| 3 |
+
# Note: LLM service configuration is in configs/llm_service_config.yaml
|
| 4 |
+
# Note: Prompts configuration is in configs/prompts.yaml
|
| 5 |
+
|
| 6 |
+
# Global verbose mode - controls all step outputs (Step 1, Step 2, Step 3, etc.)
|
| 7 |
+
# Set to false to suppress intermediate progress output for faster execution
|
| 8 |
+
verbose: false
|
| 9 |
+
|
| 10 |
+
# Paper Search Configuration
|
| 11 |
+
paper_search:
|
| 12 |
+
# Asta API Configuration
|
| 13 |
+
asta:
|
| 14 |
+
# Single API key (backward compatible, lower priority)
|
| 15 |
+
api_key: null # Set via ASTA_API_KEY env var
|
| 16 |
+
# API key pool
|
| 17 |
+
api_key_pool_path: "asta_api_pool.txt" # path relative to shared/configs/ or absolute path
|
| 18 |
+
endpoint: "https://asta-tools.allen.ai/mcp/v1"
|
| 19 |
+
|
| 20 |
+
# Semantic Scholar API Configuration (alternative)
|
| 21 |
+
semantic_scholar:
|
| 22 |
+
api_key: null # Set via S2_API_KEY env var
|
| 23 |
+
|
| 24 |
+
# Reranker Configuration
|
| 25 |
+
reranker:
|
| 26 |
+
# Reranker model path (for direct mode)
|
| 27 |
+
model: "OpenScholar/OpenScholar_Reranker" # e.g., "OpenScholar/OpenScholar_Reranker" or "BAAI/bge-reranker-base"
|
| 28 |
+
use_fp16: true
|
| 29 |
+
|
| 30 |
+
# Reranker API Configuration (for API mode with load balancing)
|
| 31 |
+
# If base_url is set, use API mode with load balancer
|
| 32 |
+
# If endpoint_pool_path is set, use API mode with endpoint pool
|
| 33 |
+
# If both are None, use direct mode (load model directly)
|
| 34 |
+
api:
|
| 35 |
+
# Base URL for reranker API service (load balancer address)
|
| 36 |
+
# Example: "http://localhost:8009" (load balancer that distributes to 8005-8008)
|
| 37 |
+
# If set, will use API mode with load balancer
|
| 38 |
+
base_url: "http://localhost:8008"
|
| 39 |
+
|
| 40 |
+
# Endpoint pool file path (alternative to base_url)
|
| 41 |
+
# Example: "reranker_endpoint_pool.txt" (contains list of endpoints: http://localhost:8005, http://localhost:8006, ...)
|
| 42 |
+
# If set, will use API mode with endpoint pool (round-robin load balancing)
|
| 43 |
+
# endpoint_pool_path: "reranker_endpoint_pool.txt" # Set to use API mode with endpoint pool
|
| 44 |
+
endpoint_pool_path: null
|
| 45 |
+
|
| 46 |
+
# Request timeout in seconds
|
| 47 |
+
timeout: 30.0
|
| 48 |
+
|
| 49 |
+
# Retrieval Configuration
|
| 50 |
+
retrieval:
|
| 51 |
+
top_n: 10
|
| 52 |
+
use_abstract: true
|
| 53 |
+
norm_cite: false
|
| 54 |
+
min_citation: null
|
| 55 |
+
limit_per_keyword: 20
|
| 56 |
+
|
| 57 |
+
# Related Work Searcher Configuration
|
| 58 |
+
related_work_searcher:
|
| 59 |
+
max_related_papers: 10
|
| 60 |
+
max_parallel_summaries: 1
|
| 61 |
+
publication_date_range: null # e.g., "2020:" for papers from 2020 onwards
|
| 62 |
+
venues: null # e.g., "ICLR,NeurIPS"
|
| 63 |
+
verbose: false # Set to false to suppress intermediate progress output (faster, less output)
|
| 64 |
+
|
| 65 |
+
# Paper Reviewer Configuration
|
| 66 |
+
paper_reviewer:
|
| 67 |
+
review_format: "ai_researcher" # Options: "detailed", "summary", "structured", "ai_researcher"
|
| 68 |
+
max_tokens: 16384 # Maximum tokens for reviewer output (increased for detailed reviews)
|
| 69 |
+
|
| 70 |
+
# Review Refiner Configuration
|
| 71 |
+
review_refiner:
|
| 72 |
+
review_format: "strict_detailed" # Options: "detailed", "summary", "structured", "strict_detailed" - should match paper_reviewer format
|
| 73 |
+
max_tokens: 16384 # Maximum tokens for refiner output (increased for detailed reviews)
|
| 74 |
+
|
| 75 |
+
# Output Configuration
|
| 76 |
+
output:
|
| 77 |
+
save_reviews: true
|
| 78 |
+
output_dir: "./outputs"
|
| 79 |
+
format: "json" # Options: "json", "markdown", "txt"
|
| 80 |
+
|
| 81 |
+
# Evaluation Configuration
|
| 82 |
+
evaluation:
|
| 83 |
+
# Default number of worker threads for concurrent evaluation
|
| 84 |
+
max_workers: 16
|
| 85 |
+
|
| 86 |
+
# Component name for LLM service assignment (used with llm_service_config.yaml)
|
| 87 |
+
# Options: "keyword_generator", "paper_summarizer", "reviewer", "refiner"
|
| 88 |
+
# Defaults to "reviewer" if not specified
|
| 89 |
+
llm_component: "reviewer"
|
| 90 |
+
|
| 91 |
+
# Default model name (can be overridden by command line args or llm_service_config.yaml)
|
| 92 |
+
default_model_name: "Qwen/Qwen2.5-72B-Instruct"
|
| 93 |
+
|
| 94 |
+
# Prompt versions
|
| 95 |
+
rubric_generation_prompt_version: "v2" # Options: "v1", "v2"
|
| 96 |
+
evaluator_prompt_version: "v1" # Options: "v0", "v1"
|
| 97 |
+
|
shared/configs/llm_service_config.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# vLLM Service Configuration
|
| 2 |
+
|
| 3 |
+
# vLLM server settings
|
| 4 |
+
vllm:
|
| 5 |
+
# Base URL for vLLM service
|
| 6 |
+
# In production, this should point to a load balancer (nginx/HAProxy) that distributes
|
| 7 |
+
# requests across multiple vLLM service instances running on different GPUs.
|
| 8 |
+
# Example: "http://localhost:8000/v1" (load balancer address)
|
| 9 |
+
base_url: "http://localhost:8000/" # directly to the load balancer, 8000-8003 are used by the vllm services
|
| 10 |
+
|
| 11 |
+
api_key: "dummy-key" # Not used for local vLLM, but required by OpenAI client
|
| 12 |
+
model_name: "openai/gpt-oss-120b"
|
| 13 |
+
# model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
| 14 |
+
# model_name: "Qwen/Qwen3-235B-A22B-Instruct-2507"
|
| 15 |
+
timeout: 300
|
| 16 |
+
|
| 17 |
+
# Rate limiting: Maximum concurrent requests to vLLM server
|
| 18 |
+
# Lower this if you're getting 500 errors (suggests server overload)
|
| 19 |
+
# Recommended: 4-8 for small models, 2-4 for large models
|
| 20 |
+
max_concurrent_requests: 64
|
| 21 |
+
|
| 22 |
+
# Retry configuration for server errors
|
| 23 |
+
max_retries: 3 # Number of retries for 500/502/503/504 errors
|
| 24 |
+
retry_delay: 1.0 # Initial delay in seconds
|
| 25 |
+
retry_backoff: 2.0 # Exponential backoff multiplier
|
| 26 |
+
|
| 27 |
+
# Default sampling parameters
|
| 28 |
+
temperature: 0.7
|
| 29 |
+
top_p: 0.8
|
| 30 |
+
top_k: 20
|
| 31 |
+
max_tokens: 16384
|
| 32 |
+
presence_penalty: 0.0
|
| 33 |
+
|
| 34 |
+
# GPT / OpenRouter API Configuration
|
| 35 |
+
gpt:
|
| 36 |
+
enabled: true
|
| 37 |
+
# Leave api_key null so it is taken from OPENAI_API_KEY or OPENROUTER_API_KEY
|
| 38 |
+
api_key: null
|
| 39 |
+
# Use the OpenRouter model you specified
|
| 40 |
+
model_name: "openai/gpt-oss-120b"
|
| 41 |
+
# Point the OpenAI-compatible client to OpenRouter
|
| 42 |
+
base_url: "http://localhost:8000/"
|
| 43 |
+
timeout: 300
|
| 44 |
+
|
| 45 |
+
# Default sampling parameters
|
| 46 |
+
temperature: 0.7
|
| 47 |
+
top_p: 0.95
|
| 48 |
+
max_tokens: 16384
|
| 49 |
+
presence_penalty: 0.0
|
| 50 |
+
|
| 51 |
+
# LLM Service Assignment
|
| 52 |
+
# Specify which LLM service to use for each component
|
| 53 |
+
llm_assignments:
|
| 54 |
+
keyword_generator: "gpt" # Options: "vllm", "gpt"
|
| 55 |
+
paper_summarizer: "gpt" # Options: "vllm", "gpt"
|
| 56 |
+
reviewer: "gpt" # Options: "vllm", "gpt"
|
| 57 |
+
refiner: "gpt" # Options: "vllm", "gpt" - defaults to reviewer if not specified
|
shared/configs/prompts.yaml
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Prompt templates for the paper reviewer agent
|
| 2 |
+
|
| 3 |
+
# Keyword generation prompt
|
| 4 |
+
# Based on OpenScholar's keyword extraction prompt, adapted for JSON output format
|
| 5 |
+
keyword_generation:
|
| 6 |
+
system: "You are an experienced research assistant helping to find related work for a paper. Always respond with valid JSON."
|
| 7 |
+
user: |
|
| 8 |
+
Suggest search queries to retrieve relevant papers related to the following paper. The search queries must be simple, short, and comma separated. Focus on core technical concepts, methods, and key techniques that would help find related work.
|
| 9 |
+
|
| 10 |
+
Here's an example:
|
| 11 |
+
##
|
| 12 |
+
Paper: How have prior work incorporated personality attributes to train personalized dialogue generation models?
|
| 13 |
+
Search queries: personalized dialogue generation, personalized language models, personalized dialogue
|
| 14 |
+
##
|
| 15 |
+
Paper: How do retrieval-augmented LMs perform well in knowledge-intensive tasks?
|
| 16 |
+
Search queries: retrieval-augmented LMs, knowledge-intensive tasks, large language models for knowledge-intensive tasks, retrieval-augmented generation
|
| 17 |
+
##
|
| 18 |
+
|
| 19 |
+
Paper information:
|
| 20 |
+
{context}
|
| 21 |
+
|
| 22 |
+
IMPORTANT: You must respond with valid JSON only. Use this format:
|
| 23 |
+
{{
|
| 24 |
+
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"]
|
| 25 |
+
}}
|
| 26 |
+
|
| 27 |
+
Return only the JSON, no additional text or explanation. Generate 3-5 short, comma-separated keywords as search queries.
|
| 28 |
+
|
| 29 |
+
# DOMAIN-SPECIFICAGENTS
|
| 30 |
+
# Paper summarization prompt
|
| 31 |
+
# Generate structured summary of each related paper: summary, main methods, key findings, relation with the target paper
|
| 32 |
+
paper_summarization:
|
| 33 |
+
user: |
|
| 34 |
+
You are an senior research assistant who is proficient at identifying the main contribution, key findings of papers and the relations between different works.
|
| 35 |
+
|
| 36 |
+
For the reference paper below:
|
| 37 |
+
|
| 38 |
+
{reference_paper}
|
| 39 |
+
|
| 40 |
+
You are given this paper as a related work to the reference paper:
|
| 41 |
+
|
| 42 |
+
{related_paper}
|
| 43 |
+
|
| 44 |
+
Now, you need to provide concise information on what the related work is about, its main methods, results and the relation between the reference paper, to make it easier for the supervisor to write the review.
|
| 45 |
+
|
| 46 |
+
Focusing on the relationship between the reference paper and the related work, summarize the related work's main methods, results and the relation between them in a concise way.
|
| 47 |
+
|
| 48 |
+
IMPORTANT: You must respond with valid JSON only. Use this format:
|
| 49 |
+
{{
|
| 50 |
+
"summary": "Your concise summary here in 2-3 sentences.",
|
| 51 |
+
"main_methods": "The main methods of the related work.",
|
| 52 |
+
"key_findings": "The key findings of the related work.",
|
| 53 |
+
"relation": "The relation between the related work and the paper you are reviewing, such as how they share similar ideas, solving the same problem, have diverged claims, etc."
|
| 54 |
+
}}
|
| 55 |
+
Return only the JSON, no additional text or explanation.
|
| 56 |
+
|
| 57 |
+
# Paper Insight Miner Prompt
|
| 58 |
+
paper_insight_miner:
|
| 59 |
+
user: |
|
| 60 |
+
You are an expert research assistant. Your task is to help refine the method/contribution parts of a candidate review, using the paper content as the source of truth.
|
| 61 |
+
|
| 62 |
+
SCOPE (strict):
|
| 63 |
+
- ONLY cover: core contributions, technical approach, model/algorithm design, mathematical formulation, assumptions, optimization/training, implementation details, and limitations of the method.
|
| 64 |
+
- Novelty: ONLY assess novelty claims AS PRESENTED IN THE PAPER (no external knowledge, no web search, no comparing to papers not mentioned in the text).
|
| 65 |
+
- Do NOT comment on experimental results, benchmarks, or score/decision fields (handled by another module).
|
| 66 |
+
- Do NOT do external related-work positioning (handled by another module).
|
| 67 |
+
|
| 68 |
+
Paper content:
|
| 69 |
+
{content}
|
| 70 |
+
|
| 71 |
+
Candidate review:
|
| 72 |
+
{candidate_review}
|
| 73 |
+
|
| 74 |
+
What to do:
|
| 75 |
+
1) Extract the paper’s core contributions and method details (paper-grounded).
|
| 76 |
+
2) Check the candidate review’s method/contribution claims against the paper and identify:
|
| 77 |
+
- incorrect/hallucinated/contradicted claims,
|
| 78 |
+
- missing key technical points,
|
| 79 |
+
- vague or generic statements that should be made specific.
|
| 80 |
+
3) Provide short rewrite suggestions WITH evidence anchors (Section/Equation/Algorithm/Figure/snippet if available).
|
| 81 |
+
|
| 82 |
+
Output JSON only:
|
| 83 |
+
{
|
| 84 |
+
"facts": {
|
| 85 |
+
"core_contributions": [
|
| 86 |
+
{"claim": "...", "evidence": "..."}
|
| 87 |
+
],
|
| 88 |
+
"method_summary": [
|
| 89 |
+
{"point": "key component / step / design choice", "evidence": "..."}
|
| 90 |
+
],
|
| 91 |
+
"assumptions_and_scope": [
|
| 92 |
+
{"item": "...", "evidence": "..."}
|
| 93 |
+
],
|
| 94 |
+
"novelty_claims_in_paper": [
|
| 95 |
+
{"claim": "as stated by the authors", "evidence": "..."}
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
"review_issues": {
|
| 99 |
+
"incorrect_or_hallucinated": [
|
| 100 |
+
{"review_claim": "...", "why_wrong": "...", "evidence": "..."}
|
| 101 |
+
],
|
| 102 |
+
"missing_key_points": [
|
| 103 |
+
{"what_missing": "...", "why_important": "...", "evidence": "..."}
|
| 104 |
+
],
|
| 105 |
+
"needs_specificity": [
|
| 106 |
+
{"review_text": "...", "how_to_fix": "name the component/assumption/equation/step", "evidence": "..."}
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
"rewrite_suggestions": [
|
| 110 |
+
{
|
| 111 |
+
"apply_to": "Summary|Strengths|Weaknesses|Suggestions|Questions (method-related only)",
|
| 112 |
+
"target": "Core Contribution Accuracy|Evidence-Based Critique|Critique Clarity",
|
| 113 |
+
"suggested_text": "1-2 sentences",
|
| 114 |
+
"evidence": "..."
|
| 115 |
+
}
|
| 116 |
+
]
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
Rules:
|
| 120 |
+
- If you cannot find support in the paper text, set evidence to "not_found_in_text"; do NOT assert the paper is missing it.
|
| 121 |
+
- Keep each list short (<=5 items). Prefer the most important contributions/components/issues.
|
| 122 |
+
- Return JSON only. No extra text.
|
| 123 |
+
|
| 124 |
+
# Paper results summarization prompt for Result Analyzer Agent
|
| 125 |
+
# according to the paper content and candidate review, pinpoint the issues and provide refinement suggestions with concrete evidence to the candidate review.
|
| 126 |
+
paper_results_analyzer:
|
| 127 |
+
user: |
|
| 128 |
+
You are an expert research assistant. Your task is to help refine the experiment/evaluation parts of a candidate review, using the paper content as the source of truth.
|
| 129 |
+
|
| 130 |
+
SCOPE (strict):
|
| 131 |
+
- ONLY cover experimental evaluation: datasets, baselines, metrics, tables/figures, quantitative results, statistical evidence, ablations.
|
| 132 |
+
- Do NOT comment on novelty, related-work positioning, writing/presentation quality, or overall recommendation. Other agents will handle those.
|
| 133 |
+
|
| 134 |
+
Paper content:
|
| 135 |
+
{content}
|
| 136 |
+
|
| 137 |
+
Candidate review:
|
| 138 |
+
{candidate_review}
|
| 139 |
+
|
| 140 |
+
What to do:
|
| 141 |
+
1) Extract key experimental facts from the paper.
|
| 142 |
+
2) Check experiment-related claims in the candidate review and identify:
|
| 143 |
+
- incorrect/hallucinated/contradicted claims,
|
| 144 |
+
- missing key experimental points,
|
| 145 |
+
- vague statements that should be made specific.
|
| 146 |
+
3) Provide short rewrite suggestions WITH evidence anchors (Table/Figure/Section/snippet if available).
|
| 147 |
+
|
| 148 |
+
Output JSON only:
|
| 149 |
+
{
|
| 150 |
+
"facts": {
|
| 151 |
+
"datasets": ["..."],
|
| 152 |
+
"metrics": ["..."],
|
| 153 |
+
"baselines": ["..."],
|
| 154 |
+
"key_results": [
|
| 155 |
+
{"claim": "...", "evidence": "..."}
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
"review_issues": {
|
| 159 |
+
"incorrect_or_hallucinated": [
|
| 160 |
+
{"review_claim": "...", "why_wrong": "...", "evidence": "..."}
|
| 161 |
+
],
|
| 162 |
+
"missing_key_points": [
|
| 163 |
+
{"what_missing": "...", "why_important": "...", "evidence": "..."}
|
| 164 |
+
],
|
| 165 |
+
"needs_specificity": [
|
| 166 |
+
{"review_text": "...", "how_to_fix": "...", "evidence": "..."}
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
"rewrite_suggestions": [
|
| 170 |
+
{
|
| 171 |
+
"apply_to": "Summary|Strengths|Weaknesses|Suggestions|Questions (experiment-related only)",
|
| 172 |
+
"target": "Results Interpretation|Evidence-Based Critique",
|
| 173 |
+
"suggested_text": "1-2 sentences",
|
| 174 |
+
"evidence": "..."
|
| 175 |
+
}
|
| 176 |
+
]
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
Rules:
|
| 180 |
+
- If you cannot find support in the paper text, set evidence to "not_found_in_text"; do NOT assert the paper is missing it.
|
| 181 |
+
- Keep each list short (<=5 items). Prefer the most important issues/results.
|
| 182 |
+
- Return JSON only. No extra text.
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# Reviewer model prompts (this may never been used)
|
| 186 |
+
# Need to be adapted to the actual review model
|
| 187 |
+
review_prompts:
|
| 188 |
+
detailed: |
|
| 189 |
+
You are reviewing a research paper. Please provide a detailed review in markdown format covering the following sections in this exact order:
|
| 190 |
+
|
| 191 |
+
## Summary
|
| 192 |
+
A brief summary of the paper's contributions and main findings.
|
| 193 |
+
|
| 194 |
+
## Soundness
|
| 195 |
+
Rate the technical soundness and correctness of the methodology (provide a score from 1 to 5, where 1 is very poor and 5 is excellent). Do not add any explanation.
|
| 196 |
+
|
| 197 |
+
## Presentation
|
| 198 |
+
Rate the clarity and quality of presentation (provide a score from 1 to 5, where 1 is very poor and 5 is excellent). Do not add any explanation.
|
| 199 |
+
|
| 200 |
+
## Contribution
|
| 201 |
+
Rate the significance and novelty of the contribution (provide a score from 1 to 5, where 1 is very poor and 5 is excellent). Do not add any explanation.
|
| 202 |
+
|
| 203 |
+
## Strengths
|
| 204 |
+
What are the main strengths of this paper? Consider:
|
| 205 |
+
- Novelty and originality
|
| 206 |
+
- Technical soundness
|
| 207 |
+
- Experimental validation
|
| 208 |
+
- Clarity of presentation
|
| 209 |
+
- Significance of contributions
|
| 210 |
+
|
| 211 |
+
## Weaknesses
|
| 212 |
+
What are the main weaknesses or concerns? Consider:
|
| 213 |
+
- Methodological issues
|
| 214 |
+
- Missing experiments or baselines
|
| 215 |
+
- Limitations not acknowledged
|
| 216 |
+
- Clarity issues
|
| 217 |
+
- Reproducibility concerns
|
| 218 |
+
|
| 219 |
+
## Questions
|
| 220 |
+
Any questions you have for the authors that would help clarify aspects of the work.
|
| 221 |
+
|
| 222 |
+
## Rating
|
| 223 |
+
Overall rating of the paper (provide a score from 1 to 10, following the reviewer scale). Do not add any explanation.
|
| 224 |
+
|
| 225 |
+
## Confidence
|
| 226 |
+
Your confidence in your assessment (provide a score from 1 to 5, following the reviewer scale). Do not add any explanation.
|
| 227 |
+
|
| 228 |
+
## Decision
|
| 229 |
+
Your recommendation: "accept", "reject", or "undecided". Do not add any explanation.
|
| 230 |
+
|
| 231 |
+
IMPORTANT: Write your review in markdown format with the exact section headers above (## Summary, ## Soundness, etc.). Include scores and explanations for each scoring section. Be constructive, specific, and fair in your review.
|
| 232 |
+
|
| 233 |
+
ai_researcher: |
|
| 234 |
+
You are an expert academic reviewer. Your task is to provide a thorough, structured, and balanced review of the following research paper.
|
| 235 |
+
|
| 236 |
+
Step 1: Read and Analyze the Paper Carefully
|
| 237 |
+
- Read the paper paragraph by paragraph.
|
| 238 |
+
- For each paragraph:
|
| 239 |
+
- Perform detailed analysis and document your thought process using <think></think> tags.
|
| 240 |
+
- Identify strengths, weaknesses, unclear points, logical flaws, technical inconsistencies, or missing references.
|
| 241 |
+
- Highlight both strengths and weaknesses.
|
| 242 |
+
- Ensure all observations are supported by reasoning inside <think></think> tags.
|
| 243 |
+
|
| 244 |
+
Step 2: Conduct the Review
|
| 245 |
+
- After completing paragraph-by-paragraph analysis, provide an overall assessment following the structure below.
|
| 246 |
+
- Provide scores, recommendations, and a final decision in the strict JSON format.
|
| 247 |
+
- Do **not** include <think> tags inside JSON format.
|
| 248 |
+
- Be concise yet sufficiently detailed for an academic review.
|
| 249 |
+
|
| 250 |
+
Step 3: Organize your reviews into the following JSON format:
|
| 251 |
+
{
|
| 252 |
+
"summary": [Concise, detailed summary covering methodology, key ideas, and results.],
|
| 253 |
+
"soundness": [Score 1-5],
|
| 254 |
+
"presentation": [Score 1-5],
|
| 255 |
+
"contribution": [Score 1-5],
|
| 256 |
+
"strengths": [List major strengths],
|
| 257 |
+
"weaknesses": [List major weaknesses, with confidence levels if applicable],
|
| 258 |
+
"suggestions": [Concrete recommendations to address weaknesses],
|
| 259 |
+
"questions": [Outstanding questions or clarifications needed],
|
| 260 |
+
"rating": [Overall score, e.g., 1-10],
|
| 261 |
+
"confidence": [Confidence in assessment, e.g., 1-5],
|
| 262 |
+
"decision": [Accept, Reject]
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
Few-shot example (for format reference and not copy the example):
|
| 266 |
+
|
| 267 |
+
{
|
| 268 |
+
"summary": "This paper introduces a novel algorithm for modeling shared dynamics between multiple observation processes, validated on both simulated and real-world data.",
|
| 269 |
+
"soundness": 3.0,
|
| 270 |
+
"presentation": 3.0,
|
| 271 |
+
"contribution": 3.0,
|
| 272 |
+
"strengths": "- Novel decomposition approach.\n - Separation of shared and residual dynamics.\n - Validated on real and simulated data.",
|
| 273 |
+
"weaknesses": "- Strong linearity assumption (high confidence).\n - Limited experiments (medium confidence).",
|
| 274 |
+
"suggestions": "- Test on nonlinear systems.\n - Expand evaluation datasets.",
|
| 275 |
+
"questions": "- Sensitivity to deviations from assumed structure?\n - Performance on nonlinear data?",
|
| 276 |
+
"rating": 6.5,
|
| 277 |
+
"confidence": 3.0,
|
| 278 |
+
"decision": "Accept"
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
Few-shot example (strictly follow this structure and do not copy the example):
|
| 282 |
+
{
|
| 283 |
+
"summary": "This paper introduces PG-LDS-ID, a novel algorithm designed to model the shared dynamics between two observation processes: a continuous-time Gaussian process and a discrete-time Poisson process. The core idea is to use a latent state-space model to capture the underlying dynamics that influence both observation streams, while also accounting for residual dynamics unique to the Poisson process. The authors propose a two-stage approach, where the first stage identifies the shared dynamics using a covariance-based subspace identification method, and the second stage identifies the residual dynamics that are only observable through the Poisson process. A key contribution is the introduction of a block-structured system matrix, which facilitates the separation of shared and residual dynamics. The method is motivated by applications in neuroscience, where one might want to model the relationship between continuous behavioral trajectories and discrete neural spiking activity. The authors validate their approach using both simulated data and a real-world dataset of non-human primate neural spiking activity and arm movements. The simulation results demonstrate the algorithm's ability to accurately recover the shared dynamics, and the real data experiment shows improved prediction accuracy compared to a prior method, PLDSID. The paper's significance lies in its ability to handle coupled observation processes with different statistical properties, while explicitly disentangling shared and residual dynamics, a capability not simultaneously offered by existing analytical methods. However, the paper's reliance on strong assumptions, such as linearity and a specific block structure, and the limited scope of its experimental validation, raise important questions about its broader applicability and robustness.",
|
| 284 |
+
"soundness": 2.8,
|
| 285 |
+
"presentation": 2.4,
|
| 286 |
+
"contribution": 2.6,
|
| 287 |
+
"strengths": "One of the primary strengths of this paper is the novel decomposition technique introduced through Equation (6), which employs a block-structured system matrix. This decomposition is a significant contribution as it simplifies the modeling of shared dynamics between data streams with different statistical properties, specifically Gaussian and Poisson processes. By breaking down the problem into manageable components, the authors enhance the overall approach's ease of handling and implementation. This is particularly valuable in practical scenarios where dealing with coupled dynamics can be complex. Furthermore, the paper's focus on explicitly disentangling shared and residual dynamics is a notable advancement. Existing methods often model the collective dynamics of multiple modalities in the same latent states, whereas PG-LDS-ID explicitly separates the shared dynamics from those unique to the Poisson process. This distinction is crucial for understanding the underlying mechanisms that drive different observation streams. The authors demonstrate the practical applicability of their method through both simulated and real-world experiments. The simulation results show that the proposed method can accurately recover the shared dynamics, and the real data experiment on non-human primate data shows that PG-LDS-ID achieves better prediction accuracies compared to PLDSID, a prior method. This empirical validation provides evidence for the effectiveness of the algorithm in a realistic setting. Finally, the method's ability to handle generalized linear processes, as opposed to being limited to Gaussian processes, is another strength. By using second-order moments, the proposed method can now deal with a broader class of observation models, making it more versatile and applicable to a wider range of problems.",
|
| 288 |
+
"weaknesses": "After a thorough review of the paper and the reviewer comments, I have identified several key weaknesses that significantly impact the paper's conclusions and broader applicability. First, the paper's strong reliance on the assumption of latent linear dynamics is a major limitation. The entire method is built upon linear dynamical state-space models, which, as the authors acknowledge in the 'Limitations' section, can only provide an approximation of nonlinear dynamics. This assumption is particularly concerning given that many real-world systems, especially those in neuroscience, exhibit nonlinear behavior. The authors do not provide any experimental results or analysis of how the method performs when applied to observations generated by a latent nonlinear system. This lack of evaluation makes it difficult to assess the method's robustness and applicability in real-world scenarios. The confidence level for this weakness is high, as the paper explicitly states its reliance on linear models and lacks any analysis of nonlinear systems. Second, the paper introduces a specific block structure in Equation (6) for the system matrices, which is a critical assumption for the method's ability to dissociate shared and residual dynamics. While the authors justify this structure as a design choice to facilitate the separation of dynamics, they do not sufficiently discuss the conditions under which this decomposition can be effectively implemented, or the consequences of deviations from this structure. Specifically, the paper does not explore what happens if the true coefficient matrix has non-zero values in the upper right block, which would violate the assumed block structure. The practical implications of this choice are not fully explored, and the paper lacks any sensitivity analysis to assess the robustness of the method to such deviations. The confidence level for this weakness is high, as the paper introduces the block structure as a key design choice without addressing its limitations or potential for misapplication. Third, the paper lacks a detailed comparison with recent, relevant subspace identification methods that also leverage multimodal data. The authors compare their method against PLDSID, a method from 2012, but do not compare against more recent techniques such as those presented in Ahmadipour et al. (2023) and Vahidi et al. (2023). This lack of comparison makes it difficult to assess the novelty and specific advantages of the proposed method compared to the current state-of-the-art. The paper mentions that existing methods do not explicitly tease apart shared and residual dynamics, but a more thorough comparison is needed to justify the contribution of this work. The confidence level for this weakness is high, as the paper does not include a comparison with recent, relevant methods. Fourth, the paper does not adequately address the estimation of the Gaussian observation noise variance. While the optimization procedure in Section 3.2.3 ensures valid noise statistics, the explicit estimation of the noise variance of the Gaussian observation process is not clearly outlined as a separate step before the optimization. This omission raises concerns about the method's sensitivity to variations in the noise variance and its impact on the accuracy of the estimated latent states. The confidence level for this weakness is medium, as the paper implicitly addresses noise statistics but does not explicitly detail the estimation of the Gaussian noise variance. Fifth, the paper's experimental evaluation is limited in scope. The authors primarily compare their method against PLDSID and do not include comparisons with more recent and competitive methods. This limited evaluation makes it difficult to assess the proposed algorithm's strengths and weaknesses in the current research landscape. Furthermore, the paper uses only one real-world dataset (NHP data), which limits the assessment of the model's broader applicability. The confidence level for this weakness is high, as the experimental section lacks comparisons with recent methods and uses a limited number of datasets. Finally, the paper claims that the algorithm can be generalized to non-Poisson/non-Gaussian models but does not provide any experimental evidence to support this claim. The paper states that the moment transformation step is key to extending the method, but no results are shown for any other distributions. This lack of empirical evidence makes the claim of generalizability unsubstantiated. The confidence level for this weakness is high, as the claim is made without any supporting experimental results.",
|
| 289 |
+
"suggestions": "To address the identified weaknesses, I recommend several concrete improvements. First, the authors should conduct a thorough analysis of the method's sensitivity to violations of the linearity assumption. This could involve simulating data from a variety of nonlinear dynamical systems and assessing the accuracy of the estimated latent states and their dimensions. For example, they could use simple nonlinear systems like the Duffing oscillator or the Lorenz attractor to generate synthetic data and then apply their method to this data. The performance of the method could be evaluated by comparing the estimated latent states and their dimensions to the true values. Furthermore, it would be beneficial to explore how the method's performance changes as the degree of nonlinearity increases. This analysis would provide a more comprehensive understanding of the method's limitations and its applicability to real-world scenarios where nonlinearities are common. Second, the authors should provide a more detailed analysis of the method's sensitivity to deviations from the assumed block structure in Equation (6). This could involve simulations where the true coefficient matrix has small non-zero values in the upper right block and assessing whether the method still converges to a reasonable estimate. A sensitivity analysis exploring the robustness of the method to such deviations would be crucial. Furthermore, the paper should provide more guidance on how to choose the dimensions of the latent spaces ($n_1$ and $n_x$). The current description is somewhat vague, and a more concrete procedure, perhaps based on information criteria or cross-validation, would be highly valuable. Third, the authors should include a detailed comparison of their approach with recent subspace identification techniques that use both behavioral and neural data, such as those presented in [1] and [2]. This comparison should include a discussion of the assumptions made by each method, the optimization procedures used, and the types of data that can be handled. For example, the authors should clearly explain how their method differs from the approaches presented in [1] and [2] in terms of the way they model the shared and residual dynamics. They should also discuss the advantages and disadvantages of their method compared to these existing techniques. Fourth, the authors should clarify the role of the noise variance of the Gaussian observation process in their method. They should provide a detailed analysis of the method's sensitivity to variations in the noise variance. This analysis could include simulations with different noise levels and a quantitative assessment of the error in latent state estimation and dimensionality identification. Furthermore, they should discuss how the method's performance is affected by the choice of the noise model. Fifth, the experimental evaluation should be expanded to include comparisons with more recent and competitive methods. While PLDSID is a relevant baseline, the field has seen significant advancements since 2012. Including comparisons with state-of-the-art methods, such as more recent deep learning approaches for time series modeling, would provide a more comprehensive assessment of the proposed algorithm's performance. This would not only highlight the strengths of the proposed method but also reveal its limitations and areas for future improvement. Finally, the authors should provide empirical support for their claim that the algorithm can be generalized to non-Poisson/non-Gaussian models. This could involve testing the method on synthetic datasets from simple models with alternative distributions. The authors should also consider including a simple example of how the moment transformation would be derived for a different distribution, such as Bernoulli, to further support their claim.",
|
| 290 |
+
"questions": "Several key uncertainties remain after my review of this paper. First, I am particularly interested in the justification for the block structure assumed in Equation (6). While the authors claim this structure does not lose generality, I would like to understand the practical implications of this choice more thoroughly. Specifically, how does the method behave when the true underlying system deviates from this block structure, even slightly? What happens if there are small non-zero values in the upper right block of the coefficient matrix? Does the method still converge to a reasonable estimate, or does it break down? A more detailed explanation of the assumptions underlying this block structure, and a sensitivity analysis exploring its robustness, would be highly beneficial. Second, I am curious about the method's performance when applied to nonlinear systems. The paper acknowledges the limitation of assuming linear dynamics but does not provide any analysis of how the method performs when this assumption is violated. How does the method perform when the underlying system is nonlinear? How does the accuracy of the estimated latent states and their dimensions change as the degree of nonlinearity increases? I would like to see more systematic evaluations of the method's performance under nonlinear conditions. Third, I would like to understand how the method compares to existing approaches that use subspace identification for multimodal data, specifically those mentioned in [1] and [2]. How does the proposed method differ in terms of the assumptions made, the optimization procedures used, and the types of data that can be handled? A more detailed comparison with these methods is needed to justify the specific contribution of this work. Fourth, I am interested in the role of the noise variance of the Gaussian observation process in the method. How does the noise variance affect the accuracy of the estimated latent states and their dimensions? How does the method's performance change as the noise variance varies? A more thorough analysis of the method's sensitivity to variations in the noise variance would be valuable. Finally, I would like to understand the practical limitations of the proposed method. What are the assumptions underlying the method, and when might these assumptions be violated in practice? Are there specific types of dynamical systems for which the method is not suitable? A clear discussion of these limitations would help readers understand the scope of the method and avoid misapplications.",
|
| 291 |
+
"rating": 6.8,
|
| 292 |
+
"confidence": 2.6,
|
| 293 |
+
"decision": "Reject"
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
NOTE: Output the JSON format ONLY, DO NOT output anything other than the json, DO NOT include your thinking process.
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# System message for reviewer
|
| 300 |
+
reviewer_system: "You are an expert academic reviewer with deep knowledge in the field. Always respond with valid JSON format."
|
| 301 |
+
|
| 302 |
+
# Refiner prompts
|
| 303 |
+
refiner_prompts:
|
| 304 |
+
detailed: |
|
| 305 |
+
You are a senior researcher refining an existing peer review. Your job is to improve factual grounding, coverage, and usefulness while preserving the draft’s structure and intent. Treat the paper text as the source of truth.
|
| 306 |
+
|
| 307 |
+
You will be given:
|
| 308 |
+
(1) Paper text (plain text converted from PDF)
|
| 309 |
+
(2) Draft review (structured)
|
| 310 |
+
(3) Method/Contribution audit report (from Paper Insight Miner; paper-grounded)
|
| 311 |
+
(4) Experiments/Results audit report (from Paper Results Analyzer; paper-grounded)
|
| 312 |
+
(5) Related-work summaries (each item is a JSON summary of one retrieved paper, written relative to the target paper)
|
| 313 |
+
|
| 314 |
+
========================
|
| 315 |
+
Primary objectives (what to improve)
|
| 316 |
+
========================
|
| 317 |
+
Refine the review to satisfy these content-quality dimensions:
|
| 318 |
+
|
| 319 |
+
1) Core Contribution Accuracy
|
| 320 |
+
2) Results Interpretation
|
| 321 |
+
3) Comparative Analysis / Positioning
|
| 322 |
+
4) Evidence-Based Critique
|
| 323 |
+
5) Critique Clarity
|
| 324 |
+
6) Completeness Coverage
|
| 325 |
+
7) Constructive Tone
|
| 326 |
+
8) Avoid False or Contradictory Claims (critical)
|
| 327 |
+
|
| 328 |
+
========================
|
| 329 |
+
Hard constraints (must follow)
|
| 330 |
+
========================
|
| 331 |
+
A. Paper-grounded correctness is mandatory:
|
| 332 |
+
- If the audit reports mark a draft claim as incorrect/hallucinated/contradicted, you MUST fix or remove it.
|
| 333 |
+
- Do NOT introduce new factual claims about the paper unless you can anchor them to the paper text or the audit reports’ evidence.
|
| 334 |
+
|
| 335 |
+
B. Evidence anchoring rule:
|
| 336 |
+
- Every major critique (esp. in Weaknesses/Suggestions/Questions) must include a verifiable anchor:
|
| 337 |
+
section name, table/figure identifier, equation/algorithm reference, dataset/metric name, or a short quote snippet (<= 20 words).
|
| 338 |
+
- If you cannot find support, convert the statement into a question or a suggestion for clarification (do not assert absence).
|
| 339 |
+
|
| 340 |
+
C. Related-work usage rule (anti-leak / anti-overclaim):
|
| 341 |
+
- Retrieved related-work summaries are NOT guaranteed to be cited by the submission.
|
| 342 |
+
- Never claim “the paper compares to/cites X” unless the paper text actually contains X.
|
| 343 |
+
- When using retrieved works, attribute them as external context:
|
| 344 |
+
��The related-work search suggests …; it would help to clarify/compare …”
|
| 345 |
+
- Use related work to: (i) sharpen positioning, (ii) propose missing baselines/comparisons, (iii) raise targeted questions.
|
| 346 |
+
|
| 347 |
+
D. Minimal-change policy:
|
| 348 |
+
- Keep the original structure and as much of the draft wording as possible.
|
| 349 |
+
- Do NOT shorten aggressively; do NOT rewrite into a totally new review.
|
| 350 |
+
- Prefer targeted edits, insertions, and corrections.
|
| 351 |
+
|
| 352 |
+
E. Numeric fields policy (IMPORTANT):
|
| 353 |
+
- Default: keep ALL numeric fields and the decision unchanged.
|
| 354 |
+
- Change numeric fields ONLY if the refined textual assessment would otherwise be clearly inconsistent, or if a major factual correction materially changes the evaluation.
|
| 355 |
+
- If you change any numeric field: change the minimum number of fields, and keep changes small unless necessary.
|
| 356 |
+
|
| 357 |
+
========================
|
| 358 |
+
How to use the tool reports (operational)
|
| 359 |
+
========================
|
| 360 |
+
1) Apply Paper Insight Miner (method/contribution):
|
| 361 |
+
- Use `review_issues.incorrect_or_hallucinated` to remove/correct wrong claims in Summary/Strengths/Weaknesses.
|
| 362 |
+
- Use `missing_key_points` and `needs_specificity` to improve technical specificity.
|
| 363 |
+
- Incorporate `rewrite_suggestions` where appropriate (method-related only).
|
| 364 |
+
|
| 365 |
+
2) Apply Paper Results Analyzer (experiments/results):
|
| 366 |
+
- Correct any wrong result interpretation.
|
| 367 |
+
- Add missing datasets/baselines/metrics/key results if they are important and supported.
|
| 368 |
+
- Convert vague experiment critiques into concrete, testable suggestions with anchors.
|
| 369 |
+
- Incorporate `rewrite_suggestions` where appropriate (experiment-related only).
|
| 370 |
+
|
| 371 |
+
3) Use Related-work summaries:
|
| 372 |
+
- Use each item’s `relation` to craft 1–3 concrete positioning points:
|
| 373 |
+
- what is similar/different,
|
| 374 |
+
- what comparisons would strengthen the paper,
|
| 375 |
+
- what claims need clarification.
|
| 376 |
+
- Do NOT dump a bibliography; only mention the most relevant comparisons (typically <= 3 items).
|
| 377 |
+
- Phrase as external suggestions, not accusations.
|
| 378 |
+
|
| 379 |
+
========================
|
| 380 |
+
Refinement checklist (do in order)
|
| 381 |
+
========================
|
| 382 |
+
Step 1: Fix incorrect/hallucinated statements flagged by the two audit reports.
|
| 383 |
+
Step 2: Improve Summary and Strengths with paper-grounded method + results highlights.
|
| 384 |
+
Step 3: Strengthen Weaknesses with evidence anchors and clearer critique.
|
| 385 |
+
Step 4: Add actionable Suggestions (each mapped to a weakness).
|
| 386 |
+
Step 5: Improve Questions to resolve uncertainties (especially when evidence is not found).
|
| 387 |
+
Step 6: Improve Comparative Analysis using related-work summaries with proper attribution.
|
| 388 |
+
Step 7: Ensure constructive tone and completeness across method / experiments / positioning.
|
| 389 |
+
|
| 390 |
+
========================
|
| 391 |
+
Output format (JSON ONLY)
|
| 392 |
+
========================
|
| 393 |
+
Return a JSON object with the following keys ONLY.
|
| 394 |
+
- Numeric fields must be numbers (not strings).
|
| 395 |
+
- decision must be one of: "accept", "reject".
|
| 396 |
+
- Do not output any text outside JSON.
|
| 397 |
+
|
| 398 |
+
{
|
| 399 |
+
"summary": "...",
|
| 400 |
+
"strengths": "...",
|
| 401 |
+
"weaknesses": "...",
|
| 402 |
+
"questions": "...",
|
| 403 |
+
"soundness": 0,
|
| 404 |
+
"presentation": 0,
|
| 405 |
+
"contribution": 0,
|
| 406 |
+
"rating": 0,
|
| 407 |
+
"confidence": 0,
|
| 408 |
+
"decision": "your_decision"
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
========================
|
| 412 |
+
Inputs
|
| 413 |
+
========================
|
| 414 |
+
[Paper Text]
|
| 415 |
+
<<paper_text>>
|
| 416 |
+
|
| 417 |
+
[Draft Review]
|
| 418 |
+
<<draft_review>>
|
| 419 |
+
|
| 420 |
+
[Paper Insight Miner Output (JSON)]
|
| 421 |
+
<<insight_miner_json>>
|
| 422 |
+
|
| 423 |
+
[Paper Results Analyzer Output (JSON)]
|
| 424 |
+
<<results_analyzer_json>>
|
| 425 |
+
|
| 426 |
+
[Related-work Summaries (JSON list)]
|
| 427 |
+
<<related_work_json_list>>
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# System message for refiner
|
| 431 |
+
refiner_system: "You are an expert review refiner with deep knowledge in academic review quality standards and meta rubrics."
|
| 432 |
+
|
| 433 |
+
# Evaluation prompts for review-based rubric generation and evaluation
|
| 434 |
+
# Rubric template for generating paper-specific rubrics
|
| 435 |
+
rubrics: |
|
| 436 |
+
{
|
| 437 |
+
"title": "Core Contribution Accuracy",
|
| 438 |
+
"description": "Essential Criteria: Identifies whether the review accurately describes the paper's main contributions and central methodological innovations in the summary and strengths sections without misinterpretation.",
|
| 439 |
+
"weight": 1
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"title": "Results Interpretation",
|
| 443 |
+
"description": "Important Criteria: Explains whether the review correctly interprets the empirical results in the summary and strengths sections, including tables, figures, and statistical comparisons.",
|
| 444 |
+
"weight": 1
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"title": "Comparative Analysis",
|
| 448 |
+
"description": "Important Criteria: States whether the review appropriately discusses comparisons with baselines and related work presented in the paper.",
|
| 449 |
+
"weight": 1
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"title": "Evidence-Based Critique",
|
| 453 |
+
"description": "Essential Criteria: Identifies whether weaknesses and criticisms in the Weaknesses, Suggestions, and Questions sections are supported by specific references to paper content such as sections, equations, tables, or figures.",
|
| 454 |
+
"weight": 1
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"title": "Critique Clarity",
|
| 458 |
+
"description": "Important Criteria: Explains whether the identified weaknesses in the Weaknesses, Suggestions, and Questions sections are stated clearly and specifically enough for authors to understand what needs improvement.",
|
| 459 |
+
"weight": 1
|
| 460 |
+
},
|
| 461 |
+
{
|
| 462 |
+
"title": "Completeness Coverage",
|
| 463 |
+
"description": "Important Criteria: Identifies whether the review addresses all major components of the paper including methodology, theory, experiments, and related work.",
|
| 464 |
+
"weight": 1
|
| 465 |
+
},
|
| 466 |
+
{
|
| 467 |
+
"title": "Constructive Tone",
|
| 468 |
+
"description": "Optional Criteria: States whether the review maintains a professional and constructive tone that encourages improvement rather than discouragement.",
|
| 469 |
+
"weight": 1
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"title": "False or Contradictory Claims",
|
| 473 |
+
"description": "Pitfall Criteria: Does not mention content or experiments that are actually absent from the paper, incorrectly claim something is missing when it exists, or make statements that contradict the paper's stated results, conclusions, or explicitly documented design choices.",
|
| 474 |
+
"weight": -1
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
# Rubric generation prompt (v2 - uses template)
|
| 478 |
+
v2_rubric_generation_prompt: |
|
| 479 |
+
You are an expert rubric writer for evaluating the quality of paper reviews in AI-related academic fields, including:
|
| 480 |
+
|
| 481 |
+
- Machine Learning
|
| 482 |
+
- Deep Learning
|
| 483 |
+
- Natural Language Processing (NLP)
|
| 484 |
+
- Computer Vision
|
| 485 |
+
- Robotics
|
| 486 |
+
- Reinforcement Learning
|
| 487 |
+
- Optimization
|
| 488 |
+
- Data-centric AI
|
| 489 |
+
- Related subdisciplines
|
| 490 |
+
|
| 491 |
+
You are given a paper content, a golden review, and a review evaluation rubric template.
|
| 492 |
+
According the paper content, golden review, and review evaluation rubric template, your task is to combine the rubric template with the paper content and golden review, to form a complete review evaluation rubric set, so that it can be used to judge if other candidates reviews captures the key contents in the golden review and candidly reflect the understanding, strengths, and weaknesses of the paper.
|
| 493 |
+
The combined rubrics should be precise, comprehensive, and actionable.
|
| 494 |
+
|
| 495 |
+
## Rubric Construction Rules
|
| 496 |
+
|
| 497 |
+
### Total Items
|
| 498 |
+
|
| 499 |
+
- Rewrite each rubric item in the template as a self-contained evaluation criterion following the above areas and criteria, adhere to the content of the golden review dynamically. Keep the original weight of the rubric item.
|
| 500 |
+
- Do NOT add or remove any rubric item. Strictly adhere to the areas and criteria in the template.
|
| 501 |
+
|
| 502 |
+
### Category Guidance
|
| 503 |
+
|
| 504 |
+
- **Essential:** Critical facts or safety checks; missing this invalidates the response.
|
| 505 |
+
- **Important:** Key reasoning, completeness, or clarity; strongly affects quality.
|
| 506 |
+
- **Optional:** Helpful stylistic or depth additions; not required
|
| 507 |
+
- **Pitfall:** Common important mistakes; each must begin with "Pitfall Criteria: Does not mention …" or "Pitfall Criteria: Recommends …"
|
| 508 |
+
|
| 509 |
+
---
|
| 510 |
+
|
| 511 |
+
## Output Requirements
|
| 512 |
+
|
| 513 |
+
- Provide a **JSON array** of rubric objects.
|
| 514 |
+
- Each object must contain **exactly three keys**:`{ "title": "...", "description": "...", "weight": ... }`
|
| 515 |
+
- No additional keys allowed.
|
| 516 |
+
- Do **not** copy large blocks of the question or reference answer.
|
| 517 |
+
- Every description must **start with its category prefix**.
|
| 518 |
+
|
| 519 |
+
Now, provided is the golden review for you to refer on:
|
| 520 |
+
|
| 521 |
+
<<golden_review>>
|
| 522 |
+
|
| 523 |
+
Following is the paper content for you to refer on:
|
| 524 |
+
|
| 525 |
+
<<paper_context>>
|
| 526 |
+
|
| 527 |
+
And below is the rubric template as json for you to refer on:
|
| 528 |
+
|
| 529 |
+
<<rubric_template>>
|
| 530 |
+
|
| 531 |
+
Please start your rubric generation following the above information provided and the instructions.
|
| 532 |
+
|
| 533 |
+
# Evaluator prompt (v1 - for evaluating reviews using rubrics)
|
| 534 |
+
v1_evaluator_prompt: |
|
| 535 |
+
You are an expert academic reviewer tasked with evaluating a research paper review following a list of rubrics.
|
| 536 |
+
|
| 537 |
+
Coupled with the paper content and the review, you need to score the review on each rubric and provide corresponding rationales.
|
| 538 |
+
|
| 539 |
+
The rubrics are as follows:
|
| 540 |
+
|
| 541 |
+
{rubrics_json}
|
| 542 |
+
|
| 543 |
+
The score should be in the range of -2 to 2, do NOT refer to the value of the weight when assigning the score during evaluation. Treat them as indications of whether this rubric is positive or negative.
|
| 544 |
+
If the weight is positive, this rubric is positive. If the weight is negative, this rubric is negative.
|
| 545 |
+
|
| 546 |
+
For each rubric:
|
| 547 |
+
- If this rubric is positive:
|
| 548 |
+
- If the review meet none of the key points in this rubric, assign 0.
|
| 549 |
+
- If the review meet at least half of the key points in this rubric, assign 1.
|
| 550 |
+
- If the review meet all of the key points in this rubric, assign 2.
|
| 551 |
+
- If this rubric is negative:
|
| 552 |
+
- If the review does NOT suffer from the pitfall (good), assign 0.
|
| 553 |
+
- If it DOES suffer any of the key points in this pitfall rubric, assign -1.
|
| 554 |
+
- If it DOES suffer all of the key points in this pitfall rubric, assign -2.
|
| 555 |
+
|
| 556 |
+
Your output format should be:
|
| 557 |
+
{{
|
| 558 |
+
"<first_rubric_title>": {{
|
| 559 |
+
"score": <-2 to 2>,
|
| 560 |
+
"rationale": "<rationale explaining the score>"
|
| 561 |
+
}},
|
| 562 |
+
...
|
| 563 |
+
"<last_rubric_title>": {{
|
| 564 |
+
"score": <-2 to 2>,
|
| 565 |
+
"rationale": "<rationale explaining the score>"
|
| 566 |
+
}}
|
| 567 |
+
}}
|
| 568 |
+
|
| 569 |
+
DO NOT include any other text in your output. Output the JSON string ONLY.
|
| 570 |
+
|
| 571 |
+
Now, provided is the paper content:
|
| 572 |
+
|
| 573 |
+
<<paper_content>>
|
| 574 |
+
|
| 575 |
+
Now, provided is the review:
|
| 576 |
+
|
| 577 |
+
<<review>>
|
| 578 |
+
|
| 579 |
+
Please start your evaluation following the above information provided and the instructions.
|
| 580 |
+
|
shared/configs/reranker_endpoint_pool.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
http://localhost:8008
|
| 2 |
+
http://localhost:8009
|
| 3 |
+
http://localhost:8010
|
| 4 |
+
http://localhost:8011
|
| 5 |
+
http://localhost:8012
|
| 6 |
+
http://localhost:8013
|
| 7 |
+
http://localhost:8014
|
| 8 |
+
http://localhost:8015
|
shared/configs/vllm_endpoint_pool.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
http://localhost:8001/v1
|
| 2 |
+
http://localhost:8002/v1
|
| 3 |
+
http://localhost:8003/v1
|
| 4 |
+
http://localhost:8004/v1
|
| 5 |
+
http://localhost:8005/v1
|
| 6 |
+
http://localhost:8006/v1
|
| 7 |
+
http://localhost:8007/v1
|
shared/utils/__init__.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared utilities for the unified review system.
|
| 3 |
+
"""
|
| 4 |
+
# Core utilities (always available)
|
| 5 |
+
from .json_parser import (
|
| 6 |
+
parse_review_markdown,
|
| 7 |
+
parse_keywords_json,
|
| 8 |
+
parse_summary_json,
|
| 9 |
+
parse_json_response,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from .prompt_loader import get_prompt_loader
|
| 13 |
+
|
| 14 |
+
# API Key Pool and Endpoint Pool (always available)
|
| 15 |
+
try:
|
| 16 |
+
from .asta_api_key_pool import AstaAPIKeyPool
|
| 17 |
+
_all_pools = ['AstaAPIKeyPool']
|
| 18 |
+
except ImportError:
|
| 19 |
+
AstaAPIKeyPool = None
|
| 20 |
+
_all_pools = []
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from .vllm_endpoint_pool import VLLMEndpointPool
|
| 24 |
+
_all_pools.append('VLLMEndpointPool')
|
| 25 |
+
except ImportError:
|
| 26 |
+
VLLMEndpointPool = None
|
| 27 |
+
|
| 28 |
+
if _all_pools:
|
| 29 |
+
__all__ = _all_pools
|
| 30 |
+
|
| 31 |
+
# Lazy imports for heavy dependencies (LLM-related)
|
| 32 |
+
# These may fail if dependencies are not installed, but that's okay
|
| 33 |
+
def _lazy_import_llm_services():
|
| 34 |
+
"""Lazy import LLM services to avoid dependency issues"""
|
| 35 |
+
try:
|
| 36 |
+
from .llm_service import LLMService, ChatMessage
|
| 37 |
+
from .llm_service_factory import (
|
| 38 |
+
get_llm_service_factory,
|
| 39 |
+
LLMServiceFactory,
|
| 40 |
+
load_api_key_from_config,
|
| 41 |
+
)
|
| 42 |
+
return {
|
| 43 |
+
'LLMService': LLMService,
|
| 44 |
+
'ChatMessage': ChatMessage,
|
| 45 |
+
'get_llm_service_factory': get_llm_service_factory,
|
| 46 |
+
'LLMServiceFactory': LLMServiceFactory,
|
| 47 |
+
'load_api_key_from_config': load_api_key_from_config,
|
| 48 |
+
}
|
| 49 |
+
except ImportError:
|
| 50 |
+
return {}
|
| 51 |
+
|
| 52 |
+
def _lazy_import_llm_implementations():
|
| 53 |
+
"""Lazy import LLM service implementations"""
|
| 54 |
+
result = {}
|
| 55 |
+
try:
|
| 56 |
+
from .vllm_service import VLLMService
|
| 57 |
+
result['VLLMService'] = VLLMService
|
| 58 |
+
except ImportError:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
from .gpt_service import GPTService
|
| 63 |
+
result['GPTService'] = GPTService
|
| 64 |
+
except ImportError:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
from .mock_llm_service import MockLLMService, extract_title_from_latex, extract_abstract_from_latex
|
| 69 |
+
result['MockLLMService'] = MockLLMService
|
| 70 |
+
result['extract_title_from_latex'] = extract_title_from_latex
|
| 71 |
+
result['extract_abstract_from_latex'] = extract_abstract_from_latex
|
| 72 |
+
except ImportError:
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
return result
|
| 76 |
+
|
| 77 |
+
def _lazy_import_other():
|
| 78 |
+
"""Lazy import other utilities"""
|
| 79 |
+
result = {}
|
| 80 |
+
try:
|
| 81 |
+
from .reranker import rerank_paragraphs_bge
|
| 82 |
+
result['rerank_paragraphs_bge'] = rerank_paragraphs_bge
|
| 83 |
+
except ImportError:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
from .review_logger import ReviewLogger
|
| 88 |
+
result['ReviewLogger'] = ReviewLogger
|
| 89 |
+
except ImportError:
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
# Populate __all__ dynamically
|
| 95 |
+
_llm_services = _lazy_import_llm_services()
|
| 96 |
+
_llm_impls = _lazy_import_llm_implementations()
|
| 97 |
+
_other = _lazy_import_other()
|
| 98 |
+
|
| 99 |
+
# Make all lazy imports available at module level
|
| 100 |
+
globals().update(_llm_services)
|
| 101 |
+
globals().update(_llm_impls)
|
| 102 |
+
globals().update(_other)
|
| 103 |
+
|
| 104 |
+
__all__ = [
|
| 105 |
+
'parse_review_markdown',
|
| 106 |
+
'parse_keywords_json',
|
| 107 |
+
'parse_summary_json',
|
| 108 |
+
'parse_json_response',
|
| 109 |
+
'get_prompt_loader',
|
| 110 |
+
] + list(_llm_services.keys()) + list(_llm_impls.keys()) + list(_other.keys())
|
| 111 |
+
|
| 112 |
+
if AstaAPIKeyPool:
|
| 113 |
+
__all__.append('AstaAPIKeyPool')
|
shared/utils/asta_api_key_pool.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Asta API Key Pool Manager
|
| 3 |
+
|
| 4 |
+
Manage multiple Asta API keys, implement key rotation and error handling.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Optional, Dict
|
| 11 |
+
from threading import Lock
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AstaAPIKeyPool:
|
| 15 |
+
"""
|
| 16 |
+
Asta API Key Pool Manager
|
| 17 |
+
|
| 18 |
+
Features:
|
| 19 |
+
1. Load multiple API keys from file
|
| 20 |
+
2. Randomly rotate keys
|
| 21 |
+
3. Track each key's usage status and errors
|
| 22 |
+
4. Implement debounce retry strategy
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, pool_path: Optional[str] = None, keys: Optional[List[str]] = None):
|
| 26 |
+
"""
|
| 27 |
+
Initialize API Key Pool
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
pool_path: API keys file path (one key per line)
|
| 31 |
+
keys: directly provide keys list (prior to pool_path)
|
| 32 |
+
"""
|
| 33 |
+
self.keys: List[str] = []
|
| 34 |
+
self.used_indices: List[int] = [] # indices used in current rotation
|
| 35 |
+
self.key_status: Dict[str, Dict] = {} # status information for each key
|
| 36 |
+
self.lock = Lock() # thread safe lock
|
| 37 |
+
|
| 38 |
+
# load keys
|
| 39 |
+
if keys:
|
| 40 |
+
self.keys = [k.strip() for k in keys if k.strip()]
|
| 41 |
+
elif os.environ.get("ASTA_API_KEY"):
|
| 42 |
+
# Try to get one or more keys from environment variable (comma-separated)
|
| 43 |
+
self.keys = [k.strip() for k in os.environ.get("ASTA_API_KEY").split(",") if k.strip()]
|
| 44 |
+
elif pool_path:
|
| 45 |
+
self._load_from_file(pool_path)
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"No API keys available. Provide keys via pool_path, keys parameter, "
|
| 49 |
+
"or ASTA_API_KEY environment variable."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if not self.keys:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
"No API keys available. Provide keys via pool_path, keys parameter, "
|
| 55 |
+
"or ASTA_API_KEY environment variable."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# initialize status for each key
|
| 59 |
+
for key in self.keys:
|
| 60 |
+
self.key_status[key] = {
|
| 61 |
+
'error_count': 0,
|
| 62 |
+
'last_error_time': None,
|
| 63 |
+
'consecutive_errors': 0,
|
| 64 |
+
'total_requests': 0,
|
| 65 |
+
'successful_requests': 0,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
def _load_from_file(self, pool_path: str):
|
| 69 |
+
"""Load API keys from file"""
|
| 70 |
+
path = Path(pool_path)
|
| 71 |
+
|
| 72 |
+
# if relative path, try to find file relative to shared/configs
|
| 73 |
+
if not path.is_absolute():
|
| 74 |
+
# try to find file relative to project root
|
| 75 |
+
project_root = Path(__file__).parent.parent.parent
|
| 76 |
+
path = project_root / "shared" / "configs" / pool_path
|
| 77 |
+
if not path.exists():
|
| 78 |
+
# try to find file relative to shared/configs
|
| 79 |
+
path = Path(__file__).parent.parent / "configs" / pool_path
|
| 80 |
+
|
| 81 |
+
if not path.exists():
|
| 82 |
+
raise FileNotFoundError(
|
| 83 |
+
f"API key pool file not found: {pool_path} (tried: {path})"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 87 |
+
lines = f.readlines()
|
| 88 |
+
|
| 89 |
+
self.keys = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
| 90 |
+
|
| 91 |
+
if not self.keys:
|
| 92 |
+
raise ValueError(f"No valid API keys found in pool file: {pool_path}")
|
| 93 |
+
|
| 94 |
+
def get_key(self) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Get next available API key (rotation strategy)
|
| 97 |
+
|
| 98 |
+
Strategy:
|
| 99 |
+
1. If current rotation is not complete, continue using unused keys
|
| 100 |
+
2. If current rotation is complete, start a new round (reset used_indices)
|
| 101 |
+
3. Prioritize keys with no recent errors
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Available API key
|
| 105 |
+
"""
|
| 106 |
+
with self.lock:
|
| 107 |
+
if not self.keys:
|
| 108 |
+
raise ValueError("No API keys available in pool")
|
| 109 |
+
|
| 110 |
+
# if current rotation is complete, start a new round
|
| 111 |
+
if len(self.used_indices) >= len(self.keys):
|
| 112 |
+
self.used_indices = []
|
| 113 |
+
|
| 114 |
+
# get indices not used in current rotation
|
| 115 |
+
available_indices = [i for i in range(len(self.keys)) if i not in self.used_indices]
|
| 116 |
+
|
| 117 |
+
if not available_indices:
|
| 118 |
+
# all keys are used in current rotation, start a new round
|
| 119 |
+
available_indices = list(range(len(self.keys)))
|
| 120 |
+
self.used_indices = []
|
| 121 |
+
|
| 122 |
+
# prioritize keys with fewer errors (randomly select, but prioritize keys with higher success rate and fewer errors)
|
| 123 |
+
key_scores = []
|
| 124 |
+
for idx in available_indices:
|
| 125 |
+
key = self.keys[idx]
|
| 126 |
+
status = self.key_status[key]
|
| 127 |
+
|
| 128 |
+
# calculate score: error count, success rate, score越高
|
| 129 |
+
error_count = status['error_count']
|
| 130 |
+
total = status['total_requests']
|
| 131 |
+
success_rate = (status['successful_requests'] / total) if total > 0 else 1.0
|
| 132 |
+
|
| 133 |
+
# if recent error, reduce score
|
| 134 |
+
recent_error_penalty = 0
|
| 135 |
+
if status['last_error_time']:
|
| 136 |
+
time_since_error = time.time() - status['last_error_time']
|
| 137 |
+
if time_since_error < 60: # 1 minute
|
| 138 |
+
recent_error_penalty = 0.5
|
| 139 |
+
|
| 140 |
+
score = success_rate - (error_count * 0.1) - recent_error_penalty
|
| 141 |
+
key_scores.append((idx, score))
|
| 142 |
+
|
| 143 |
+
# sort by score, select highest score (but add some randomness)
|
| 144 |
+
key_scores.sort(key=lambda x: x[1], reverse=True)
|
| 145 |
+
|
| 146 |
+
# select from top 50% (add randomness but prioritize better keys)
|
| 147 |
+
top_n = max(1, len(key_scores) // 2) if len(key_scores) > 1 else 1
|
| 148 |
+
selected_idx, _ = random.choice(key_scores[:top_n])
|
| 149 |
+
|
| 150 |
+
# mark as used
|
| 151 |
+
self.used_indices.append(selected_idx)
|
| 152 |
+
|
| 153 |
+
selected_key = self.keys[selected_idx]
|
| 154 |
+
self.key_status[selected_key]['total_requests'] += 1
|
| 155 |
+
|
| 156 |
+
return selected_key
|
| 157 |
+
|
| 158 |
+
def mark_success(self, key: str):
|
| 159 |
+
"""mark key as successful"""
|
| 160 |
+
with self.lock:
|
| 161 |
+
if key in self.key_status:
|
| 162 |
+
self.key_status[key]['successful_requests'] += 1
|
| 163 |
+
self.key_status[key]['consecutive_errors'] = 0
|
| 164 |
+
|
| 165 |
+
def mark_error(self, key: str, error_type: str = "rate_limit"):
|
| 166 |
+
"""
|
| 167 |
+
mark key as failed
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
key: failed API key
|
| 171 |
+
error_type: error type ("rate_limit", "auth_error", "server_error", "other")
|
| 172 |
+
"""
|
| 173 |
+
with self.lock:
|
| 174 |
+
if key in self.key_status:
|
| 175 |
+
status = self.key_status[key]
|
| 176 |
+
status['error_count'] += 1
|
| 177 |
+
status['consecutive_errors'] += 1
|
| 178 |
+
status['last_error_time'] = time.time()
|
| 179 |
+
|
| 180 |
+
def get_status(self) -> Dict:
|
| 181 |
+
"""get pool status information (for debugging)"""
|
| 182 |
+
with self.lock:
|
| 183 |
+
return {
|
| 184 |
+
'total_keys': len(self.keys),
|
| 185 |
+
'current_round_progress': f"{len(self.used_indices)}/{len(self.keys)}",
|
| 186 |
+
'keys_status': {
|
| 187 |
+
key: {
|
| 188 |
+
'error_count': status['error_count'],
|
| 189 |
+
'successful_requests': status['successful_requests'],
|
| 190 |
+
'total_requests': status['total_requests'],
|
| 191 |
+
'success_rate': (
|
| 192 |
+
status['successful_requests'] / status['total_requests']
|
| 193 |
+
if status['total_requests'] > 0 else 0.0
|
| 194 |
+
),
|
| 195 |
+
'consecutive_errors': status['consecutive_errors'],
|
| 196 |
+
'last_error_time': status['last_error_time'],
|
| 197 |
+
}
|
| 198 |
+
for key, status in self.key_status.items()
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
def reset_round(self):
|
| 203 |
+
"""reset current rotation (force start a new round)"""
|
| 204 |
+
with self.lock:
|
| 205 |
+
self.used_indices = []
|
shared/utils/gpt_service.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI GPT API service implementation
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Dict, Optional, Any, Union
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from .llm_service import LLMService, ChatMessage
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GPTService(LLMService):
|
| 11 |
+
"""
|
| 12 |
+
OpenAI GPT API service wrapper
|
| 13 |
+
|
| 14 |
+
This service connects to OpenAI's API (or compatible API)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
api_key: Optional[str] = None,
|
| 20 |
+
model_name: str = "gpt-4o",
|
| 21 |
+
base_url: Optional[str] = None,
|
| 22 |
+
timeout: int = 300,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Initialize GPT service
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
api_key: OpenAI / OpenRouter API key (set via env if omitted)
|
| 29 |
+
model_name: Model name (e.g., gpt-4o, openai/gpt-oss-120b:free, etc.)
|
| 30 |
+
base_url: API base URL (default: https://api.openai.com/v1)
|
| 31 |
+
timeout: Request timeout in seconds
|
| 32 |
+
"""
|
| 33 |
+
# Prefer explicit parameter, then common environment variables.
|
| 34 |
+
# This allows using OpenRouter (OPENROUTER_API_KEY) without hard-coding secrets.
|
| 35 |
+
self.api_key = (
|
| 36 |
+
api_key
|
| 37 |
+
or os.environ.get("OPENAI_API_KEY")
|
| 38 |
+
or os.environ.get("OPENROUTER_API_KEY")
|
| 39 |
+
)
|
| 40 |
+
if not self.api_key:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
"API key is required. Set OPENAI_API_KEY or OPENROUTER_API_KEY "
|
| 43 |
+
"environment variable, or pass api_key parameter."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.model_name = model_name
|
| 47 |
+
# Prefer explicit base_url, then environment variables, then OpenAI default.
|
| 48 |
+
# This allows swapping in any OpenAI-compatible endpoint (e.g., OpenRouter)
|
| 49 |
+
# without changing code.
|
| 50 |
+
self.base_url = (
|
| 51 |
+
base_url
|
| 52 |
+
or os.environ.get("OPENAI_BASE_URL")
|
| 53 |
+
or os.environ.get("OPENROUTER_BASE_URL")
|
| 54 |
+
or "https://api.openai.com/v1"
|
| 55 |
+
)
|
| 56 |
+
self.timeout = timeout
|
| 57 |
+
|
| 58 |
+
self.client = OpenAI(
|
| 59 |
+
api_key=self.api_key,
|
| 60 |
+
base_url=self.base_url,
|
| 61 |
+
timeout=self.timeout,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def _format_messages(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> List[Dict[str, str]]:
|
| 65 |
+
"""Format messages for OpenAI API"""
|
| 66 |
+
formatted = []
|
| 67 |
+
for msg in messages:
|
| 68 |
+
if isinstance(msg, ChatMessage):
|
| 69 |
+
formatted.append({"role": msg.role, "content": msg.content})
|
| 70 |
+
elif isinstance(msg, dict):
|
| 71 |
+
formatted.append(msg)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Invalid message type: {type(msg)}")
|
| 74 |
+
return formatted
|
| 75 |
+
|
| 76 |
+
def generate(
|
| 77 |
+
self,
|
| 78 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 79 |
+
temperature: float = 0.7,
|
| 80 |
+
top_p: float = 0.95,
|
| 81 |
+
top_k: int = 20,
|
| 82 |
+
max_tokens: int = 16384,
|
| 83 |
+
presence_penalty: float = 0.0,
|
| 84 |
+
**kwargs
|
| 85 |
+
) -> str:
|
| 86 |
+
"""
|
| 87 |
+
Generate text from messages
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
messages: List of chat messages
|
| 91 |
+
temperature: Sampling temperature
|
| 92 |
+
top_p: Top-p sampling parameter
|
| 93 |
+
top_k: Top-k sampling parameter (not used by GPT API, but kept for compatibility)
|
| 94 |
+
max_tokens: Maximum tokens to generate
|
| 95 |
+
presence_penalty: Presence penalty (0-2)
|
| 96 |
+
**kwargs: Additional parameters
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Generated text
|
| 100 |
+
"""
|
| 101 |
+
formatted_messages = self._format_messages(messages)
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
# GPT API doesn't support top_k, so we exclude it
|
| 105 |
+
# Some newer models (like GPT 5.2) use max_completion_tokens instead of max_tokens
|
| 106 |
+
params = {
|
| 107 |
+
"model": self.model_name,
|
| 108 |
+
"messages": formatted_messages,
|
| 109 |
+
"temperature": temperature,
|
| 110 |
+
"top_p": top_p,
|
| 111 |
+
"presence_penalty": presence_penalty,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Check if model requires max_completion_tokens instead of max_tokens
|
| 115 |
+
# Models that use max_completion_tokens: o1, o1-preview, o1-mini, and newer models
|
| 116 |
+
if any(model_name in self.model_name.lower() for model_name in ["o1", "gpt-5", "gpt5"]):
|
| 117 |
+
params["max_completion_tokens"] = max_tokens
|
| 118 |
+
else:
|
| 119 |
+
params["max_tokens"] = max_tokens
|
| 120 |
+
|
| 121 |
+
params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
|
| 122 |
+
|
| 123 |
+
response = self.client.chat.completions.create(**params)
|
| 124 |
+
|
| 125 |
+
return response.choices[0].message.content
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
# If max_tokens fails, try max_completion_tokens as fallback
|
| 129 |
+
if "max_tokens" in str(e) and "max_completion_tokens" in str(e):
|
| 130 |
+
try:
|
| 131 |
+
params = {
|
| 132 |
+
"model": self.model_name,
|
| 133 |
+
"messages": formatted_messages,
|
| 134 |
+
"temperature": temperature,
|
| 135 |
+
"top_p": top_p,
|
| 136 |
+
"max_completion_tokens": max_tokens,
|
| 137 |
+
"presence_penalty": presence_penalty,
|
| 138 |
+
}
|
| 139 |
+
params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
|
| 140 |
+
response = self.client.chat.completions.create(**params)
|
| 141 |
+
return response.choices[0].message.content
|
| 142 |
+
except Exception as e2:
|
| 143 |
+
raise RuntimeError(f"Error generating text from GPT service: {e2}")
|
| 144 |
+
raise RuntimeError(f"Error generating text from GPT service: {e}")
|
| 145 |
+
|
| 146 |
+
def stream_generate(
|
| 147 |
+
self,
|
| 148 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 149 |
+
temperature: float = 0.7,
|
| 150 |
+
top_p: float = 0.95,
|
| 151 |
+
top_k: int = 20,
|
| 152 |
+
max_tokens: int = 16384,
|
| 153 |
+
presence_penalty: float = 0.0,
|
| 154 |
+
**kwargs
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Stream generate text from messages
|
| 158 |
+
|
| 159 |
+
Yields:
|
| 160 |
+
Generated text chunks
|
| 161 |
+
"""
|
| 162 |
+
formatted_messages = self._format_messages(messages)
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
params = {
|
| 166 |
+
"model": self.model_name,
|
| 167 |
+
"messages": formatted_messages,
|
| 168 |
+
"temperature": temperature,
|
| 169 |
+
"top_p": top_p,
|
| 170 |
+
"presence_penalty": presence_penalty,
|
| 171 |
+
"stream": True,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
# Check if model requires max_completion_tokens instead of max_tokens
|
| 175 |
+
if any(model_name in self.model_name.lower() for model_name in ["o1", "gpt-5", "gpt5"]):
|
| 176 |
+
params["max_completion_tokens"] = max_tokens
|
| 177 |
+
else:
|
| 178 |
+
params["max_tokens"] = max_tokens
|
| 179 |
+
|
| 180 |
+
params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
|
| 181 |
+
|
| 182 |
+
stream = self.client.chat.completions.create(**params)
|
| 183 |
+
|
| 184 |
+
for chunk in stream:
|
| 185 |
+
if chunk.choices[0].delta.content:
|
| 186 |
+
yield chunk.choices[0].delta.content
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
# If max_tokens fails, try max_completion_tokens as fallback
|
| 190 |
+
if "max_tokens" in str(e) and "max_completion_tokens" in str(e):
|
| 191 |
+
try:
|
| 192 |
+
params = {
|
| 193 |
+
"model": self.model_name,
|
| 194 |
+
"messages": formatted_messages,
|
| 195 |
+
"temperature": temperature,
|
| 196 |
+
"top_p": top_p,
|
| 197 |
+
"max_completion_tokens": max_tokens,
|
| 198 |
+
"presence_penalty": presence_penalty,
|
| 199 |
+
"stream": True,
|
| 200 |
+
}
|
| 201 |
+
params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
|
| 202 |
+
stream = self.client.chat.completions.create(**params)
|
| 203 |
+
for chunk in stream:
|
| 204 |
+
if chunk.choices[0].delta.content:
|
| 205 |
+
yield chunk.choices[0].delta.content
|
| 206 |
+
return
|
| 207 |
+
except Exception as e2:
|
| 208 |
+
raise RuntimeError(f"Error streaming text from GPT service: {e2}")
|
| 209 |
+
raise RuntimeError(f"Error streaming text from GPT service: {e}")
|
| 210 |
+
|
shared/utils/json_parser.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Robust JSON parsing utilities for LLM responses
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def extract_json_from_text(text: str) -> Optional[str]:
|
| 10 |
+
"""
|
| 11 |
+
Extract JSON from text by removing markdown code block markers
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
text: Text that may contain JSON in markdown code blocks or plain JSON
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Extracted JSON string or None if not found
|
| 18 |
+
"""
|
| 19 |
+
if not text:
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
text_stripped = text.strip()
|
| 23 |
+
|
| 24 |
+
# Try to parse as plain JSON first (no code blocks)
|
| 25 |
+
try:
|
| 26 |
+
json.loads(text_stripped)
|
| 27 |
+
return text_stripped
|
| 28 |
+
except json.JSONDecodeError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
# Remove markdown code block markers: ```json ... ``` or ``` ... ```
|
| 32 |
+
if text_stripped.startswith('```json'):
|
| 33 |
+
# Remove ```json at start and ``` at end
|
| 34 |
+
if text_stripped.endswith('```'):
|
| 35 |
+
text_stripped = text_stripped[7:-3].strip()
|
| 36 |
+
else:
|
| 37 |
+
# No closing ```, just remove opening
|
| 38 |
+
text_stripped = text_stripped[7:].strip()
|
| 39 |
+
elif text_stripped.startswith('```'):
|
| 40 |
+
# Handle ``` ... ``` (without json label)
|
| 41 |
+
if text_stripped.endswith('```'):
|
| 42 |
+
text_stripped = text_stripped[3:-3].strip()
|
| 43 |
+
else:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
# Try to parse as JSON after removing code block markers
|
| 47 |
+
try:
|
| 48 |
+
json.loads(text_stripped)
|
| 49 |
+
return text_stripped
|
| 50 |
+
except json.JSONDecodeError:
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_json_response(text: str, fallback: Any = None) -> Any:
|
| 55 |
+
"""
|
| 56 |
+
Parse JSON from LLM response with robust error handling
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
text: LLM response text
|
| 60 |
+
fallback: Fallback value if parsing fails
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Parsed JSON object or fallback
|
| 64 |
+
"""
|
| 65 |
+
if not text:
|
| 66 |
+
return fallback
|
| 67 |
+
|
| 68 |
+
# Extract JSON from text
|
| 69 |
+
json_str = extract_json_from_text(text)
|
| 70 |
+
|
| 71 |
+
if json_str is None:
|
| 72 |
+
return fallback
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
return json.loads(json_str)
|
| 76 |
+
except json.JSONDecodeError as e:
|
| 77 |
+
# Try to fix common JSON issues
|
| 78 |
+
json_str = fix_json_common_issues(json_str)
|
| 79 |
+
try:
|
| 80 |
+
return json.loads(json_str)
|
| 81 |
+
except json.JSONDecodeError:
|
| 82 |
+
return fallback
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def fix_json_common_issues(json_str: str) -> str:
|
| 86 |
+
"""
|
| 87 |
+
Fix common JSON formatting issues
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
json_str: JSON string that may have issues
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Fixed JSON string
|
| 94 |
+
"""
|
| 95 |
+
# Remove trailing commas
|
| 96 |
+
json_str = re.sub(r',\s*}', '}', json_str)
|
| 97 |
+
json_str = re.sub(r',\s*]', ']', json_str)
|
| 98 |
+
|
| 99 |
+
# Fix single quotes to double quotes (basic)
|
| 100 |
+
json_str = re.sub(r"'(\w+)':", r'"\1":', json_str)
|
| 101 |
+
|
| 102 |
+
# Remove comments (basic)
|
| 103 |
+
json_str = re.sub(r'//.*?$', '', json_str, flags=re.MULTILINE)
|
| 104 |
+
json_str = re.sub(r'/\*.*?\*/', '', json_str, flags=re.DOTALL)
|
| 105 |
+
|
| 106 |
+
return json_str
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def parse_keywords_json(response: str) -> List[str]:
|
| 110 |
+
"""
|
| 111 |
+
Parse keywords from JSON response
|
| 112 |
+
|
| 113 |
+
Expected format:
|
| 114 |
+
{"keywords": ["keyword1", "keyword2", ...]}
|
| 115 |
+
or
|
| 116 |
+
["keyword1", "keyword2", ...]
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
response: LLM response text
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
List of keywords, or empty list if parsing fails
|
| 123 |
+
"""
|
| 124 |
+
if response is None:
|
| 125 |
+
return []
|
| 126 |
+
|
| 127 |
+
parsed = parse_json_response(response, fallback=None)
|
| 128 |
+
|
| 129 |
+
if parsed is None:
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
# Handle dict format: {"keywords": [...]}
|
| 133 |
+
if isinstance(parsed, dict):
|
| 134 |
+
if "keywords" in parsed and isinstance(parsed["keywords"], list):
|
| 135 |
+
return parsed["keywords"][:5]
|
| 136 |
+
return []
|
| 137 |
+
|
| 138 |
+
# Handle list format: ["keyword1", "keyword2", ...]
|
| 139 |
+
if isinstance(parsed, list):
|
| 140 |
+
return parsed[:5]
|
| 141 |
+
|
| 142 |
+
return []
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def parse_summary_json(response: str) -> str:
|
| 146 |
+
"""
|
| 147 |
+
Parse summary from JSON response
|
| 148 |
+
|
| 149 |
+
Expected format:
|
| 150 |
+
{"summary": "summary text"}
|
| 151 |
+
or
|
| 152 |
+
{"text": "summary text", "summary": "summary text"}
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
response: LLM response text
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Summary text
|
| 159 |
+
"""
|
| 160 |
+
parsed = parse_json_response(response, fallback=None)
|
| 161 |
+
|
| 162 |
+
if parsed is None:
|
| 163 |
+
# Fallback to text parsing
|
| 164 |
+
return response.strip()
|
| 165 |
+
|
| 166 |
+
if isinstance(parsed, dict):
|
| 167 |
+
# Try different possible keys
|
| 168 |
+
for key in ["summary", "text", "content", "description"]:
|
| 169 |
+
if key in parsed:
|
| 170 |
+
summary = str(parsed[key]).strip()
|
| 171 |
+
if summary:
|
| 172 |
+
return summary
|
| 173 |
+
|
| 174 |
+
# Fallback to text parsing
|
| 175 |
+
return response.strip()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def parse_review_json(response: str, review_format: str = "detailed") -> Dict[str, Any]:
|
| 179 |
+
"""
|
| 180 |
+
Parse review from JSON or markdown response
|
| 181 |
+
|
| 182 |
+
Expected formats:
|
| 183 |
+
- JSON: {"summary": "...", "soundness": 5, ...}
|
| 184 |
+
- Markdown: ## Summary\n\n...\n## Soundness\n\n...
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
response: LLM response text (JSON or markdown)
|
| 188 |
+
review_format: Review format type (detailed, summary, structured)
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Review dictionary with parsed fields
|
| 192 |
+
"""
|
| 193 |
+
# First try to parse as JSON
|
| 194 |
+
parsed = parse_json_response(response, fallback=None)
|
| 195 |
+
|
| 196 |
+
if parsed is not None and isinstance(parsed, dict):
|
| 197 |
+
# JSON format - ensure it has required fields
|
| 198 |
+
if "review" not in parsed:
|
| 199 |
+
parsed["review"] = response.strip()
|
| 200 |
+
return parsed
|
| 201 |
+
|
| 202 |
+
# If not JSON, try to parse as markdown
|
| 203 |
+
if "## " in response or "##" in response:
|
| 204 |
+
markdown_parsed = parse_review_markdown(response)
|
| 205 |
+
if len(markdown_parsed) > 1: # More than just "review" field
|
| 206 |
+
return markdown_parsed
|
| 207 |
+
|
| 208 |
+
# Fallback to text parsing
|
| 209 |
+
return {"review": response.strip()}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def parse_review_markdown(markdown_text: str) -> Dict[str, Any]:
|
| 213 |
+
"""
|
| 214 |
+
Parse review from markdown format with sections like:
|
| 215 |
+
## Summary
|
| 216 |
+
...
|
| 217 |
+
## Soundness
|
| 218 |
+
...
|
| 219 |
+
etc.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
markdown_text: Markdown formatted review text
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Review dictionary with parsed fields
|
| 226 |
+
"""
|
| 227 |
+
review_dict = {"review": markdown_text.strip()}
|
| 228 |
+
|
| 229 |
+
# Pattern to match markdown sections: ## SectionName\n\ncontent
|
| 230 |
+
section_pattern = r'##\s*([^\n]+)\s*\n\n(.*?)(?=\n##\s*|$)'
|
| 231 |
+
matches = re.finditer(section_pattern, markdown_text, re.DOTALL)
|
| 232 |
+
|
| 233 |
+
for match in matches:
|
| 234 |
+
section_name = match.group(1).strip()
|
| 235 |
+
section_content = match.group(2).strip()
|
| 236 |
+
|
| 237 |
+
# Normalize section name (case-insensitive, remove extra spaces)
|
| 238 |
+
section_name_lower = section_name.lower()
|
| 239 |
+
|
| 240 |
+
# Map section names to dictionary keys
|
| 241 |
+
if "summary" in section_name_lower:
|
| 242 |
+
review_dict["summary"] = section_content
|
| 243 |
+
elif "soundness" in section_name_lower:
|
| 244 |
+
# Extract score - prioritize single float number (e.g., "3.0", "4.5")
|
| 245 |
+
# If format is "3 / 5" or "**3 / 5**", extract the number before the slash
|
| 246 |
+
score_val = None
|
| 247 |
+
|
| 248 |
+
lines = section_content.split('\n')
|
| 249 |
+
if lines:
|
| 250 |
+
first_line = lines[0].strip()
|
| 251 |
+
first_line_clean = re.sub(r'[`\*]', '', first_line)
|
| 252 |
+
|
| 253 |
+
# Try to match number at start that's NOT followed by "/"
|
| 254 |
+
num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
|
| 255 |
+
if num_match:
|
| 256 |
+
remaining = first_line_clean[len(num_match.group(0)):].strip()
|
| 257 |
+
if not remaining.startswith('/'):
|
| 258 |
+
try:
|
| 259 |
+
score_val = float(num_match.group(1))
|
| 260 |
+
except (ValueError, IndexError):
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
# If not found and there's a "/", try to extract number before "/" (e.g., "3 / 5" -> 3)
|
| 264 |
+
if score_val is None and '/' in first_line_clean:
|
| 265 |
+
fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
|
| 266 |
+
if fraction_match:
|
| 267 |
+
try:
|
| 268 |
+
score_val = float(fraction_match.group(1))
|
| 269 |
+
except (ValueError, IndexError):
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
# If not found, try to find number after "score:" or "rating:"
|
| 273 |
+
if score_val is None:
|
| 274 |
+
score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
|
| 275 |
+
if score_match:
|
| 276 |
+
try:
|
| 277 |
+
score_val = float(score_match.group(1))
|
| 278 |
+
except (ValueError, IndexError):
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
if score_val is not None:
|
| 282 |
+
review_dict["soundness"] = score_val # Keep as float
|
| 283 |
+
elif "presentation" in section_name_lower:
|
| 284 |
+
score_val = None
|
| 285 |
+
lines = section_content.split('\n')
|
| 286 |
+
if lines:
|
| 287 |
+
first_line = lines[0].strip()
|
| 288 |
+
first_line_clean = re.sub(r'[`\*]', '', first_line)
|
| 289 |
+
|
| 290 |
+
num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
|
| 291 |
+
if num_match:
|
| 292 |
+
remaining = first_line_clean[len(num_match.group(0)):].strip()
|
| 293 |
+
if not remaining.startswith('/'):
|
| 294 |
+
try:
|
| 295 |
+
score_val = float(num_match.group(1))
|
| 296 |
+
except (ValueError, IndexError):
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
if score_val is None and '/' in first_line_clean:
|
| 300 |
+
fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
|
| 301 |
+
if fraction_match:
|
| 302 |
+
try:
|
| 303 |
+
score_val = float(fraction_match.group(1))
|
| 304 |
+
except (ValueError, IndexError):
|
| 305 |
+
pass
|
| 306 |
+
|
| 307 |
+
if score_val is None:
|
| 308 |
+
score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
|
| 309 |
+
if score_match:
|
| 310 |
+
try:
|
| 311 |
+
score_val = float(score_match.group(1))
|
| 312 |
+
except (ValueError, IndexError):
|
| 313 |
+
pass
|
| 314 |
+
|
| 315 |
+
if score_val is not None:
|
| 316 |
+
review_dict["presentation"] = score_val
|
| 317 |
+
elif "contribution" in section_name_lower:
|
| 318 |
+
score_val = None
|
| 319 |
+
lines = section_content.split('\n')
|
| 320 |
+
if lines:
|
| 321 |
+
first_line = lines[0].strip()
|
| 322 |
+
first_line_clean = re.sub(r'[`\*]', '', first_line)
|
| 323 |
+
|
| 324 |
+
num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
|
| 325 |
+
if num_match:
|
| 326 |
+
remaining = first_line_clean[len(num_match.group(0)):].strip()
|
| 327 |
+
if not remaining.startswith('/'):
|
| 328 |
+
try:
|
| 329 |
+
score_val = float(num_match.group(1))
|
| 330 |
+
except (ValueError, IndexError):
|
| 331 |
+
pass
|
| 332 |
+
|
| 333 |
+
if score_val is None and '/' in first_line_clean:
|
| 334 |
+
fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
|
| 335 |
+
if fraction_match:
|
| 336 |
+
try:
|
| 337 |
+
score_val = float(fraction_match.group(1))
|
| 338 |
+
except (ValueError, IndexError):
|
| 339 |
+
pass
|
| 340 |
+
|
| 341 |
+
if score_val is None:
|
| 342 |
+
score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
|
| 343 |
+
if score_match:
|
| 344 |
+
try:
|
| 345 |
+
score_val = float(score_match.group(1))
|
| 346 |
+
except (ValueError, IndexError):
|
| 347 |
+
pass
|
| 348 |
+
|
| 349 |
+
if score_val is not None:
|
| 350 |
+
review_dict["contribution"] = score_val
|
| 351 |
+
elif "strength" in section_name_lower:
|
| 352 |
+
review_dict["strengths"] = section_content
|
| 353 |
+
elif "weakness" in section_name_lower:
|
| 354 |
+
review_dict["weaknesses"] = section_content
|
| 355 |
+
elif "question" in section_name_lower:
|
| 356 |
+
review_dict["questions"] = section_content
|
| 357 |
+
elif "rating" in section_name_lower and "confidence" not in section_name_lower:
|
| 358 |
+
score_val = None
|
| 359 |
+
lines = section_content.split('\n')
|
| 360 |
+
if lines:
|
| 361 |
+
first_line = lines[0].strip()
|
| 362 |
+
first_line_clean = re.sub(r'[`\*]', '', first_line)
|
| 363 |
+
|
| 364 |
+
num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
|
| 365 |
+
if num_match:
|
| 366 |
+
remaining = first_line_clean[len(num_match.group(0)):].strip()
|
| 367 |
+
if not remaining.startswith('/'):
|
| 368 |
+
try:
|
| 369 |
+
score_val = float(num_match.group(1))
|
| 370 |
+
except (ValueError, IndexError):
|
| 371 |
+
pass
|
| 372 |
+
|
| 373 |
+
if score_val is None and '/' in first_line_clean:
|
| 374 |
+
fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
|
| 375 |
+
if fraction_match:
|
| 376 |
+
try:
|
| 377 |
+
score_val = float(fraction_match.group(1))
|
| 378 |
+
except (ValueError, IndexError):
|
| 379 |
+
pass
|
| 380 |
+
|
| 381 |
+
if score_val is None:
|
| 382 |
+
score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
|
| 383 |
+
if score_match:
|
| 384 |
+
try:
|
| 385 |
+
score_val = float(score_match.group(1))
|
| 386 |
+
except (ValueError, IndexError):
|
| 387 |
+
pass
|
| 388 |
+
|
| 389 |
+
if score_val is not None:
|
| 390 |
+
review_dict["rating"] = score_val
|
| 391 |
+
elif "confidence" in section_name_lower:
|
| 392 |
+
score_val = None
|
| 393 |
+
lines = section_content.split('\n')
|
| 394 |
+
if lines:
|
| 395 |
+
first_line = lines[0].strip()
|
| 396 |
+
first_line_clean = re.sub(r'[`\*]', '', first_line)
|
| 397 |
+
|
| 398 |
+
num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
|
| 399 |
+
if num_match:
|
| 400 |
+
remaining = first_line_clean[len(num_match.group(0)):].strip()
|
| 401 |
+
if not remaining.startswith('/'):
|
| 402 |
+
try:
|
| 403 |
+
score_val = float(num_match.group(1))
|
| 404 |
+
except (ValueError, IndexError):
|
| 405 |
+
pass
|
| 406 |
+
|
| 407 |
+
if score_val is None and '/' in first_line_clean:
|
| 408 |
+
fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
|
| 409 |
+
if fraction_match:
|
| 410 |
+
try:
|
| 411 |
+
score_val = float(fraction_match.group(1))
|
| 412 |
+
except (ValueError, IndexError):
|
| 413 |
+
pass
|
| 414 |
+
|
| 415 |
+
if score_val is None:
|
| 416 |
+
score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
|
| 417 |
+
if score_match:
|
| 418 |
+
try:
|
| 419 |
+
score_val = float(score_match.group(1))
|
| 420 |
+
except (ValueError, IndexError):
|
| 421 |
+
pass
|
| 422 |
+
|
| 423 |
+
if score_val is not None:
|
| 424 |
+
review_dict["confidence"] = score_val
|
| 425 |
+
elif "decision" in section_name_lower:
|
| 426 |
+
review_dict["decision"] = section_content
|
| 427 |
+
|
| 428 |
+
return review_dict
|
shared/utils/llm_service.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract base class for LLM services
|
| 3 |
+
"""
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import List, Dict, Optional, Any, Union
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChatMessage(BaseModel):
|
| 10 |
+
"""Chat message model"""
|
| 11 |
+
role: str # "system", "user", "assistant"
|
| 12 |
+
content: str
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LLMService(ABC):
|
| 16 |
+
"""Abstract base class for LLM services"""
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def generate(
|
| 20 |
+
self,
|
| 21 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 22 |
+
temperature: float = 0.7,
|
| 23 |
+
top_p: float = 0.8,
|
| 24 |
+
top_k: int = 20,
|
| 25 |
+
max_tokens: int = 16384,
|
| 26 |
+
presence_penalty: float = 0.0,
|
| 27 |
+
**kwargs
|
| 28 |
+
) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Generate text from messages
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
messages: List of chat messages
|
| 34 |
+
temperature: Sampling temperature
|
| 35 |
+
top_p: Top-p sampling parameter
|
| 36 |
+
top_k: Top-k sampling parameter
|
| 37 |
+
max_tokens: Maximum tokens to generate
|
| 38 |
+
presence_penalty: Presence penalty (0-2)
|
| 39 |
+
**kwargs: Additional parameters
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Generated text
|
| 43 |
+
"""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def stream_generate(
|
| 48 |
+
self,
|
| 49 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 50 |
+
temperature: float = 0.7,
|
| 51 |
+
top_p: float = 0.8,
|
| 52 |
+
top_k: int = 20,
|
| 53 |
+
max_tokens: int = 16384,
|
| 54 |
+
presence_penalty: float = 0.0,
|
| 55 |
+
**kwargs
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Stream generate text from messages
|
| 59 |
+
|
| 60 |
+
Yields:
|
| 61 |
+
Generated text chunks
|
| 62 |
+
"""
|
| 63 |
+
pass
|
| 64 |
+
|
shared/utils/llm_service_factory.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Factory for creating LLM services from configuration
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import yaml
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional, Dict, Any
|
| 8 |
+
from .llm_service import LLMService
|
| 9 |
+
from .vllm_service import VLLMService
|
| 10 |
+
from .gpt_service import GPTService
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LLMServiceFactory:
|
| 14 |
+
"""Factory for creating LLM services from configuration"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config_file: Optional[str] = None):
|
| 17 |
+
"""
|
| 18 |
+
Initialize factory with configuration
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
config_file: Path to vLLM service config YAML file
|
| 22 |
+
"""
|
| 23 |
+
if config_file is None:
|
| 24 |
+
project_root = Path(__file__).parent.parent.parent
|
| 25 |
+
config_file = project_root / "shared" / "configs" / "llm_service_config.yaml"
|
| 26 |
+
|
| 27 |
+
self.config_file = Path(config_file)
|
| 28 |
+
self._config = None
|
| 29 |
+
self._load_config()
|
| 30 |
+
|
| 31 |
+
def _load_config(self):
|
| 32 |
+
"""Load configuration from YAML file"""
|
| 33 |
+
if not self.config_file.exists():
|
| 34 |
+
raise FileNotFoundError(f"Config file not found: {self.config_file}")
|
| 35 |
+
|
| 36 |
+
with open(self.config_file, 'r', encoding='utf-8') as f:
|
| 37 |
+
self._config = yaml.safe_load(f)
|
| 38 |
+
|
| 39 |
+
def create_vllm_service(self, **override_params) -> VLLMService:
|
| 40 |
+
"""
|
| 41 |
+
Create vLLM service from configuration
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
**override_params: Parameters to override config values
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
VLLMService instance
|
| 48 |
+
"""
|
| 49 |
+
vllm_config = self._config.get("vllm", {})
|
| 50 |
+
|
| 51 |
+
# Merge config with overrides
|
| 52 |
+
params = {
|
| 53 |
+
"base_url": vllm_config.get("base_url", "http://localhost:8000/v1"),
|
| 54 |
+
"api_key": vllm_config.get("api_key", "dummy-key"),
|
| 55 |
+
"model_name": vllm_config.get("model_name", "Qwen/Qwen3-4B-Instruct-2507"),
|
| 56 |
+
"timeout": vllm_config.get("timeout", 300),
|
| 57 |
+
}
|
| 58 |
+
params.update(override_params)
|
| 59 |
+
|
| 60 |
+
return VLLMService(**params)
|
| 61 |
+
|
| 62 |
+
def create_gpt_service(self, **override_params) -> GPTService:
|
| 63 |
+
"""
|
| 64 |
+
Create GPT service from configuration
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
**override_params: Parameters to override config values
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
GPTService instance
|
| 71 |
+
"""
|
| 72 |
+
gpt_config = self._config.get("gpt", {})
|
| 73 |
+
|
| 74 |
+
# if not gpt_config.get("enabled", False):
|
| 75 |
+
# raise ValueError("GPT service is not enabled in configuration")
|
| 76 |
+
|
| 77 |
+
# Merge config with overrides
|
| 78 |
+
params = {
|
| 79 |
+
"api_key": gpt_config.get("api_key") or os.environ.get("OPENAI_API_KEY"),
|
| 80 |
+
"model_name": gpt_config.get("model_name", "gpt-4o"),
|
| 81 |
+
"base_url": gpt_config.get("base_url"),
|
| 82 |
+
"timeout": gpt_config.get("timeout", 300),
|
| 83 |
+
}
|
| 84 |
+
params.update(override_params)
|
| 85 |
+
|
| 86 |
+
return GPTService(**params)
|
| 87 |
+
|
| 88 |
+
def create_service(self, service_type: str = "vllm", **override_params) -> LLMService:
|
| 89 |
+
"""
|
| 90 |
+
Create LLM service by type
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
service_type: Service type ("vllm" or "gpt")
|
| 94 |
+
**override_params: Parameters to override config values
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
LLMService instance
|
| 98 |
+
"""
|
| 99 |
+
if service_type == "vllm":
|
| 100 |
+
return self.create_vllm_service(**override_params)
|
| 101 |
+
elif service_type == "gpt":
|
| 102 |
+
return self.create_gpt_service(**override_params)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Unknown service type: {service_type}")
|
| 105 |
+
|
| 106 |
+
def get_llm_assignment(self, component: str) -> str:
|
| 107 |
+
"""
|
| 108 |
+
Get LLM service assignment for a component
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
component: Component name ("keyword_generator", "paper_summarizer", "reviewer", "refiner")
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Service type ("vllm" or "gpt")
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
KeyError: If component is not found and no fallback is available
|
| 118 |
+
"""
|
| 119 |
+
assignments = self._config.get("llm_assignments", {})
|
| 120 |
+
if component in assignments:
|
| 121 |
+
return assignments[component]
|
| 122 |
+
|
| 123 |
+
# Fallback: if refiner not configured, use reviewer's assignment
|
| 124 |
+
if component == "refiner" and "reviewer" in assignments:
|
| 125 |
+
return assignments["reviewer"]
|
| 126 |
+
|
| 127 |
+
# Default fallback to vllm (may cause connection errors if vllm not available)
|
| 128 |
+
return assignments.get(component, "vllm")
|
| 129 |
+
|
| 130 |
+
def create_service_for_component(self, component: str, **override_params) -> LLMService:
|
| 131 |
+
"""
|
| 132 |
+
Create LLM service for a specific component based on configuration
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
component: Component name ("keyword_generator", "paper_summarizer", "reviewer")
|
| 136 |
+
**override_params: Parameters to override config values
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
LLMService instance
|
| 140 |
+
"""
|
| 141 |
+
service_type = self.get_llm_assignment(component)
|
| 142 |
+
return self.create_service(service_type, **override_params)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Global factory instance
|
| 146 |
+
_factory: Optional[LLMServiceFactory] = None
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_api_key_from_config(config_path: str) -> Optional[str]:
|
| 150 |
+
"""
|
| 151 |
+
Load API key from a YAML config file.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
config_path: Path to YAML config file
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
API key string, or None if not found
|
| 158 |
+
|
| 159 |
+
Note:
|
| 160 |
+
Returns None (instead of raising) if file doesn't exist or key not found,
|
| 161 |
+
to allow graceful fallback to environment variables.
|
| 162 |
+
"""
|
| 163 |
+
from pathlib import Path
|
| 164 |
+
|
| 165 |
+
config_file = Path(config_path)
|
| 166 |
+
if not config_file.exists():
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
with open(config_file, 'r', encoding='utf-8') as f:
|
| 171 |
+
config = yaml.safe_load(f)
|
| 172 |
+
return config.get('api_key')
|
| 173 |
+
except Exception:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_llm_service_factory(config_file: Optional[str] = None) -> LLMServiceFactory:
|
| 178 |
+
"""
|
| 179 |
+
Get or create global LLM service factory
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
config_file: Optional path to config file
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
LLMServiceFactory instance
|
| 186 |
+
"""
|
| 187 |
+
global _factory
|
| 188 |
+
if _factory is None or config_file is not None:
|
| 189 |
+
_factory = LLMServiceFactory(config_file)
|
| 190 |
+
return _factory
|
| 191 |
+
|
shared/utils/load_balancer.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple Python Load Balancer
|
| 3 |
+
|
| 4 |
+
A lightweight load balancer for vLLM and Reranker services.
|
| 5 |
+
Uses FastAPI to forward requests to multiple backend services.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import asyncio
|
| 11 |
+
import httpx
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Dict, Optional, Any
|
| 14 |
+
from threading import Lock
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from fastapi import FastAPI, Request, HTTPException
|
| 19 |
+
from fastapi.responses import StreamingResponse, Response
|
| 20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
+
import uvicorn
|
| 22 |
+
HAS_FASTAPI = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_FASTAPI = False
|
| 25 |
+
print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SimpleLoadBalancer:
|
| 29 |
+
"""Simple load balancer with round-robin and least-connection strategies"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
backends: List[str],
|
| 34 |
+
strategy: str = "round_robin",
|
| 35 |
+
health_check_interval: float = 10.0,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize load balancer
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
backends: List of backend URLs (e.g., ["http://localhost:8000", "http://localhost:8001"])
|
| 42 |
+
strategy: Load balancing strategy ("round_robin" or "least_conn")
|
| 43 |
+
health_check_interval: Health check interval in seconds
|
| 44 |
+
"""
|
| 45 |
+
self.backends = backends
|
| 46 |
+
self.strategy = strategy
|
| 47 |
+
self.health_check_interval = health_check_interval
|
| 48 |
+
|
| 49 |
+
# Round-robin state
|
| 50 |
+
self.current_index = 0
|
| 51 |
+
self.index_lock = Lock()
|
| 52 |
+
|
| 53 |
+
# Least-connection state
|
| 54 |
+
self.connection_counts: Dict[str, int] = {backend: 0 for backend in backends}
|
| 55 |
+
self.conn_lock = Lock()
|
| 56 |
+
|
| 57 |
+
# Health check state
|
| 58 |
+
self.healthy_backends: Dict[str, bool] = {backend: True for backend in backends}
|
| 59 |
+
self.health_lock = Lock()
|
| 60 |
+
|
| 61 |
+
# HTTP client for forwarding requests
|
| 62 |
+
self.client = httpx.AsyncClient(timeout=300.0)
|
| 63 |
+
|
| 64 |
+
print(f"Load balancer initialized with {len(backends)} backends")
|
| 65 |
+
print(f"Strategy: {strategy}")
|
| 66 |
+
for i, backend in enumerate(backends):
|
| 67 |
+
print(f" [{i+1}] {backend}")
|
| 68 |
+
|
| 69 |
+
def get_backend(self) -> Optional[str]:
|
| 70 |
+
"""Get next backend based on strategy"""
|
| 71 |
+
with self.health_lock:
|
| 72 |
+
available_backends = [b for b in self.backends if self.healthy_backends.get(b, True)]
|
| 73 |
+
|
| 74 |
+
if not available_backends:
|
| 75 |
+
# If no healthy backends, try all backends
|
| 76 |
+
available_backends = self.backends
|
| 77 |
+
|
| 78 |
+
if not available_backends:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
if self.strategy == "round_robin":
|
| 82 |
+
with self.index_lock:
|
| 83 |
+
backend = available_backends[self.current_index % len(available_backends)]
|
| 84 |
+
self.current_index = (self.current_index + 1) % len(available_backends)
|
| 85 |
+
return backend
|
| 86 |
+
|
| 87 |
+
elif self.strategy == "least_conn":
|
| 88 |
+
with self.conn_lock:
|
| 89 |
+
# Find backend with least connections
|
| 90 |
+
backend = min(available_backends, key=lambda b: self.connection_counts.get(b, 0))
|
| 91 |
+
self.connection_counts[backend] = self.connection_counts.get(backend, 0) + 1
|
| 92 |
+
return backend
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
# Default to round-robin
|
| 96 |
+
with self.index_lock:
|
| 97 |
+
backend = available_backends[self.current_index % len(available_backends)]
|
| 98 |
+
self.current_index = (self.current_index + 1) % len(available_backends)
|
| 99 |
+
return backend
|
| 100 |
+
|
| 101 |
+
def release_backend(self, backend: str):
|
| 102 |
+
"""Release a backend (for least-conn strategy)"""
|
| 103 |
+
if self.strategy == "least_conn":
|
| 104 |
+
with self.conn_lock:
|
| 105 |
+
self.connection_counts[backend] = max(0, self.connection_counts.get(backend, 0) - 1)
|
| 106 |
+
|
| 107 |
+
async def health_check(self, backend: str) -> bool:
|
| 108 |
+
"""Check if a backend is healthy"""
|
| 109 |
+
try:
|
| 110 |
+
# For vLLM backends (URLs ending with /v1), use /models endpoint
|
| 111 |
+
# For other backends, try /health first, then root
|
| 112 |
+
if backend.endswith("/v1"):
|
| 113 |
+
# vLLM endpoint: try /models (which becomes /v1/models)
|
| 114 |
+
endpoints = ["/models", "/"]
|
| 115 |
+
else:
|
| 116 |
+
# Other services: try /health, then root
|
| 117 |
+
endpoints = ["/health", "/"]
|
| 118 |
+
|
| 119 |
+
for endpoint in endpoints:
|
| 120 |
+
try:
|
| 121 |
+
response = await self.client.get(f"{backend}{endpoint}", timeout=5.0)
|
| 122 |
+
if response.status_code < 500:
|
| 123 |
+
return True
|
| 124 |
+
except:
|
| 125 |
+
continue
|
| 126 |
+
return False
|
| 127 |
+
except Exception:
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
async def check_all_backends(self):
|
| 131 |
+
"""Check health of all backends"""
|
| 132 |
+
while True:
|
| 133 |
+
for backend in self.backends:
|
| 134 |
+
is_healthy = await self.health_check(backend)
|
| 135 |
+
with self.health_lock:
|
| 136 |
+
self.healthy_backends[backend] = is_healthy
|
| 137 |
+
if not is_healthy:
|
| 138 |
+
print(f"Warning: Backend {backend} is unhealthy")
|
| 139 |
+
await asyncio.sleep(self.health_check_interval)
|
| 140 |
+
|
| 141 |
+
async def forward_request(
|
| 142 |
+
self,
|
| 143 |
+
method: str,
|
| 144 |
+
path: str,
|
| 145 |
+
request: Request,
|
| 146 |
+
backend: Optional[str] = None
|
| 147 |
+
) -> Response:
|
| 148 |
+
"""Forward a request to a backend"""
|
| 149 |
+
if backend is None:
|
| 150 |
+
backend = self.get_backend()
|
| 151 |
+
|
| 152 |
+
if backend is None:
|
| 153 |
+
raise HTTPException(status_code=503, detail="No healthy backends available")
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
# Get request body
|
| 157 |
+
body = await request.body()
|
| 158 |
+
|
| 159 |
+
# Get query parameters
|
| 160 |
+
query_params = dict(request.query_params)
|
| 161 |
+
|
| 162 |
+
# Get headers (exclude host and connection)
|
| 163 |
+
headers = dict(request.headers)
|
| 164 |
+
headers.pop("host", None)
|
| 165 |
+
headers.pop("connection", None)
|
| 166 |
+
headers.pop("content-length", None)
|
| 167 |
+
|
| 168 |
+
# Forward request
|
| 169 |
+
url = f"{backend}{path}"
|
| 170 |
+
if query_params:
|
| 171 |
+
url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items())
|
| 172 |
+
|
| 173 |
+
response = await self.client.request(
|
| 174 |
+
method=method,
|
| 175 |
+
url=url,
|
| 176 |
+
content=body,
|
| 177 |
+
headers=headers,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Create response
|
| 181 |
+
return Response(
|
| 182 |
+
content=response.content,
|
| 183 |
+
status_code=response.status_code,
|
| 184 |
+
headers=dict(response.headers),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
# Mark backend as unhealthy
|
| 189 |
+
with self.health_lock:
|
| 190 |
+
self.healthy_backends[backend] = False
|
| 191 |
+
|
| 192 |
+
self.release_backend(backend)
|
| 193 |
+
raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}")
|
| 194 |
+
finally:
|
| 195 |
+
self.release_backend(backend)
|
| 196 |
+
|
| 197 |
+
async def forward_streaming_request(
|
| 198 |
+
self,
|
| 199 |
+
method: str,
|
| 200 |
+
path: str,
|
| 201 |
+
request: Request,
|
| 202 |
+
backend: Optional[str] = None
|
| 203 |
+
):
|
| 204 |
+
"""Forward a streaming request to a backend"""
|
| 205 |
+
if backend is None:
|
| 206 |
+
backend = self.get_backend()
|
| 207 |
+
|
| 208 |
+
if backend is None:
|
| 209 |
+
raise HTTPException(status_code=503, detail="No healthy backends available")
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
# Get request body
|
| 213 |
+
body = await request.body()
|
| 214 |
+
|
| 215 |
+
# Get query parameters
|
| 216 |
+
query_params = dict(request.query_params)
|
| 217 |
+
|
| 218 |
+
# Get headers
|
| 219 |
+
headers = dict(request.headers)
|
| 220 |
+
headers.pop("host", None)
|
| 221 |
+
headers.pop("connection", None)
|
| 222 |
+
headers.pop("content-length", None)
|
| 223 |
+
|
| 224 |
+
# Forward request
|
| 225 |
+
url = f"{backend}{path}"
|
| 226 |
+
if query_params:
|
| 227 |
+
url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items())
|
| 228 |
+
|
| 229 |
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
| 230 |
+
async with client.stream(
|
| 231 |
+
method=method,
|
| 232 |
+
url=url,
|
| 233 |
+
content=body,
|
| 234 |
+
headers=headers,
|
| 235 |
+
) as response:
|
| 236 |
+
async def generate():
|
| 237 |
+
async for chunk in response.aiter_bytes():
|
| 238 |
+
yield chunk
|
| 239 |
+
|
| 240 |
+
return StreamingResponse(
|
| 241 |
+
generate(),
|
| 242 |
+
status_code=response.status_code,
|
| 243 |
+
headers=dict(response.headers),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
# Mark backend as unhealthy
|
| 248 |
+
with self.health_lock:
|
| 249 |
+
self.healthy_backends[backend] = False
|
| 250 |
+
|
| 251 |
+
self.release_backend(backend)
|
| 252 |
+
raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}")
|
| 253 |
+
finally:
|
| 254 |
+
self.release_backend(backend)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def create_load_balancer_app(
|
| 258 |
+
backends: List[str],
|
| 259 |
+
strategy: str = "round_robin",
|
| 260 |
+
health_check_interval: float = 10.0,
|
| 261 |
+
) -> FastAPI:
|
| 262 |
+
"""Create FastAPI app with load balancer"""
|
| 263 |
+
if not HAS_FASTAPI:
|
| 264 |
+
raise RuntimeError("FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
|
| 265 |
+
|
| 266 |
+
app = FastAPI(title="Simple Load Balancer", version="1.0.0")
|
| 267 |
+
|
| 268 |
+
# Add CORS middleware
|
| 269 |
+
app.add_middleware(
|
| 270 |
+
CORSMiddleware,
|
| 271 |
+
allow_origins=["*"],
|
| 272 |
+
allow_credentials=True,
|
| 273 |
+
allow_methods=["*"],
|
| 274 |
+
allow_headers=["*"],
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Create load balancer
|
| 278 |
+
lb = SimpleLoadBalancer(backends, strategy, health_check_interval)
|
| 279 |
+
|
| 280 |
+
# Start health check task
|
| 281 |
+
@app.on_event("startup")
|
| 282 |
+
async def start_health_check():
|
| 283 |
+
asyncio.create_task(lb.check_all_backends())
|
| 284 |
+
|
| 285 |
+
# Health check endpoint
|
| 286 |
+
@app.get("/health")
|
| 287 |
+
async def health():
|
| 288 |
+
healthy_count = sum(1 for h in lb.healthy_backends.values() if h)
|
| 289 |
+
return {
|
| 290 |
+
"status": "healthy" if healthy_count > 0 else "unhealthy",
|
| 291 |
+
"healthy_backends": healthy_count,
|
| 292 |
+
"total_backends": len(lb.backends),
|
| 293 |
+
"backends": [
|
| 294 |
+
{
|
| 295 |
+
"url": backend,
|
| 296 |
+
"healthy": lb.healthy_backends.get(backend, False),
|
| 297 |
+
"connections": lb.connection_counts.get(backend, 0) if lb.strategy == "least_conn" else None,
|
| 298 |
+
}
|
| 299 |
+
for backend in lb.backends
|
| 300 |
+
]
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
# Forward all other requests
|
| 304 |
+
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
| 305 |
+
async def forward(request: Request, path: str):
|
| 306 |
+
method = request.method
|
| 307 |
+
is_streaming = "stream" in request.query_params or "stream=true" in str(request.url)
|
| 308 |
+
|
| 309 |
+
if is_streaming:
|
| 310 |
+
return await lb.forward_streaming_request(method, f"/{path}", request)
|
| 311 |
+
else:
|
| 312 |
+
return await lb.forward_request(method, f"/{path}", request)
|
| 313 |
+
|
| 314 |
+
return app
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def main():
|
| 318 |
+
"""Main entry point for load balancer"""
|
| 319 |
+
parser = argparse.ArgumentParser(description="Simple Python Load Balancer")
|
| 320 |
+
parser.add_argument(
|
| 321 |
+
"--backends",
|
| 322 |
+
type=str,
|
| 323 |
+
nargs="+",
|
| 324 |
+
required=True,
|
| 325 |
+
help="Backend URLs (e.g., http://localhost:8000 http://localhost:8001)"
|
| 326 |
+
)
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--host",
|
| 329 |
+
type=str,
|
| 330 |
+
default="0.0.0.0",
|
| 331 |
+
help="Host to bind to (default: 0.0.0.0)"
|
| 332 |
+
)
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--port",
|
| 335 |
+
type=int,
|
| 336 |
+
default=8000,
|
| 337 |
+
help="Port to bind to (default: 8000)"
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--strategy",
|
| 341 |
+
type=str,
|
| 342 |
+
default="round_robin",
|
| 343 |
+
choices=["round_robin", "least_conn"],
|
| 344 |
+
help="Load balancing strategy (default: round_robin)"
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--health-check-interval",
|
| 348 |
+
type=float,
|
| 349 |
+
default=10.0,
|
| 350 |
+
help="Health check interval in seconds (default: 10.0)"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
args = parser.parse_args()
|
| 354 |
+
|
| 355 |
+
if not HAS_FASTAPI:
|
| 356 |
+
print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
|
| 357 |
+
sys.exit(1)
|
| 358 |
+
|
| 359 |
+
# Create app
|
| 360 |
+
app = create_load_balancer_app(
|
| 361 |
+
backends=args.backends,
|
| 362 |
+
strategy=args.strategy,
|
| 363 |
+
health_check_interval=args.health_check_interval,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Run server
|
| 367 |
+
print(f"Starting load balancer on {args.host}:{args.port}")
|
| 368 |
+
print(f"Strategy: {args.strategy}")
|
| 369 |
+
print(f"Backends: {', '.join(args.backends)}")
|
| 370 |
+
|
| 371 |
+
uvicorn.run(
|
| 372 |
+
app,
|
| 373 |
+
host=args.host,
|
| 374 |
+
port=args.port,
|
| 375 |
+
log_level="info"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
if __name__ == "__main__":
|
| 380 |
+
main()
|
| 381 |
+
|
| 382 |
+
# python -m shared.utils.load_balancer --backends http://localhost:8000 http://localhost:8001 http://localhost:8002 http://localhost:8003 --port 8004 --strategy round_robin
|
shared/utils/mock_llm_service.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mock LLM Service that returns pre-generated reviews from a JSON file
|
| 3 |
+
This is a hack for testing the refiner pipeline with existing reviews
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Dict, Optional, Any, Union
|
| 9 |
+
|
| 10 |
+
from .llm_service import LLMService, ChatMessage
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_title_from_latex(paper_context: str) -> Optional[str]:
|
| 14 |
+
"""Extract title from LaTeX format \\title{...}"""
|
| 15 |
+
match = re.search(r'\\title\{([^}]+)\}', paper_context)
|
| 16 |
+
if match:
|
| 17 |
+
return match.group(1).strip()
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def extract_abstract_from_latex(paper_context: str) -> Optional[str]:
|
| 22 |
+
"""Extract abstract from LaTeX format \\begin{abstract}...\\end{abstract}"""
|
| 23 |
+
match = re.search(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', paper_context, re.DOTALL)
|
| 24 |
+
if match:
|
| 25 |
+
abstract = match.group(1).strip()
|
| 26 |
+
# Clean up LaTeX commands
|
| 27 |
+
abstract = re.sub(r'\\[a-zA-Z]+\{([^}]+)\}', r'\1', abstract) # Remove LaTeX commands
|
| 28 |
+
abstract = re.sub(r'\$([^$]+)\$', r'\1', abstract) # Remove math mode
|
| 29 |
+
return abstract
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MockLLMService(LLMService):
|
| 34 |
+
"""
|
| 35 |
+
Mock LLM Service that returns pre-generated reviews from a JSON file
|
| 36 |
+
|
| 37 |
+
This service matches papers by extracting title and abstract from paper_context
|
| 38 |
+
and returns the corresponding pred_fast_mode_baseline from the JSON file.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, json_file_path: str):
|
| 42 |
+
"""
|
| 43 |
+
Initialize Mock LLM Service
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
json_file_path: Path to JSON file containing pre-generated reviews
|
| 47 |
+
"""
|
| 48 |
+
self.json_file_path = Path(json_file_path)
|
| 49 |
+
if not self.json_file_path.exists():
|
| 50 |
+
raise FileNotFoundError(f"JSON file not found: {json_file_path}")
|
| 51 |
+
|
| 52 |
+
# Load JSON data
|
| 53 |
+
with open(self.json_file_path, 'r', encoding='utf-8') as f:
|
| 54 |
+
self.data = json.load(f)
|
| 55 |
+
|
| 56 |
+
# Build index for faster lookup
|
| 57 |
+
self._build_index()
|
| 58 |
+
|
| 59 |
+
def _build_index(self):
|
| 60 |
+
"""Build index mapping (title, abstract) to review"""
|
| 61 |
+
self.index = {}
|
| 62 |
+
self.entries = [] # Store full entries for fallback matching
|
| 63 |
+
self.initial_scores_index = {} # Store initial scores and decision for each entry
|
| 64 |
+
|
| 65 |
+
for entry in self.data:
|
| 66 |
+
paper_context = entry.get('paper_context', '')
|
| 67 |
+
title = extract_title_from_latex(paper_context)
|
| 68 |
+
abstract = extract_abstract_from_latex(paper_context)
|
| 69 |
+
|
| 70 |
+
# Extract review content (prefer meta_review.content, fallback to raw_text)
|
| 71 |
+
model_prediction = entry.get('model_prediction', {})
|
| 72 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 73 |
+
review_content = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 74 |
+
|
| 75 |
+
# Extract initial scores and decision
|
| 76 |
+
initial_scores = {
|
| 77 |
+
'rating': meta_review.get('rating'),
|
| 78 |
+
'soundness': meta_review.get('soundness'),
|
| 79 |
+
'presentation': meta_review.get('presentation'),
|
| 80 |
+
'contribution': meta_review.get('contribution'),
|
| 81 |
+
'decision': model_prediction.get('decision'),
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
if title and abstract:
|
| 85 |
+
# Use normalized title and first 200 chars of abstract as key
|
| 86 |
+
normalized_title = title.lower().strip()
|
| 87 |
+
normalized_abstract = abstract[:200].lower().strip()
|
| 88 |
+
key = (normalized_title, normalized_abstract)
|
| 89 |
+
self.index[key] = review_content
|
| 90 |
+
self.initial_scores_index[key] = initial_scores
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Store entry for fallback matching
|
| 94 |
+
self.entries.append({
|
| 95 |
+
'title': title,
|
| 96 |
+
'abstract': abstract,
|
| 97 |
+
'paper_context': paper_context,
|
| 98 |
+
'review': review_content,
|
| 99 |
+
'id': entry.get('id', ''),
|
| 100 |
+
'initial_scores': initial_scores,
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
def _find_entry(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> Optional[Dict[str, Any]]:
|
| 104 |
+
"""
|
| 105 |
+
Find entry by matching title and abstract from messages
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
messages: List of chat messages
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Entry dict with 'review' and 'initial_scores' or None if not found
|
| 112 |
+
"""
|
| 113 |
+
# Extract paper context from user message
|
| 114 |
+
user_message = None
|
| 115 |
+
for msg in messages:
|
| 116 |
+
if isinstance(msg, dict):
|
| 117 |
+
if msg.get('role') == 'user':
|
| 118 |
+
user_message = msg.get('content', '')
|
| 119 |
+
elif isinstance(msg, ChatMessage):
|
| 120 |
+
if msg.role == 'user':
|
| 121 |
+
user_message = msg.content
|
| 122 |
+
|
| 123 |
+
if not user_message:
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
# Try to extract title and abstract from user message
|
| 127 |
+
# Look for patterns like "Title: ..." or "Abstract: ..."
|
| 128 |
+
title_match = re.search(r'Title:\s*(.+?)(?:\n|$)', user_message, re.IGNORECASE)
|
| 129 |
+
abstract_match = re.search(r'Abstract:\s*(.+?)(?:\n\n|Content:|$)', user_message, re.DOTALL | re.IGNORECASE)
|
| 130 |
+
|
| 131 |
+
extracted_title = None
|
| 132 |
+
extracted_abstract = None
|
| 133 |
+
|
| 134 |
+
if title_match and abstract_match:
|
| 135 |
+
extracted_title = title_match.group(1).strip()
|
| 136 |
+
extracted_abstract = abstract_match.group(1).strip()
|
| 137 |
+
else:
|
| 138 |
+
# Fallback: search in paper_context if available
|
| 139 |
+
paper_context_match = re.search(r'Paper to review:\s*(.+?)(?:Please provide|$)', user_message, re.DOTALL)
|
| 140 |
+
if paper_context_match:
|
| 141 |
+
paper_context = paper_context_match.group(1)
|
| 142 |
+
extracted_title = extract_title_from_latex(paper_context)
|
| 143 |
+
extracted_abstract = extract_abstract_from_latex(paper_context)
|
| 144 |
+
|
| 145 |
+
if extracted_title and extracted_abstract:
|
| 146 |
+
# Normalize for matching
|
| 147 |
+
normalized_title = extracted_title.lower().strip()
|
| 148 |
+
normalized_abstract = extracted_abstract[:200].lower().strip()
|
| 149 |
+
|
| 150 |
+
# Try exact match first
|
| 151 |
+
key = (normalized_title, normalized_abstract)
|
| 152 |
+
if key in self.index:
|
| 153 |
+
return {
|
| 154 |
+
'review': self.index[key],
|
| 155 |
+
'initial_scores': self.initial_scores_index.get(key, {})
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
# Try fuzzy match (check if title matches)
|
| 159 |
+
for (index_title, index_abstract), review in self.index.items():
|
| 160 |
+
# Check title similarity (either contains or is contained)
|
| 161 |
+
title_similar = (
|
| 162 |
+
normalized_title in index_title or
|
| 163 |
+
index_title in normalized_title or
|
| 164 |
+
normalized_title == index_title
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Check abstract similarity (first 100 chars)
|
| 168 |
+
abstract_similar = (
|
| 169 |
+
normalized_abstract[:100] in index_abstract[:100] or
|
| 170 |
+
index_abstract[:100] in normalized_abstract[:100] or
|
| 171 |
+
normalized_abstract[:100] == index_abstract[:100]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if title_similar and abstract_similar:
|
| 175 |
+
return {
|
| 176 |
+
'review': review,
|
| 177 |
+
'initial_scores': self.initial_scores_index.get((index_title, index_abstract), {})
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
# Final fallback: try to match by paper_context in entries
|
| 181 |
+
for entry in self.entries:
|
| 182 |
+
if entry['paper_context']:
|
| 183 |
+
# Check if user message contains similar content
|
| 184 |
+
entry_title = entry['title']
|
| 185 |
+
if entry_title and extracted_title:
|
| 186 |
+
if entry_title.lower().strip() in extracted_title.lower() or extracted_title.lower() in entry_title.lower():
|
| 187 |
+
return {
|
| 188 |
+
'review': entry['review'],
|
| 189 |
+
'initial_scores': entry.get('initial_scores', {})
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
def _find_review(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> Optional[str]:
|
| 195 |
+
"""
|
| 196 |
+
Find review by matching title and abstract from messages
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
messages: List of chat messages
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Review text or None if not found
|
| 203 |
+
"""
|
| 204 |
+
entry = self._find_entry(messages)
|
| 205 |
+
if entry:
|
| 206 |
+
return entry['review']
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
def get_initial_scores(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> Optional[Dict[str, Any]]:
|
| 210 |
+
"""
|
| 211 |
+
Get initial scores and decision by matching title and abstract from messages
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
messages: List of chat messages
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Dict with initial scores (rating, soundness, presentation, contribution, decision) or None if not found
|
| 218 |
+
"""
|
| 219 |
+
entry = self._find_entry(messages)
|
| 220 |
+
if entry:
|
| 221 |
+
return entry.get('initial_scores', {})
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
def generate(
|
| 225 |
+
self,
|
| 226 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 227 |
+
temperature: float = 0.7,
|
| 228 |
+
top_p: float = 0.8,
|
| 229 |
+
top_k: int = 20,
|
| 230 |
+
max_tokens: int = 16384,
|
| 231 |
+
presence_penalty: float = 0.0,
|
| 232 |
+
**kwargs
|
| 233 |
+
) -> str:
|
| 234 |
+
"""
|
| 235 |
+
Generate text from messages (returns pre-generated review)
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
messages: List of chat messages
|
| 239 |
+
temperature: Ignored (for compatibility)
|
| 240 |
+
top_p: Ignored (for compatibility)
|
| 241 |
+
top_k: Ignored (for compatibility)
|
| 242 |
+
max_tokens: Ignored (for compatibility)
|
| 243 |
+
presence_penalty: Ignored (for compatibility)
|
| 244 |
+
**kwargs: Additional parameters (ignored)
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Pre-generated review text
|
| 248 |
+
"""
|
| 249 |
+
review = self._find_review(messages)
|
| 250 |
+
if review:
|
| 251 |
+
return review
|
| 252 |
+
|
| 253 |
+
# Fallback: return a default message
|
| 254 |
+
return "## Summary:\n\nReview not found in pre-generated data."
|
| 255 |
+
|
| 256 |
+
def stream_generate(
|
| 257 |
+
self,
|
| 258 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 259 |
+
temperature: float = 0.7,
|
| 260 |
+
top_p: float = 0.8,
|
| 261 |
+
top_k: int = 20,
|
| 262 |
+
max_tokens: int = 16384,
|
| 263 |
+
presence_penalty: float = 0.0,
|
| 264 |
+
**kwargs
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Stream generate text from messages (yields pre-generated review)
|
| 268 |
+
|
| 269 |
+
Yields:
|
| 270 |
+
Pre-generated review text chunks
|
| 271 |
+
"""
|
| 272 |
+
review = self._find_review(messages)
|
| 273 |
+
if review:
|
| 274 |
+
# Yield in chunks to simulate streaming
|
| 275 |
+
chunk_size = 100
|
| 276 |
+
for i in range(0, len(review), chunk_size):
|
| 277 |
+
yield review[i:i + chunk_size]
|
| 278 |
+
else:
|
| 279 |
+
yield "## Summary:\n\nReview not found in pre-generated data."
|
| 280 |
+
|
shared/utils/prompt_loader.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility for loading prompts from YAML configuration files
|
| 3 |
+
"""
|
| 4 |
+
import yaml
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PromptLoader:
|
| 10 |
+
"""Load and manage prompts from YAML files"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, prompts_file: Optional[str] = None):
|
| 13 |
+
"""
|
| 14 |
+
Initialize prompt loader
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
prompts_file: Path to prompts YAML file. If None, uses default location.
|
| 18 |
+
"""
|
| 19 |
+
if prompts_file is None:
|
| 20 |
+
# Default to shared/configs/prompts.yaml relative to project root
|
| 21 |
+
project_root = Path(__file__).parent.parent.parent
|
| 22 |
+
prompts_file = project_root / "shared" / "configs" / "prompts.yaml"
|
| 23 |
+
|
| 24 |
+
self.prompts_file = Path(prompts_file)
|
| 25 |
+
self._prompts = None
|
| 26 |
+
self._load_prompts()
|
| 27 |
+
|
| 28 |
+
def _load_prompts(self):
|
| 29 |
+
"""Load prompts from YAML file"""
|
| 30 |
+
if not self.prompts_file.exists():
|
| 31 |
+
raise FileNotFoundError(f"Prompts file not found: {self.prompts_file}")
|
| 32 |
+
|
| 33 |
+
with open(self.prompts_file, 'r', encoding='utf-8') as f:
|
| 34 |
+
self._prompts = yaml.safe_load(f)
|
| 35 |
+
|
| 36 |
+
def get_keyword_generation_prompt(self, context: str) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Get keyword generation prompt with context filled in
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
context: Paper information context
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Formatted prompt string
|
| 45 |
+
"""
|
| 46 |
+
template = self._prompts["keyword_generation"]["user"]
|
| 47 |
+
return template.format(context=context)
|
| 48 |
+
|
| 49 |
+
def get_keyword_generation_system(self) -> str:
|
| 50 |
+
"""Get keyword generation system message"""
|
| 51 |
+
return self._prompts["keyword_generation"].get("system", "")
|
| 52 |
+
|
| 53 |
+
def get_paper_summarization_prompt(self, reference_paper: str, related_paper: str) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Get paper summarization prompt with reference_paper and related_paper filled in
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
reference_paper: Reference paper information (the paper being reviewed)
|
| 59 |
+
related_paper: Related paper information
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Formatted prompt string
|
| 63 |
+
"""
|
| 64 |
+
template = self._prompts["paper_summarization"]["user"]
|
| 65 |
+
return template.format(reference_paper=reference_paper, related_paper=related_paper)
|
| 66 |
+
|
| 67 |
+
def get_paper_results_summarization_prompt(self, content: str) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Get paper results summarization prompt with content filled in
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
content: Paper content (experiment results section)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Formatted prompt string
|
| 76 |
+
"""
|
| 77 |
+
template = self._prompts["paper_results_summarization"]["user"]
|
| 78 |
+
return template.format(content=content)
|
| 79 |
+
|
| 80 |
+
def get_paper_insight_miner_prompt(self, content: str, candidate_review: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Get paper insight miner prompt with content and candidate_review filled in
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
content: Paper content
|
| 86 |
+
candidate_review: Candidate review draft
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Formatted prompt string
|
| 90 |
+
"""
|
| 91 |
+
template = self._prompts["paper_insight_miner"]["user"]
|
| 92 |
+
# Use replace instead of format to avoid issues with JSON braces in the template
|
| 93 |
+
prompt = template.replace("{content}", content)
|
| 94 |
+
prompt = prompt.replace("{candidate_review}", candidate_review)
|
| 95 |
+
return prompt
|
| 96 |
+
|
| 97 |
+
def get_paper_results_analyzer_prompt(self, content: str, candidate_review: str) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Get paper results analyzer prompt with content and candidate_review filled in
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
content: Paper content
|
| 103 |
+
candidate_review: Candidate review draft
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Formatted prompt string
|
| 107 |
+
"""
|
| 108 |
+
template = self._prompts["paper_results_analyzer"]["user"]
|
| 109 |
+
# Use replace instead of format to avoid issues with JSON braces in the template
|
| 110 |
+
prompt = template.replace("{content}", content)
|
| 111 |
+
prompt = prompt.replace("{candidate_review}", candidate_review)
|
| 112 |
+
return prompt
|
| 113 |
+
|
| 114 |
+
def get_review_prompt(self, review_format: str = "detailed") -> str:
|
| 115 |
+
"""
|
| 116 |
+
Get review prompt for specified format
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
review_format: Review format ("detailed", "summary", "structured")
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Review prompt string
|
| 123 |
+
"""
|
| 124 |
+
if review_format not in self._prompts["review_prompts"]:
|
| 125 |
+
review_format = "detailed"
|
| 126 |
+
|
| 127 |
+
return self._prompts["review_prompts"][review_format]
|
| 128 |
+
|
| 129 |
+
def get_reviewer_system_message(self) -> str:
|
| 130 |
+
"""Get system message for reviewer"""
|
| 131 |
+
return self._prompts.get("reviewer_system", "You are an expert academic reviewer with deep knowledge in the field.")
|
| 132 |
+
|
| 133 |
+
def get_refiner_prompt(self, review_format: str = "detailed") -> str:
|
| 134 |
+
"""
|
| 135 |
+
Get refiner prompt for specified format
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
review_format: Review format ("detailed", "summary", "structured")
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Refiner prompt string
|
| 142 |
+
"""
|
| 143 |
+
if "refiner_prompts" not in self._prompts:
|
| 144 |
+
raise ValueError("refiner_prompts not found in prompts file")
|
| 145 |
+
|
| 146 |
+
if review_format not in self._prompts["refiner_prompts"]:
|
| 147 |
+
review_format = "detailed"
|
| 148 |
+
|
| 149 |
+
return self._prompts["refiner_prompts"][review_format]
|
| 150 |
+
|
| 151 |
+
def get_refiner_system_message(self) -> str:
|
| 152 |
+
"""Get system message for refiner"""
|
| 153 |
+
return self._prompts.get("refiner_system", "You are an expert review refiner with deep knowledge in academic review quality standards and meta rubrics.")
|
| 154 |
+
|
| 155 |
+
def get_rubrics_template(self) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Get the rubrics template for generating paper-specific rubrics.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Rubrics template string (JSON array format)
|
| 161 |
+
"""
|
| 162 |
+
return self._prompts.get("rubrics", "")
|
| 163 |
+
|
| 164 |
+
def get_rubric_generation_prompt(self, version: str = "v2") -> str:
|
| 165 |
+
"""
|
| 166 |
+
Get rubric generation prompt.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
version: Prompt version ("v1" or "v2", default: "v2")
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Rubric generation prompt template string
|
| 173 |
+
"""
|
| 174 |
+
key = f"{version}_rubric_generation_prompt"
|
| 175 |
+
prompt = self._prompts.get(key, "")
|
| 176 |
+
|
| 177 |
+
# For v2, replace rubric_template placeholder with actual template
|
| 178 |
+
if version == "v2" and "<<rubric_template>>" in prompt:
|
| 179 |
+
rubric_template = self.get_rubrics_template()
|
| 180 |
+
prompt = prompt.replace("<<rubric_template>>", rubric_template)
|
| 181 |
+
|
| 182 |
+
return prompt
|
| 183 |
+
|
| 184 |
+
def get_evaluator_prompt(self, version: str = "v1") -> str:
|
| 185 |
+
"""
|
| 186 |
+
Get evaluator prompt for evaluating reviews using rubrics.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
version: Prompt version ("v0" or "v1", default: "v1")
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Evaluator prompt template string
|
| 193 |
+
"""
|
| 194 |
+
key = f"{version}_evaluator_prompt"
|
| 195 |
+
return self._prompts.get(key, "")
|
| 196 |
+
|
| 197 |
+
def reload(self):
|
| 198 |
+
"""Reload prompts from file"""
|
| 199 |
+
self._load_prompts()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Global prompt loader instance
|
| 203 |
+
_prompt_loader: Optional[PromptLoader] = None
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_prompt_loader(prompts_file: Optional[str] = None) -> PromptLoader:
|
| 207 |
+
"""
|
| 208 |
+
Get or create global prompt loader instance
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
prompts_file: Optional path to prompts file
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
PromptLoader instance
|
| 215 |
+
"""
|
| 216 |
+
global _prompt_loader
|
| 217 |
+
if _prompt_loader is None or prompts_file is not None:
|
| 218 |
+
_prompt_loader = PromptLoader(prompts_file)
|
| 219 |
+
return _prompt_loader
|
| 220 |
+
|
shared/utils/reranker.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranker utilities for paper retrieval
|
| 3 |
+
Based on OpenScholar's rerank_paragraphs_bge function
|
| 4 |
+
|
| 5 |
+
Supports two modes:
|
| 6 |
+
1. Direct mode: Use FlagReranker directly (requires global lock for thread-safety)
|
| 7 |
+
2. API mode: Use reranker API service with load balancing (recommended for multi-GPU)
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
+
import requests
|
| 13 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 14 |
+
|
| 15 |
+
# Suppress transformers progress bars
|
| 16 |
+
os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
|
| 17 |
+
|
| 18 |
+
# Global lock for reranker usage (FlagReranker's tokenizer is not thread-safe)
|
| 19 |
+
# This prevents "Already borrowed" errors when multiple threads use the same reranker
|
| 20 |
+
# NOTE: Not needed when using API mode
|
| 21 |
+
_reranker_usage_lock = threading.Lock()
|
| 22 |
+
|
| 23 |
+
# Try to import endpoint pool for API mode
|
| 24 |
+
try:
|
| 25 |
+
from .reranker_endpoint_pool import RerankerEndpointPool
|
| 26 |
+
HAS_ENDPOINT_POOL = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
HAS_ENDPOINT_POOL = False
|
| 29 |
+
RerankerEndpointPool = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def rerank_paragraphs_bge(
|
| 33 |
+
query: str,
|
| 34 |
+
paragraphs: List[Dict[str, Any]],
|
| 35 |
+
reranker: Optional[Any] = None,
|
| 36 |
+
reranker_endpoint_pool: Optional[Any] = None,
|
| 37 |
+
norm_cite: bool = False,
|
| 38 |
+
start_index: int = 0,
|
| 39 |
+
use_abstract: bool = False,
|
| 40 |
+
timeout: float = 30.0,
|
| 41 |
+
) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
|
| 42 |
+
"""
|
| 43 |
+
Rerank paragraphs using BGE reranker (from OpenScholar)
|
| 44 |
+
|
| 45 |
+
Supports two modes:
|
| 46 |
+
1. Direct mode: Pass FlagReranker instance (uses global lock, thread-safe but serialized)
|
| 47 |
+
2. API mode: Pass RerankerEndpointPool (recommended for multi-GPU, parallel requests)
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
query: Search query
|
| 51 |
+
paragraphs: List of paragraph/paper dictionaries
|
| 52 |
+
reranker: FlagReranker instance (for direct mode, optional if using API mode)
|
| 53 |
+
reranker_endpoint_pool: RerankerEndpointPool instance (for API mode, optional if using direct mode)
|
| 54 |
+
norm_cite: Whether to normalize citation counts and add to scores
|
| 55 |
+
start_index: Starting index for id mapping
|
| 56 |
+
use_abstract: Whether to include abstract in reranking text
|
| 57 |
+
timeout: Request timeout for API mode (seconds)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tuple of:
|
| 61 |
+
- reranked_paragraphs: List of reranked paragraphs
|
| 62 |
+
- result_dict: Dictionary mapping original index to score
|
| 63 |
+
- id_mapping: Dictionary mapping new index to original index
|
| 64 |
+
"""
|
| 65 |
+
# Filter out paragraphs without text
|
| 66 |
+
paragraphs = [p for p in paragraphs if p.get("text") is not None]
|
| 67 |
+
|
| 68 |
+
if not paragraphs:
|
| 69 |
+
return [], {}, {}
|
| 70 |
+
|
| 71 |
+
# Build paragraph texts for reranking
|
| 72 |
+
if use_abstract:
|
| 73 |
+
paragraph_texts = [
|
| 74 |
+
p["title"] + "\n" + p["abstract"] + "\n" + p["text"]
|
| 75 |
+
if "title" in p and "abstract" in p and p.get("title") and p.get("abstract")
|
| 76 |
+
else p["text"]
|
| 77 |
+
for p in paragraphs
|
| 78 |
+
]
|
| 79 |
+
else:
|
| 80 |
+
paragraph_texts = [
|
| 81 |
+
p["title"] + " " + p["text"]
|
| 82 |
+
if "title" in p and p.get("title") is not None
|
| 83 |
+
else p["text"]
|
| 84 |
+
for p in paragraphs
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# Filter out empty or None texts
|
| 88 |
+
valid_indices = []
|
| 89 |
+
valid_texts = []
|
| 90 |
+
for i, text in enumerate(paragraph_texts):
|
| 91 |
+
if text and isinstance(text, str) and text.strip():
|
| 92 |
+
valid_indices.append(i)
|
| 93 |
+
valid_texts.append(text)
|
| 94 |
+
|
| 95 |
+
# If no valid texts, return empty results
|
| 96 |
+
if not valid_texts:
|
| 97 |
+
return [], {}, {}
|
| 98 |
+
|
| 99 |
+
# If some texts were filtered out, update paragraphs list
|
| 100 |
+
if len(valid_indices) < len(paragraphs):
|
| 101 |
+
paragraphs = [paragraphs[i] for i in valid_indices]
|
| 102 |
+
paragraph_texts = valid_texts
|
| 103 |
+
|
| 104 |
+
# Compute reranking scores
|
| 105 |
+
if reranker is None and reranker_endpoint_pool is None:
|
| 106 |
+
# If no reranker, return original order
|
| 107 |
+
id_mapping = {i: i + start_index for i in range(len(paragraphs))}
|
| 108 |
+
result_dict = {i: 0.0 for i in range(len(paragraphs))}
|
| 109 |
+
return paragraphs, result_dict, id_mapping
|
| 110 |
+
|
| 111 |
+
# API mode: Use reranker API service (recommended for multi-GPU)
|
| 112 |
+
if reranker_endpoint_pool is not None:
|
| 113 |
+
return _rerank_via_api(
|
| 114 |
+
query=query,
|
| 115 |
+
paragraph_texts=paragraph_texts,
|
| 116 |
+
paragraphs=paragraphs,
|
| 117 |
+
reranker_endpoint_pool=reranker_endpoint_pool,
|
| 118 |
+
norm_cite=norm_cite,
|
| 119 |
+
start_index=start_index,
|
| 120 |
+
timeout=timeout
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Direct mode: Use FlagReranker directly (requires global lock)
|
| 124 |
+
# Suppress transformers warnings and progress bars during computation
|
| 125 |
+
original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
|
| 126 |
+
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
|
| 127 |
+
|
| 128 |
+
# Use lock to prevent "Already borrowed" errors from Rust tokenizer
|
| 129 |
+
# FlagReranker's tokenizer is not thread-safe, so we need to serialize access
|
| 130 |
+
with _reranker_usage_lock:
|
| 131 |
+
try:
|
| 132 |
+
# Ensure we have at least one valid text before calling compute_score
|
| 133 |
+
if not paragraph_texts:
|
| 134 |
+
return [], {}, {}
|
| 135 |
+
scores = reranker.compute_score([[query, p] for p in paragraph_texts], batch_size=100)
|
| 136 |
+
finally:
|
| 137 |
+
# Restore original verbosity
|
| 138 |
+
if original_verbosity:
|
| 139 |
+
os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
|
| 140 |
+
elif 'TRANSFORMERS_VERBOSITY' in os.environ:
|
| 141 |
+
del os.environ['TRANSFORMERS_VERBOSITY']
|
| 142 |
+
|
| 143 |
+
# Handle score format (can be float or list)
|
| 144 |
+
if isinstance(scores, float):
|
| 145 |
+
result_dict = {0: scores}
|
| 146 |
+
else:
|
| 147 |
+
result_dict = {p_id: score for p_id, score in enumerate(scores)}
|
| 148 |
+
|
| 149 |
+
# Add normalized citation counts if enabled
|
| 150 |
+
if norm_cite:
|
| 151 |
+
citation_items = [
|
| 152 |
+
item["citation_counts"]
|
| 153 |
+
for item in paragraphs
|
| 154 |
+
if "citation_counts" in item and item["citation_counts"] is not None
|
| 155 |
+
]
|
| 156 |
+
if len(citation_items) > 0:
|
| 157 |
+
max_citations = max(citation_items)
|
| 158 |
+
for p_id in result_dict:
|
| 159 |
+
if (
|
| 160 |
+
"citation_counts" in paragraphs[p_id]
|
| 161 |
+
and paragraphs[p_id]["citation_counts"] is not None
|
| 162 |
+
):
|
| 163 |
+
result_dict[p_id] = result_dict[p_id] + (
|
| 164 |
+
paragraphs[p_id]["citation_counts"] / max_citations
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Sort by score
|
| 168 |
+
p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
|
| 169 |
+
|
| 170 |
+
# Build reranked list and id mapping
|
| 171 |
+
new_orders = []
|
| 172 |
+
id_mapping = {}
|
| 173 |
+
for i, (p_id, _) in enumerate(p_ids):
|
| 174 |
+
new_orders.append(paragraphs[p_id])
|
| 175 |
+
id_mapping[i] = int(p_id) + start_index
|
| 176 |
+
|
| 177 |
+
return new_orders, result_dict, id_mapping
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _rerank_via_api(
|
| 181 |
+
query: str,
|
| 182 |
+
paragraph_texts: List[str],
|
| 183 |
+
paragraphs: List[Dict[str, Any]],
|
| 184 |
+
reranker_endpoint_pool: Any,
|
| 185 |
+
norm_cite: bool = False,
|
| 186 |
+
start_index: int = 0,
|
| 187 |
+
timeout: float = 30.0,
|
| 188 |
+
) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
|
| 189 |
+
"""
|
| 190 |
+
Rerank paragraphs via API service (supports load balancing across multiple GPUs)
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
query: Search query
|
| 194 |
+
paragraph_texts: List of paragraph texts (already formatted)
|
| 195 |
+
paragraphs: List of paragraph dictionaries
|
| 196 |
+
reranker_endpoint_pool: RerankerEndpointPool instance
|
| 197 |
+
norm_cite: Whether to normalize citation counts
|
| 198 |
+
start_index: Starting index for id mapping
|
| 199 |
+
timeout: Request timeout
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Tuple of reranked paragraphs, result dict, and id mapping
|
| 203 |
+
"""
|
| 204 |
+
if not paragraph_texts:
|
| 205 |
+
return [], {}, {}
|
| 206 |
+
|
| 207 |
+
# Get endpoint from pool (round-robin load balancing)
|
| 208 |
+
endpoint = reranker_endpoint_pool.get_endpoint()
|
| 209 |
+
api_url = f"{endpoint}/rerank"
|
| 210 |
+
|
| 211 |
+
# Prepare request
|
| 212 |
+
request_data = {
|
| 213 |
+
"query": query,
|
| 214 |
+
"paragraphs": paragraph_texts,
|
| 215 |
+
"batch_size": 100
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
start_time = time.time()
|
| 219 |
+
try:
|
| 220 |
+
# Make API request
|
| 221 |
+
response = requests.post(
|
| 222 |
+
api_url,
|
| 223 |
+
json=request_data,
|
| 224 |
+
timeout=timeout
|
| 225 |
+
)
|
| 226 |
+
response.raise_for_status()
|
| 227 |
+
|
| 228 |
+
result = response.json()
|
| 229 |
+
scores = result.get("scores", [])
|
| 230 |
+
response_time = time.time() - start_time
|
| 231 |
+
|
| 232 |
+
# Mark success
|
| 233 |
+
reranker_endpoint_pool.mark_success(endpoint, response_time)
|
| 234 |
+
|
| 235 |
+
except requests.exceptions.RequestException as e:
|
| 236 |
+
# Mark error
|
| 237 |
+
reranker_endpoint_pool.mark_error(endpoint, str(e))
|
| 238 |
+
raise RuntimeError(f"Reranker API request failed: {e}")
|
| 239 |
+
|
| 240 |
+
# Handle score format (should be list from API)
|
| 241 |
+
if isinstance(scores, float):
|
| 242 |
+
result_dict = {0: scores}
|
| 243 |
+
else:
|
| 244 |
+
result_dict = {p_id: score for p_id, score in enumerate(scores)}
|
| 245 |
+
|
| 246 |
+
# Add normalized citation counts if enabled
|
| 247 |
+
if norm_cite:
|
| 248 |
+
citation_items = [
|
| 249 |
+
item["citation_counts"]
|
| 250 |
+
for item in paragraphs
|
| 251 |
+
if "citation_counts" in item and item["citation_counts"] is not None
|
| 252 |
+
]
|
| 253 |
+
if len(citation_items) > 0:
|
| 254 |
+
max_citations = max(citation_items)
|
| 255 |
+
for p_id in result_dict:
|
| 256 |
+
if (
|
| 257 |
+
"citation_counts" in paragraphs[p_id]
|
| 258 |
+
and paragraphs[p_id]["citation_counts"] is not None
|
| 259 |
+
):
|
| 260 |
+
result_dict[p_id] = result_dict[p_id] + (
|
| 261 |
+
paragraphs[p_id]["citation_counts"] / max_citations
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Sort by score
|
| 265 |
+
p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
|
| 266 |
+
|
| 267 |
+
# Build reranked list and id mapping
|
| 268 |
+
new_orders = []
|
| 269 |
+
id_mapping = {}
|
| 270 |
+
for i, (p_id, _) in enumerate(p_ids):
|
| 271 |
+
new_orders.append(paragraphs[p_id])
|
| 272 |
+
id_mapping[i] = int(p_id) + start_index
|
| 273 |
+
|
| 274 |
+
return new_orders, result_dict, id_mapping
|
| 275 |
+
|
shared/utils/reranker_api_service.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranker API Service
|
| 3 |
+
|
| 4 |
+
Pack FlagReranker into an HTTP API service, supporting multi-GPU load balancing.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
# Suppress transformers warnings
|
| 13 |
+
os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from fastapi import FastAPI, HTTPException
|
| 17 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 18 |
+
from pydantic import BaseModel
|
| 19 |
+
import uvicorn
|
| 20 |
+
HAS_FASTAPI = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
HAS_FASTAPI = False
|
| 23 |
+
print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn")
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from FlagEmbedding import FlagReranker
|
| 27 |
+
HAS_FLAGEMBEDDING = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
HAS_FLAGEMBEDDING = False
|
| 30 |
+
print("Warning: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Request/Response models
|
| 34 |
+
class RerankRequest(BaseModel):
|
| 35 |
+
query: str
|
| 36 |
+
paragraphs: List[str]
|
| 37 |
+
batch_size: int = 100
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RerankResponse(BaseModel):
|
| 41 |
+
scores: List[float]
|
| 42 |
+
success: bool
|
| 43 |
+
message: Optional[str] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Global reranker instance
|
| 47 |
+
_reranker: Optional[Any] = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def create_app(model_path: str, use_fp16: bool = True, device: Optional[str] = None):
|
| 51 |
+
"""Create FastAPI app with reranker"""
|
| 52 |
+
global _reranker
|
| 53 |
+
|
| 54 |
+
app = FastAPI(title="Reranker API Service", version="1.0.0")
|
| 55 |
+
|
| 56 |
+
# Add CORS middleware
|
| 57 |
+
app.add_middleware(
|
| 58 |
+
CORSMiddleware,
|
| 59 |
+
allow_origins=["*"],
|
| 60 |
+
allow_credentials=True,
|
| 61 |
+
allow_methods=["*"],
|
| 62 |
+
allow_headers=["*"],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@app.on_event("startup")
|
| 66 |
+
async def load_reranker():
|
| 67 |
+
"""Load reranker model on startup"""
|
| 68 |
+
global _reranker
|
| 69 |
+
if not HAS_FLAGEMBEDDING:
|
| 70 |
+
raise RuntimeError("FlagEmbedding not installed")
|
| 71 |
+
|
| 72 |
+
print(f"Loading reranker model: {model_path}")
|
| 73 |
+
print(f"Using FP16: {use_fp16}")
|
| 74 |
+
if device:
|
| 75 |
+
print(f"Using device: {device}")
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
_reranker = FlagReranker(
|
| 79 |
+
model_path,
|
| 80 |
+
use_fp16=use_fp16,
|
| 81 |
+
)
|
| 82 |
+
if device:
|
| 83 |
+
# Note: FlagReranker may not support explicit device setting
|
| 84 |
+
# This is a placeholder for future support
|
| 85 |
+
pass
|
| 86 |
+
print("Reranker model loaded successfully")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error loading reranker: {e}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
@app.get("/health")
|
| 92 |
+
async def health_check():
|
| 93 |
+
"""Health check endpoint"""
|
| 94 |
+
return {
|
| 95 |
+
"status": "healthy",
|
| 96 |
+
"model_loaded": _reranker is not None
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
@app.post("/rerank", response_model=RerankResponse)
|
| 100 |
+
async def rerank(request: RerankRequest):
|
| 101 |
+
"""Rerank paragraphs given a query"""
|
| 102 |
+
global _reranker
|
| 103 |
+
|
| 104 |
+
if _reranker is None:
|
| 105 |
+
raise HTTPException(status_code=503, detail="Reranker not loaded")
|
| 106 |
+
|
| 107 |
+
if not request.paragraphs:
|
| 108 |
+
return RerankResponse(
|
| 109 |
+
scores=[],
|
| 110 |
+
success=True,
|
| 111 |
+
message="No paragraphs to rerank"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
# Prepare sentence pairs: [[query, paragraph], ...]
|
| 116 |
+
sentence_pairs = [[request.query, p] for p in request.paragraphs]
|
| 117 |
+
|
| 118 |
+
# Compute scores
|
| 119 |
+
scores = _reranker.compute_score(
|
| 120 |
+
sentence_pairs,
|
| 121 |
+
batch_size=request.batch_size
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Handle score format (can be float or list)
|
| 125 |
+
if isinstance(scores, float):
|
| 126 |
+
scores = [scores]
|
| 127 |
+
elif not isinstance(scores, list):
|
| 128 |
+
scores = list(scores)
|
| 129 |
+
|
| 130 |
+
return RerankResponse(
|
| 131 |
+
scores=scores,
|
| 132 |
+
success=True
|
| 133 |
+
)
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"Error during reranking: {e}")
|
| 136 |
+
import traceback
|
| 137 |
+
traceback.print_exc()
|
| 138 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 139 |
+
|
| 140 |
+
return app
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
"""Main entry point for reranker API service"""
|
| 145 |
+
parser = argparse.ArgumentParser(description="Reranker API Service")
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--model_path",
|
| 148 |
+
type=str,
|
| 149 |
+
required=True,
|
| 150 |
+
help="Path to reranker model (e.g., 'OpenScholar/OpenScholar_Reranker')"
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--host",
|
| 154 |
+
type=str,
|
| 155 |
+
default="0.0.0.0",
|
| 156 |
+
help="Host to bind to (default: 0.0.0.0)"
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--port",
|
| 160 |
+
type=int,
|
| 161 |
+
default=8004,
|
| 162 |
+
help="Port to bind to (default: 8004)"
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--use_fp16",
|
| 166 |
+
action="store_true",
|
| 167 |
+
default=True,
|
| 168 |
+
help="Use FP16 precision (default: True)"
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--no_fp16",
|
| 172 |
+
dest="use_fp16",
|
| 173 |
+
action="store_false",
|
| 174 |
+
help="Disable FP16 precision"
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--device",
|
| 178 |
+
type=str,
|
| 179 |
+
default=None,
|
| 180 |
+
help="Device to use (e.g., 'cuda:0', 'cuda:1')"
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--workers",
|
| 184 |
+
type=int,
|
| 185 |
+
default=1,
|
| 186 |
+
help="Number of worker processes (default: 1, use 1 for reranker)"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
args = parser.parse_args()
|
| 190 |
+
|
| 191 |
+
if not HAS_FASTAPI:
|
| 192 |
+
print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn")
|
| 193 |
+
sys.exit(1)
|
| 194 |
+
|
| 195 |
+
if not HAS_FLAGEMBEDDING:
|
| 196 |
+
print("Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
|
| 197 |
+
sys.exit(1)
|
| 198 |
+
|
| 199 |
+
# Create app
|
| 200 |
+
app = create_app(
|
| 201 |
+
model_path=args.model_path,
|
| 202 |
+
use_fp16=args.use_fp16,
|
| 203 |
+
device=args.device
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Run server
|
| 207 |
+
print(f"Starting reranker API service on {args.host}:{args.port}")
|
| 208 |
+
print(f"Model: {args.model_path}")
|
| 209 |
+
print(f"FP16: {args.use_fp16}")
|
| 210 |
+
|
| 211 |
+
uvicorn.run(
|
| 212 |
+
app,
|
| 213 |
+
host=args.host,
|
| 214 |
+
port=args.port,
|
| 215 |
+
workers=args.workers,
|
| 216 |
+
log_level="info"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
main()
|
shared/utils/reranker_endpoint_pool.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranker Endpoint Pool Manager
|
| 3 |
+
|
| 4 |
+
Manage multiple reranker API endpoints, implement round-robin access and load balancing.
|
| 5 |
+
Reuse the logic of VLLMEndpointPool.
|
| 6 |
+
"""
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Dict
|
| 9 |
+
from threading import Lock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RerankerEndpointPool:
|
| 13 |
+
"""
|
| 14 |
+
Reranker Endpoint Pool Manager
|
| 15 |
+
|
| 16 |
+
Features:
|
| 17 |
+
1. Load multiple reranker API endpoints from file
|
| 18 |
+
2. Round-robin access endpoints (ensure uniform distribution)
|
| 19 |
+
3. Track usage status and errors for each endpoint
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, pool_path: Optional[str] = None, endpoints: Optional[List[str]] = None, use_round_robin: bool = True):
|
| 23 |
+
"""
|
| 24 |
+
Initialize Reranker Endpoint Pool
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
pool_path: Endpoints file path (one endpoint URL per line)
|
| 28 |
+
endpoints: directly provide endpoints list (prior to pool_path)
|
| 29 |
+
use_round_robin: whether to use round-robin strategy (default True, recommended)
|
| 30 |
+
"""
|
| 31 |
+
self.endpoints: List[str] = []
|
| 32 |
+
self.current_index: int = 0 # Round-robin current index
|
| 33 |
+
self.endpoint_status: Dict[str, Dict] = {} # status information for each endpoint
|
| 34 |
+
self.lock = Lock() # thread safe lock
|
| 35 |
+
self.use_round_robin = use_round_robin # whether to use round-robin
|
| 36 |
+
|
| 37 |
+
# load endpoints
|
| 38 |
+
if endpoints:
|
| 39 |
+
self.endpoints = endpoints
|
| 40 |
+
elif pool_path:
|
| 41 |
+
self._load_from_file(pool_path)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError("Either pool_path or endpoints must be provided")
|
| 44 |
+
|
| 45 |
+
if not self.endpoints:
|
| 46 |
+
raise ValueError("No endpoints loaded")
|
| 47 |
+
|
| 48 |
+
# initialize status for each endpoint
|
| 49 |
+
for endpoint in self.endpoints:
|
| 50 |
+
if endpoint not in self.endpoint_status:
|
| 51 |
+
self.endpoint_status[endpoint] = {
|
| 52 |
+
'total_requests': 0,
|
| 53 |
+
'successful_requests': 0,
|
| 54 |
+
'failed_requests': 0,
|
| 55 |
+
'total_response_time': 0.0,
|
| 56 |
+
'last_error': None,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
print(f"RerankerEndpointPool initialized with {len(self.endpoints)} endpoints")
|
| 60 |
+
for i, endpoint in enumerate(self.endpoints):
|
| 61 |
+
print(f" [{i+1}] {endpoint}")
|
| 62 |
+
|
| 63 |
+
def _load_from_file(self, pool_path: str):
|
| 64 |
+
"""load endpoints from file"""
|
| 65 |
+
path = Path(pool_path)
|
| 66 |
+
if not path.is_absolute():
|
| 67 |
+
# try to find file relative to shared/configs/
|
| 68 |
+
project_root = Path(__file__).parent.parent.parent
|
| 69 |
+
path = project_root / "shared" / "configs" / pool_path
|
| 70 |
+
|
| 71 |
+
if not path.exists():
|
| 72 |
+
raise FileNotFoundError(f"Reranker endpoint pool file not found: {path}")
|
| 73 |
+
|
| 74 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 75 |
+
lines = f.readlines()
|
| 76 |
+
|
| 77 |
+
self.endpoints = []
|
| 78 |
+
for line in lines:
|
| 79 |
+
line = line.strip()
|
| 80 |
+
if line and not line.startswith('#'):
|
| 81 |
+
# ensure URL format is correct
|
| 82 |
+
if not line.startswith('http://') and not line.startswith('https://'):
|
| 83 |
+
line = f"http://{line}"
|
| 84 |
+
self.endpoints.append(line)
|
| 85 |
+
|
| 86 |
+
def get_endpoint(self) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Get next available endpoint (round-robin strategy)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Available endpoint URL
|
| 92 |
+
"""
|
| 93 |
+
with self.lock:
|
| 94 |
+
if not self.endpoints:
|
| 95 |
+
raise ValueError("No reranker endpoints available in pool")
|
| 96 |
+
|
| 97 |
+
if self.use_round_robin:
|
| 98 |
+
# Round-robin: simple round-robin, ensure uniform distribution
|
| 99 |
+
selected_idx = self.current_index
|
| 100 |
+
self.current_index = (self.current_index + 1) % len(self.endpoints)
|
| 101 |
+
|
| 102 |
+
selected_endpoint = self.endpoints[selected_idx]
|
| 103 |
+
self.endpoint_status[selected_endpoint]['total_requests'] += 1
|
| 104 |
+
|
| 105 |
+
return selected_endpoint
|
| 106 |
+
else:
|
| 107 |
+
# smart selection mode (can select based on error rate, etc.)
|
| 108 |
+
# simple implementation: select endpoint with least requests
|
| 109 |
+
min_requests = min(
|
| 110 |
+
self.endpoint_status[ep]['total_requests']
|
| 111 |
+
for ep in self.endpoints
|
| 112 |
+
)
|
| 113 |
+
candidates = [
|
| 114 |
+
ep for ep in self.endpoints
|
| 115 |
+
if self.endpoint_status[ep]['total_requests'] == min_requests
|
| 116 |
+
]
|
| 117 |
+
selected_endpoint = candidates[0]
|
| 118 |
+
self.endpoint_status[selected_endpoint]['total_requests'] += 1
|
| 119 |
+
return selected_endpoint
|
| 120 |
+
|
| 121 |
+
def mark_success(self, endpoint: str, response_time: float = 0.0):
|
| 122 |
+
"""mark endpoint request as successful"""
|
| 123 |
+
with self.lock:
|
| 124 |
+
if endpoint in self.endpoint_status:
|
| 125 |
+
self.endpoint_status[endpoint]['successful_requests'] += 1
|
| 126 |
+
self.endpoint_status[endpoint]['total_response_time'] += response_time
|
| 127 |
+
|
| 128 |
+
def mark_error(self, endpoint: str, error: str):
|
| 129 |
+
"""mark endpoint request as failed"""
|
| 130 |
+
with self.lock:
|
| 131 |
+
if endpoint in self.endpoint_status:
|
| 132 |
+
self.endpoint_status[endpoint]['failed_requests'] += 1
|
| 133 |
+
self.endpoint_status[endpoint]['last_error'] = error
|
| 134 |
+
|
| 135 |
+
def get_status(self) -> Dict:
|
| 136 |
+
"""get pool status information"""
|
| 137 |
+
with self.lock:
|
| 138 |
+
endpoints_status = {}
|
| 139 |
+
for endpoint, status in self.endpoint_status.items():
|
| 140 |
+
total = status['total_requests']
|
| 141 |
+
success = status['successful_requests']
|
| 142 |
+
failed = status['failed_requests']
|
| 143 |
+
avg_time = (
|
| 144 |
+
status['total_response_time'] / success
|
| 145 |
+
if success > 0 else 0.0
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
endpoints_status[endpoint] = {
|
| 149 |
+
'total_requests': total,
|
| 150 |
+
'successful_requests': success,
|
| 151 |
+
'failed_requests': failed,
|
| 152 |
+
'success_rate': success / total if total > 0 else 0.0,
|
| 153 |
+
'avg_response_time': avg_time,
|
| 154 |
+
'last_error': status['last_error'],
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
'total_endpoints': len(self.endpoints),
|
| 159 |
+
'endpoints_status': endpoints_status,
|
| 160 |
+
}
|
shared/utils/reranker_pool.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranker Pool for multiprocessing-safe reranker sharing
|
| 3 |
+
|
| 4 |
+
Solve the issue of FlagReranker not being pickleable in multi-process/multi-thread environment.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import threading
|
| 8 |
+
from typing import Optional, Dict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# Global reranker storage (thread-safe)
|
| 12 |
+
# Note: Dictionary access is atomic in Python for simple operations,
|
| 13 |
+
# but we use a lock for thread-safety when modifying the dict
|
| 14 |
+
_reranker_pool: Dict[str, object] = {}
|
| 15 |
+
_reranker_lock = threading.Lock()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_reranker(model_path: str, use_fp16: bool = True):
|
| 19 |
+
"""
|
| 20 |
+
Get or create reranker (thread-safe, process-shared)
|
| 21 |
+
|
| 22 |
+
Performance optimization:
|
| 23 |
+
- Load and cache on first call
|
| 24 |
+
- Return cached instance on subsequent calls (no lock check)
|
| 25 |
+
- Use double-check locking pattern, reduce lock contention
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_path: Reranker model path
|
| 29 |
+
use_fp16: whether to use FP16
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
FlagReranker instance
|
| 33 |
+
"""
|
| 34 |
+
global _reranker_pool
|
| 35 |
+
|
| 36 |
+
# create unique key
|
| 37 |
+
key = f"{model_path}_{use_fp16}"
|
| 38 |
+
|
| 39 |
+
# performance optimization: fast path check (no lock)
|
| 40 |
+
if key in _reranker_pool:
|
| 41 |
+
return _reranker_pool[key]
|
| 42 |
+
|
| 43 |
+
# slow path: needs loading (needs lock)
|
| 44 |
+
with _reranker_lock:
|
| 45 |
+
# double check: other threads may have loaded while waiting for lock
|
| 46 |
+
if key not in _reranker_pool:
|
| 47 |
+
# lazy import, avoid importing when module is loaded
|
| 48 |
+
try:
|
| 49 |
+
from FlagEmbedding import FlagReranker
|
| 50 |
+
|
| 51 |
+
# set environment variable to suppress progress bar
|
| 52 |
+
original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
|
| 53 |
+
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# load model
|
| 57 |
+
reranker = FlagReranker(model_path, use_fp16=use_fp16)
|
| 58 |
+
_reranker_pool[key] = reranker
|
| 59 |
+
finally:
|
| 60 |
+
# restore original verbosity
|
| 61 |
+
if original_verbosity:
|
| 62 |
+
os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
|
| 63 |
+
elif 'TRANSFORMERS_VERBOSITY' in os.environ:
|
| 64 |
+
del os.environ['TRANSFORMERS_VERBOSITY']
|
| 65 |
+
|
| 66 |
+
except ImportError:
|
| 67 |
+
raise ImportError(
|
| 68 |
+
"FlagEmbedding not installed. Install it with: pip install FlagEmbedding"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return _reranker_pool[key]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def clear_reranker_pool():
|
| 75 |
+
"""clear reranker pool (mainly for testing)"""
|
| 76 |
+
global _reranker_pool
|
| 77 |
+
with _reranker_lock:
|
| 78 |
+
_reranker_pool.clear()
|
shared/utils/review_logger.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Review Logger Utility
|
| 3 |
+
|
| 4 |
+
Captures and logs all intermediate outputs from the review pipeline
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
import uuid
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, Any, Optional, List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ReviewLogger:
|
| 14 |
+
"""
|
| 15 |
+
Logger for capturing complete review pipeline execution logs
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, log_dir: Optional[str] = None, enabled: bool = True):
|
| 19 |
+
"""
|
| 20 |
+
Initialize Review Logger
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
log_dir: Directory to save log files. If None, uses current directory.
|
| 24 |
+
enabled: Whether logging is enabled
|
| 25 |
+
"""
|
| 26 |
+
self.enabled = enabled
|
| 27 |
+
self.log_dir = Path(log_dir) if log_dir else Path.cwd()
|
| 28 |
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
# Current run data
|
| 31 |
+
self.current_run_id: Optional[str] = None
|
| 32 |
+
self.current_run_data: Optional[Dict[str, Any]] = None
|
| 33 |
+
|
| 34 |
+
def start_run(
|
| 35 |
+
self,
|
| 36 |
+
title: str,
|
| 37 |
+
abstract: str,
|
| 38 |
+
content: Optional[str] = None,
|
| 39 |
+
keywords: Optional[List[str]] = None,
|
| 40 |
+
publication_date_range: Optional[str] = None,
|
| 41 |
+
venues: Optional[str] = None,
|
| 42 |
+
review_format: str = "detailed",
|
| 43 |
+
) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Start a new review run and generate UUID
|
| 46 |
+
|
| 47 |
+
IMPORTANT: If current_run_data already exists, this method will preserve existing
|
| 48 |
+
intermediate_outputs data to prevent data loss. Only input data and metadata are updated.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
title: Paper title
|
| 52 |
+
abstract: Paper abstract
|
| 53 |
+
content: Paper content (optional)
|
| 54 |
+
keywords: Existing keywords (optional)
|
| 55 |
+
publication_date_range: Date range filter (optional)
|
| 56 |
+
venues: Venue filter (optional)
|
| 57 |
+
review_format: Review format
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Run UUID string
|
| 61 |
+
"""
|
| 62 |
+
if not self.enabled:
|
| 63 |
+
return ""
|
| 64 |
+
|
| 65 |
+
# Generate UUID based on timestamp
|
| 66 |
+
timestamp = datetime.now()
|
| 67 |
+
# Use timestamp-based UUID (UUID1 uses MAC address + timestamp)
|
| 68 |
+
run_id = str(uuid.uuid1())
|
| 69 |
+
|
| 70 |
+
# PRESERVE existing intermediate_outputs if current_run_data already exists
|
| 71 |
+
# This prevents data loss if start_run() is called multiple times
|
| 72 |
+
existing_intermediate_outputs = None
|
| 73 |
+
existing_final_output = None
|
| 74 |
+
existing_errors = None
|
| 75 |
+
if self.current_run_data is not None:
|
| 76 |
+
existing_intermediate_outputs = self.current_run_data.get("intermediate_outputs")
|
| 77 |
+
existing_final_output = self.current_run_data.get("final_output")
|
| 78 |
+
existing_errors = self.current_run_data.get("errors", [])
|
| 79 |
+
|
| 80 |
+
self.current_run_id = run_id
|
| 81 |
+
|
| 82 |
+
# Initialize intermediate_outputs: use existing data if available, otherwise create new
|
| 83 |
+
if existing_intermediate_outputs is not None:
|
| 84 |
+
# Preserve existing intermediate outputs
|
| 85 |
+
intermediate_outputs = existing_intermediate_outputs
|
| 86 |
+
# Only initialize None fields if they don't exist
|
| 87 |
+
if "generated_keywords" not in intermediate_outputs:
|
| 88 |
+
intermediate_outputs["generated_keywords"] = None
|
| 89 |
+
if "retrieved_papers" not in intermediate_outputs:
|
| 90 |
+
intermediate_outputs["retrieved_papers"] = []
|
| 91 |
+
if "paper_summaries" not in intermediate_outputs:
|
| 92 |
+
intermediate_outputs["paper_summaries"] = []
|
| 93 |
+
if "related_work_json_list" not in intermediate_outputs:
|
| 94 |
+
intermediate_outputs["related_work_json_list"] = None
|
| 95 |
+
if "paper_results_analyzer_output" not in intermediate_outputs:
|
| 96 |
+
intermediate_outputs["paper_results_analyzer_output"] = None
|
| 97 |
+
if "paper_insight_miner_output" not in intermediate_outputs:
|
| 98 |
+
intermediate_outputs["paper_insight_miner_output"] = None
|
| 99 |
+
if "review_prompt" not in intermediate_outputs:
|
| 100 |
+
intermediate_outputs["review_prompt"] = None
|
| 101 |
+
if "review_llm_response" not in intermediate_outputs:
|
| 102 |
+
intermediate_outputs["review_llm_response"] = None
|
| 103 |
+
if "parsed_review" not in intermediate_outputs:
|
| 104 |
+
intermediate_outputs["parsed_review"] = None
|
| 105 |
+
if "refiner_prompt" not in intermediate_outputs:
|
| 106 |
+
intermediate_outputs["refiner_prompt"] = None
|
| 107 |
+
if "refiner_llm_response" not in intermediate_outputs:
|
| 108 |
+
intermediate_outputs["refiner_llm_response"] = None
|
| 109 |
+
if "parsed_refined_review" not in intermediate_outputs:
|
| 110 |
+
intermediate_outputs["parsed_refined_review"] = None
|
| 111 |
+
else:
|
| 112 |
+
# Create new intermediate_outputs structure
|
| 113 |
+
intermediate_outputs = {
|
| 114 |
+
"generated_keywords": None,
|
| 115 |
+
"retrieved_papers": [],
|
| 116 |
+
"paper_summaries": [],
|
| 117 |
+
"related_work_json_list": None,
|
| 118 |
+
"paper_results_analyzer_output": None,
|
| 119 |
+
"paper_insight_miner_output": None,
|
| 120 |
+
"review_prompt": None,
|
| 121 |
+
"review_llm_response": None,
|
| 122 |
+
"parsed_review": None,
|
| 123 |
+
"refiner_prompt": None,
|
| 124 |
+
"refiner_llm_response": None,
|
| 125 |
+
"parsed_refined_review": None,
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
self.current_run_data = {
|
| 129 |
+
"run_id": run_id,
|
| 130 |
+
"timestamp": timestamp.isoformat(),
|
| 131 |
+
"input": {
|
| 132 |
+
"title": title,
|
| 133 |
+
"abstract": abstract,
|
| 134 |
+
"content": content,
|
| 135 |
+
"keywords": keywords,
|
| 136 |
+
"publication_date_range": publication_date_range,
|
| 137 |
+
"venues": venues,
|
| 138 |
+
"review_format": review_format,
|
| 139 |
+
},
|
| 140 |
+
"intermediate_outputs": intermediate_outputs,
|
| 141 |
+
"final_output": existing_final_output,
|
| 142 |
+
"errors": existing_errors if existing_errors is not None else [],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return run_id
|
| 146 |
+
|
| 147 |
+
def log_keywords(self, keywords: List[str]):
|
| 148 |
+
"""Log generated search keywords"""
|
| 149 |
+
if self.enabled and self.current_run_data:
|
| 150 |
+
# Ensure intermediate_outputs exists
|
| 151 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 152 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 153 |
+
self.current_run_data["intermediate_outputs"]["generated_keywords"] = keywords
|
| 154 |
+
|
| 155 |
+
def log_retrieved_papers(self, papers: List[Dict[str, Any]]):
|
| 156 |
+
"""Log retrieved papers (raw)"""
|
| 157 |
+
if self.enabled and self.current_run_data:
|
| 158 |
+
# Ensure intermediate_outputs exists
|
| 159 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 160 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 161 |
+
# Store paper metadata (may be large, so we store essential info)
|
| 162 |
+
self.current_run_data["intermediate_outputs"]["retrieved_papers"] = [
|
| 163 |
+
{
|
| 164 |
+
"paper_id": p.get("paper_id"),
|
| 165 |
+
"title": p.get("title"),
|
| 166 |
+
"authors": p.get("authors", [])[:10], # Limit authors
|
| 167 |
+
"year": p.get("year"),
|
| 168 |
+
"venue": p.get("venue"),
|
| 169 |
+
"abstract": p.get("abstract", "")[:500], # Truncate abstract
|
| 170 |
+
"citation_counts": p.get("citation_counts", 0),
|
| 171 |
+
}
|
| 172 |
+
for p in papers
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
def log_paper_summary(self, paper_title: str, summary: str, paper_index: int):
|
| 176 |
+
"""Log a single paper summary"""
|
| 177 |
+
if self.enabled and self.current_run_data:
|
| 178 |
+
# Ensure intermediate_outputs exists
|
| 179 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 180 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 181 |
+
if "paper_summaries" not in self.current_run_data["intermediate_outputs"]:
|
| 182 |
+
self.current_run_data["intermediate_outputs"]["paper_summaries"] = []
|
| 183 |
+
self.current_run_data["intermediate_outputs"]["paper_summaries"].append({
|
| 184 |
+
"paper_index": paper_index,
|
| 185 |
+
"paper_title": paper_title,
|
| 186 |
+
"summary": summary,
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
def log_related_work_json_list(self, related_work_json_list: List[Dict[str, Any]]):
|
| 190 |
+
"""Log the final related work JSON list"""
|
| 191 |
+
if self.enabled and self.current_run_data:
|
| 192 |
+
# Ensure intermediate_outputs exists
|
| 193 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 194 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 195 |
+
self.current_run_data["intermediate_outputs"]["related_work_json_list"] = related_work_json_list
|
| 196 |
+
|
| 197 |
+
def log_paper_results_analyzer_output(self, results_analyzer_output: str):
|
| 198 |
+
"""Log the paper results analyzer JSON output"""
|
| 199 |
+
if self.enabled and self.current_run_data:
|
| 200 |
+
# Ensure intermediate_outputs exists
|
| 201 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 202 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 203 |
+
self.current_run_data["intermediate_outputs"]["paper_results_analyzer_output"] = results_analyzer_output
|
| 204 |
+
|
| 205 |
+
def log_paper_insight_miner_output(self, insight_miner_output: str):
|
| 206 |
+
"""Log the paper insight miner JSON output"""
|
| 207 |
+
if self.enabled and self.current_run_data:
|
| 208 |
+
# Ensure intermediate_outputs exists
|
| 209 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 210 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 211 |
+
self.current_run_data["intermediate_outputs"]["paper_insight_miner_output"] = insight_miner_output
|
| 212 |
+
|
| 213 |
+
def log_review_prompt(self, prompt: str, system_message: Optional[str] = None):
|
| 214 |
+
"""Log the review prompt sent to LLM"""
|
| 215 |
+
if self.enabled and self.current_run_data:
|
| 216 |
+
# Ensure intermediate_outputs exists
|
| 217 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 218 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 219 |
+
self.current_run_data["intermediate_outputs"]["review_prompt"] = {
|
| 220 |
+
"system_message": system_message,
|
| 221 |
+
"user_prompt": prompt,
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
def log_review_llm_response(self, response: str):
|
| 225 |
+
"""Log the raw LLM response for review"""
|
| 226 |
+
if self.enabled and self.current_run_data:
|
| 227 |
+
# Ensure intermediate_outputs exists
|
| 228 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 229 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 230 |
+
self.current_run_data["intermediate_outputs"]["review_llm_response"] = response
|
| 231 |
+
|
| 232 |
+
def log_parsed_review(self, parsed_review: Dict[str, Any]):
|
| 233 |
+
"""Log the parsed review dictionary"""
|
| 234 |
+
if self.enabled and self.current_run_data:
|
| 235 |
+
# Ensure intermediate_outputs exists
|
| 236 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 237 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 238 |
+
self.current_run_data["intermediate_outputs"]["parsed_review"] = parsed_review
|
| 239 |
+
|
| 240 |
+
def log_refiner_prompt(self, prompt: str, system_message: Optional[str] = None):
|
| 241 |
+
"""Log the refiner prompt sent to LLM"""
|
| 242 |
+
if self.enabled and self.current_run_data:
|
| 243 |
+
# Ensure intermediate_outputs exists
|
| 244 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 245 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 246 |
+
self.current_run_data["intermediate_outputs"]["refiner_prompt"] = {
|
| 247 |
+
"system_message": system_message,
|
| 248 |
+
"user_prompt": prompt,
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
def log_refiner_llm_response(self, response: str):
|
| 252 |
+
"""Log the raw LLM response for refiner"""
|
| 253 |
+
if self.enabled and self.current_run_data:
|
| 254 |
+
# Ensure intermediate_outputs exists
|
| 255 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 256 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 257 |
+
self.current_run_data["intermediate_outputs"]["refiner_llm_response"] = response
|
| 258 |
+
|
| 259 |
+
def log_parsed_refined_review(self, parsed_review: Dict[str, Any]):
|
| 260 |
+
"""Log the parsed refined review dictionary"""
|
| 261 |
+
if self.enabled and self.current_run_data:
|
| 262 |
+
# Ensure intermediate_outputs exists
|
| 263 |
+
if "intermediate_outputs" not in self.current_run_data:
|
| 264 |
+
self.current_run_data["intermediate_outputs"] = {}
|
| 265 |
+
self.current_run_data["intermediate_outputs"]["parsed_refined_review"] = parsed_review
|
| 266 |
+
|
| 267 |
+
def log_final_output(self, final_output: Dict[str, Any]):
|
| 268 |
+
"""Log the final review output"""
|
| 269 |
+
if self.enabled and self.current_run_data:
|
| 270 |
+
self.current_run_data["final_output"] = final_output
|
| 271 |
+
|
| 272 |
+
def log_error(self, error: str, step: Optional[str] = None):
|
| 273 |
+
"""Log an error that occurred during execution"""
|
| 274 |
+
if self.enabled and self.current_run_data:
|
| 275 |
+
if "errors" not in self.current_run_data:
|
| 276 |
+
self.current_run_data["errors"] = []
|
| 277 |
+
self.current_run_data["errors"].append({
|
| 278 |
+
"step": step,
|
| 279 |
+
"error": error,
|
| 280 |
+
"timestamp": datetime.now().isoformat(),
|
| 281 |
+
})
|
| 282 |
+
|
| 283 |
+
def save_run(self) -> Optional[str]:
|
| 284 |
+
"""
|
| 285 |
+
Save the current run to a JSON file
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Path to saved log file, or None if logging is disabled
|
| 289 |
+
"""
|
| 290 |
+
if not self.enabled or not self.current_run_data:
|
| 291 |
+
return None
|
| 292 |
+
|
| 293 |
+
# Generate filename with timestamp and UUID
|
| 294 |
+
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 295 |
+
filename = f"review_log_{timestamp_str}_{self.current_run_id[:8]}.json"
|
| 296 |
+
log_path = self.log_dir / filename
|
| 297 |
+
|
| 298 |
+
# Save to JSON
|
| 299 |
+
with open(log_path, 'w', encoding='utf-8') as f:
|
| 300 |
+
json.dump(self.current_run_data, f, indent=2, ensure_ascii=False)
|
| 301 |
+
|
| 302 |
+
return str(log_path)
|
| 303 |
+
|
| 304 |
+
def get_current_run_id(self) -> Optional[str]:
|
| 305 |
+
"""Get the current run ID"""
|
| 306 |
+
return self.current_run_id if self.enabled else None
|
shared/utils/vllm_endpoint_pool.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VLLM Endpoint Pool Manager
|
| 3 |
+
|
| 4 |
+
Manage multiple vLLM endpoints, implement round-robin access and load balancing.
|
| 5 |
+
"""
|
| 6 |
+
import random
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Dict
|
| 9 |
+
from threading import Lock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VLLMEndpointPool:
|
| 13 |
+
"""
|
| 14 |
+
VLLM Endpoint Pool Manager
|
| 15 |
+
|
| 16 |
+
Features:
|
| 17 |
+
1. Load multiple vLLM endpoints from file
|
| 18 |
+
2. Round-robin access endpoints (ensure uniform distribution)
|
| 19 |
+
3. Track usage status and errors for each endpoint
|
| 20 |
+
4. Smart selection (based on error rate and success rate, as backup)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, pool_path: Optional[str] = None, endpoints: Optional[List[str]] = None, use_round_robin: bool = True):
|
| 24 |
+
"""
|
| 25 |
+
Initialize VLLM Endpoint Pool
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
pool_path: Endpoints file path (one endpoint URL per line)
|
| 29 |
+
endpoints: directly provide endpoints list (prior to pool_path)
|
| 30 |
+
use_round_robin: whether to use round-robin strategy (True=uniform distribution, False=smart selection)
|
| 31 |
+
"""
|
| 32 |
+
self.endpoints: List[str] = []
|
| 33 |
+
self.current_index: int = 0 # Round-robin current index
|
| 34 |
+
self.used_indices: List[int] = [] # used indices in current round (for smart selection)
|
| 35 |
+
self.endpoint_status: Dict[str, Dict] = {} # status information for each endpoint
|
| 36 |
+
self.lock = Lock() # thread safe lock
|
| 37 |
+
self.use_round_robin = use_round_robin # whether to use round-robin
|
| 38 |
+
|
| 39 |
+
# load endpoints
|
| 40 |
+
if endpoints:
|
| 41 |
+
self.endpoints = [e.strip() for e in endpoints if e.strip()]
|
| 42 |
+
elif pool_path:
|
| 43 |
+
self._load_from_file(pool_path)
|
| 44 |
+
else:
|
| 45 |
+
# try to get single endpoint from environment variable (backward compatibility)
|
| 46 |
+
import os
|
| 47 |
+
env_endpoint = os.environ.get("VLLM_BASE_URL")
|
| 48 |
+
if env_endpoint:
|
| 49 |
+
# ensure format is correct (may need to add /v1)
|
| 50 |
+
if not env_endpoint.endswith('/v1'):
|
| 51 |
+
if env_endpoint.endswith('/'):
|
| 52 |
+
env_endpoint = env_endpoint.rstrip('/') + '/v1'
|
| 53 |
+
else:
|
| 54 |
+
env_endpoint = env_endpoint + '/v1'
|
| 55 |
+
self.endpoints = [env_endpoint]
|
| 56 |
+
|
| 57 |
+
if not self.endpoints:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"No vLLM endpoints available. Provide endpoints via pool_path, endpoints parameter, "
|
| 60 |
+
"or VLLM_BASE_URL environment variable."
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# initialize status for each endpoint
|
| 64 |
+
for endpoint in self.endpoints:
|
| 65 |
+
self.endpoint_status[endpoint] = {
|
| 66 |
+
'error_count': 0,
|
| 67 |
+
'last_error_time': None,
|
| 68 |
+
'consecutive_errors': 0,
|
| 69 |
+
'total_requests': 0,
|
| 70 |
+
'successful_requests': 0,
|
| 71 |
+
'total_response_time': 0.0, # 累计响应时间
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
def _load_from_file(self, pool_path: str):
|
| 75 |
+
"""load vLLM endpoints from file"""
|
| 76 |
+
path = Path(pool_path)
|
| 77 |
+
|
| 78 |
+
# if relative path, try to find file relative to shared/configs
|
| 79 |
+
if not path.is_absolute():
|
| 80 |
+
# try to find file relative to project root
|
| 81 |
+
project_root = Path(__file__).parent.parent.parent
|
| 82 |
+
path = project_root / "shared" / "configs" / pool_path
|
| 83 |
+
if not path.exists():
|
| 84 |
+
# try to find file relative to shared/configs
|
| 85 |
+
path = Path(__file__).parent.parent / "configs" / pool_path
|
| 86 |
+
|
| 87 |
+
if not path.exists():
|
| 88 |
+
raise FileNotFoundError(
|
| 89 |
+
f"VLLM endpoint pool file not found: {pool_path} (tried: {path})"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 93 |
+
lines = f.readlines()
|
| 94 |
+
|
| 95 |
+
endpoints = []
|
| 96 |
+
for line in lines:
|
| 97 |
+
line = line.strip()
|
| 98 |
+
if line and not line.startswith('#'):
|
| 99 |
+
# ensure format is correct (may need to add /v1)
|
| 100 |
+
if not line.endswith('/v1'):
|
| 101 |
+
if line.endswith('/'):
|
| 102 |
+
line = line.rstrip('/') + '/v1'
|
| 103 |
+
else:
|
| 104 |
+
line = line + '/v1'
|
| 105 |
+
endpoints.append(line)
|
| 106 |
+
|
| 107 |
+
self.endpoints = endpoints
|
| 108 |
+
|
| 109 |
+
if not self.endpoints:
|
| 110 |
+
raise ValueError(f"No valid vLLM endpoints found in pool file: {pool_path}")
|
| 111 |
+
|
| 112 |
+
def get_endpoint(self) -> str:
|
| 113 |
+
"""
|
| 114 |
+
Get next available endpoint (round-robin strategy)
|
| 115 |
+
|
| 116 |
+
Strategy:
|
| 117 |
+
- Round-robin mode (default): simple round-robin, ensure uniform distribution
|
| 118 |
+
- Smart selection mode: select best endpoint based on error rate, success rate, response time
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Available endpoint URL
|
| 122 |
+
"""
|
| 123 |
+
import time
|
| 124 |
+
|
| 125 |
+
with self.lock:
|
| 126 |
+
if not self.endpoints:
|
| 127 |
+
raise ValueError("No vLLM endpoints available in pool")
|
| 128 |
+
|
| 129 |
+
if self.use_round_robin:
|
| 130 |
+
# Round-robin: simple round-robin, ensure uniform distribution
|
| 131 |
+
selected_idx = self.current_index
|
| 132 |
+
self.current_index = (self.current_index + 1) % len(self.endpoints)
|
| 133 |
+
|
| 134 |
+
selected_endpoint = self.endpoints[selected_idx]
|
| 135 |
+
self.endpoint_status[selected_endpoint]['total_requests'] += 1
|
| 136 |
+
|
| 137 |
+
return selected_endpoint
|
| 138 |
+
else:
|
| 139 |
+
# smart selection mode (original logic)
|
| 140 |
+
# if current round is complete, start a new round
|
| 141 |
+
if len(self.used_indices) >= len(self.endpoints):
|
| 142 |
+
self.used_indices = []
|
| 143 |
+
|
| 144 |
+
# get indices not used in current round
|
| 145 |
+
available_indices = [i for i in range(len(self.endpoints)) if i not in self.used_indices]
|
| 146 |
+
|
| 147 |
+
if not available_indices:
|
| 148 |
+
# all endpoints are in current round, start a new round
|
| 149 |
+
available_indices = list(range(len(self.endpoints)))
|
| 150 |
+
self.used_indices = []
|
| 151 |
+
|
| 152 |
+
# prioritize endpoints with fewer errors and higher success rate
|
| 153 |
+
endpoint_scores = []
|
| 154 |
+
for idx in available_indices:
|
| 155 |
+
endpoint = self.endpoints[idx]
|
| 156 |
+
status = self.endpoint_status[endpoint]
|
| 157 |
+
|
| 158 |
+
# calculate score: error count, success rate, response time, score越高
|
| 159 |
+
error_count = status['error_count']
|
| 160 |
+
total = status['total_requests']
|
| 161 |
+
success_rate = (status['successful_requests'] / total) if total > 0 else 1.0
|
| 162 |
+
|
| 163 |
+
# calculate average response time (shorter is better)
|
| 164 |
+
avg_response_time = (
|
| 165 |
+
status['total_response_time'] / status['successful_requests']
|
| 166 |
+
if status['successful_requests'] > 0 else 0.0
|
| 167 |
+
)
|
| 168 |
+
# normalize response time score (assume 10 seconds as baseline, faster score higher)
|
| 169 |
+
response_time_score = 1.0 / (1.0 + avg_response_time / 10.0)
|
| 170 |
+
|
| 171 |
+
# if recent error, reduce score
|
| 172 |
+
recent_error_penalty = 0
|
| 173 |
+
if status['last_error_time']:
|
| 174 |
+
time_since_error = time.time() - status['last_error_time']
|
| 175 |
+
if time_since_error < 60: # 1 minute内
|
| 176 |
+
recent_error_penalty = 0.5
|
| 177 |
+
|
| 178 |
+
score = success_rate - (error_count * 0.1) - recent_error_penalty + (response_time_score * 0.2)
|
| 179 |
+
endpoint_scores.append((idx, score))
|
| 180 |
+
|
| 181 |
+
# sort by score, select highest score (but add some randomness)
|
| 182 |
+
endpoint_scores.sort(key=lambda x: x[1], reverse=True)
|
| 183 |
+
|
| 184 |
+
# select from top 50% (add randomness but prioritize better)
|
| 185 |
+
top_n = max(1, len(endpoint_scores) // 2) if len(endpoint_scores) > 1 else 1
|
| 186 |
+
selected_idx, _ = random.choice(endpoint_scores[:top_n])
|
| 187 |
+
|
| 188 |
+
# mark as used
|
| 189 |
+
self.used_indices.append(selected_idx)
|
| 190 |
+
|
| 191 |
+
selected_endpoint = self.endpoints[selected_idx]
|
| 192 |
+
self.endpoint_status[selected_endpoint]['total_requests'] += 1
|
| 193 |
+
|
| 194 |
+
return selected_endpoint
|
| 195 |
+
|
| 196 |
+
def mark_success(self, endpoint: str, response_time: float = 0.0):
|
| 197 |
+
"""
|
| 198 |
+
mark endpoint as successful
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
endpoint: successful endpoint URL
|
| 202 |
+
response_time: response time (seconds)
|
| 203 |
+
"""
|
| 204 |
+
with self.lock:
|
| 205 |
+
if endpoint in self.endpoint_status:
|
| 206 |
+
status = self.endpoint_status[endpoint]
|
| 207 |
+
status['successful_requests'] += 1
|
| 208 |
+
status['consecutive_errors'] = 0
|
| 209 |
+
status['total_response_time'] += response_time
|
| 210 |
+
|
| 211 |
+
def mark_error(self, endpoint: str, error_type: str = "server_error"):
|
| 212 |
+
"""
|
| 213 |
+
mark endpoint as failed
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
endpoint: failed endpoint URL
|
| 217 |
+
error_type: error type ("server_error", "timeout", "connection_error", "other")
|
| 218 |
+
"""
|
| 219 |
+
import time
|
| 220 |
+
|
| 221 |
+
with self.lock:
|
| 222 |
+
if endpoint in self.endpoint_status:
|
| 223 |
+
status = self.endpoint_status[endpoint]
|
| 224 |
+
status['error_count'] += 1
|
| 225 |
+
status['consecutive_errors'] += 1
|
| 226 |
+
status['last_error_time'] = time.time()
|
| 227 |
+
|
| 228 |
+
def get_status(self) -> Dict:
|
| 229 |
+
"""get pool status information (for debugging)"""
|
| 230 |
+
with self.lock:
|
| 231 |
+
return {
|
| 232 |
+
'total_endpoints': len(self.endpoints),
|
| 233 |
+
'current_round_progress': f"{len(self.used_indices)}/{len(self.endpoints)}",
|
| 234 |
+
'endpoints_status': {
|
| 235 |
+
endpoint: {
|
| 236 |
+
'error_count': status['error_count'],
|
| 237 |
+
'successful_requests': status['successful_requests'],
|
| 238 |
+
'total_requests': status['total_requests'],
|
| 239 |
+
'success_rate': (
|
| 240 |
+
status['successful_requests'] / status['total_requests']
|
| 241 |
+
if status['total_requests'] > 0 else 0.0
|
| 242 |
+
),
|
| 243 |
+
'avg_response_time': (
|
| 244 |
+
status['total_response_time'] / status['successful_requests']
|
| 245 |
+
if status['successful_requests'] > 0 else 0.0
|
| 246 |
+
),
|
| 247 |
+
'consecutive_errors': status['consecutive_errors'],
|
| 248 |
+
'last_error_time': status['last_error_time'],
|
| 249 |
+
}
|
| 250 |
+
for endpoint, status in self.endpoint_status.items()
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
def reset_round(self):
|
| 255 |
+
"""reset current round (force start a new round)"""
|
| 256 |
+
with self.lock:
|
| 257 |
+
self.used_indices = []
|
shared/utils/vllm_service.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified vLLM service for review system
|
| 3 |
+
|
| 4 |
+
This service only handles API calls, no load balancing logic.
|
| 5 |
+
Load balancing should be handled at the deployment service level (e.g., nginx reverse proxy).
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import random
|
| 10 |
+
import yaml
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List, Dict, Optional, Any, Union
|
| 13 |
+
from threading import Semaphore, Lock as ThreadLock
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
from .llm_service import LLMService, ChatMessage
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VLLMService(LLMService):
|
| 19 |
+
"""
|
| 20 |
+
Simplified vLLM service wrapper for local LLM deployment
|
| 21 |
+
|
| 22 |
+
This service connects to a vLLM server endpoint.
|
| 23 |
+
Load balancing should be handled at the deployment level (e.g., nginx, multiple services behind a load balancer).
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Simple API calls to a single endpoint
|
| 27 |
+
- Automatic retry with exponential backoff for 500 errors
|
| 28 |
+
- Configurable max concurrent requests (per service instance)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Class-level semaphore for rate limiting (shared across all instances of this service)
|
| 32 |
+
# Use lazy initialization to avoid pickle issues with multiprocessing
|
| 33 |
+
_request_semaphore: Optional[Semaphore] = None
|
| 34 |
+
_max_concurrent_requests: int = 8 # Default limit
|
| 35 |
+
_semaphore_lock = ThreadLock() # Thread-safe initialization lock
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
base_url: Optional[str] = None,
|
| 40 |
+
api_key: Optional[str] = None,
|
| 41 |
+
model_name: Optional[str] = None,
|
| 42 |
+
timeout: Optional[int] = None,
|
| 43 |
+
config_file: Optional[str] = None,
|
| 44 |
+
max_concurrent_requests: Optional[int] = None,
|
| 45 |
+
max_retries: int = 3,
|
| 46 |
+
retry_delay: float = 1.0,
|
| 47 |
+
retry_backoff: float = 2.0,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Initialize vLLM service
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
base_url: vLLM server base URL (default: from config or http://localhost:8000/v1)
|
| 54 |
+
api_key: API key (overrides config)
|
| 55 |
+
model_name: Model name identifier (overrides config)
|
| 56 |
+
timeout: Request timeout in seconds (overrides config)
|
| 57 |
+
config_file: Path to config file (default: configs/llm_service_config.yaml)
|
| 58 |
+
max_concurrent_requests: Maximum concurrent requests per service instance (default: 8)
|
| 59 |
+
max_retries: Maximum number of retries for failed requests (default: 3)
|
| 60 |
+
retry_delay: Initial retry delay in seconds (default: 1.0)
|
| 61 |
+
retry_backoff: Retry delay multiplier (default: 2.0)
|
| 62 |
+
"""
|
| 63 |
+
# Load config from YAML
|
| 64 |
+
config = self._load_config(config_file)
|
| 65 |
+
vllm_config = config.get("vllm", {})
|
| 66 |
+
|
| 67 |
+
# Use provided values or fall back to config, then environment variables
|
| 68 |
+
self.base_url = base_url or vllm_config.get("base_url") or os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1")
|
| 69 |
+
self.model_name = model_name or vllm_config.get("model_name", "Qwen/Qwen3-4B-Instruct-2507")
|
| 70 |
+
self.api_key = api_key or vllm_config.get("api_key", "dummy-key")
|
| 71 |
+
self.timeout = timeout or vllm_config.get("timeout", 300)
|
| 72 |
+
|
| 73 |
+
# Retry configuration
|
| 74 |
+
self.max_retries = max_retries
|
| 75 |
+
self.retry_delay = retry_delay
|
| 76 |
+
self.retry_backoff = retry_backoff
|
| 77 |
+
|
| 78 |
+
# Rate limiting: Initialize class-level semaphore if not already initialized
|
| 79 |
+
# Use lazy initialization with thread-safe check to avoid pickle issues
|
| 80 |
+
if max_concurrent_requests is not None:
|
| 81 |
+
VLLMService._max_concurrent_requests = max_concurrent_requests
|
| 82 |
+
else:
|
| 83 |
+
# Try to get from config
|
| 84 |
+
config_max_concurrent = vllm_config.get("max_concurrent_requests")
|
| 85 |
+
if config_max_concurrent is not None:
|
| 86 |
+
VLLMService._max_concurrent_requests = config_max_concurrent
|
| 87 |
+
|
| 88 |
+
# Lazy initialization of semaphore will happen on first use
|
| 89 |
+
# This avoids pickle issues when using multiprocessing/ThreadPoolExecutor
|
| 90 |
+
|
| 91 |
+
# Store default sampling parameters from config
|
| 92 |
+
self.default_temperature = vllm_config.get("temperature", 0.7)
|
| 93 |
+
self.default_top_p = vllm_config.get("top_p", 0.8)
|
| 94 |
+
self.default_top_k = vllm_config.get("top_k", 20)
|
| 95 |
+
self.default_max_tokens = vllm_config.get("max_tokens", 16384)
|
| 96 |
+
self.default_presence_penalty = vllm_config.get("presence_penalty", 0.0)
|
| 97 |
+
|
| 98 |
+
# Create OpenAI client
|
| 99 |
+
self.client = OpenAI(
|
| 100 |
+
api_key=self.api_key,
|
| 101 |
+
base_url=self.base_url,
|
| 102 |
+
timeout=self.timeout,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _load_config(config_file: Optional[str] = None) -> Dict[str, Any]:
|
| 107 |
+
"""
|
| 108 |
+
Load configuration from YAML file
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
config_file: Path to config file
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Configuration dictionary
|
| 115 |
+
"""
|
| 116 |
+
if config_file is None:
|
| 117 |
+
project_root = Path(__file__).parent.parent.parent
|
| 118 |
+
config_file = project_root / "shared" / "configs" / "llm_service_config.yaml"
|
| 119 |
+
|
| 120 |
+
config_path = Path(config_file)
|
| 121 |
+
if not config_path.exists():
|
| 122 |
+
# Return defaults if config file doesn't exist
|
| 123 |
+
return {
|
| 124 |
+
"vllm": {
|
| 125 |
+
"base_url": "http://localhost:8000/v1",
|
| 126 |
+
"api_key": "dummy-key",
|
| 127 |
+
"model_name": "Qwen/Qwen3-4B-Instruct-2507",
|
| 128 |
+
"timeout": 300,
|
| 129 |
+
"max_concurrent_requests": 8,
|
| 130 |
+
"max_retries": 3,
|
| 131 |
+
"retry_delay": 1.0,
|
| 132 |
+
"retry_backoff": 2.0,
|
| 133 |
+
"temperature": 0.7,
|
| 134 |
+
"top_p": 0.8,
|
| 135 |
+
"top_k": 20,
|
| 136 |
+
"max_tokens": 16384,
|
| 137 |
+
"presence_penalty": 0.0,
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 142 |
+
return yaml.safe_load(f) or {}
|
| 143 |
+
|
| 144 |
+
@classmethod
|
| 145 |
+
def _ensure_semaphore(cls):
|
| 146 |
+
"""Thread-safe lazy initialization of semaphore to avoid pickle issues"""
|
| 147 |
+
if cls._request_semaphore is None:
|
| 148 |
+
with cls._semaphore_lock:
|
| 149 |
+
# Double-check pattern
|
| 150 |
+
if cls._request_semaphore is None:
|
| 151 |
+
cls._request_semaphore = Semaphore(cls._max_concurrent_requests)
|
| 152 |
+
|
| 153 |
+
def _format_messages(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> List[Dict[str, str]]:
|
| 154 |
+
"""Format messages for OpenAI API"""
|
| 155 |
+
formatted = []
|
| 156 |
+
for msg in messages:
|
| 157 |
+
if isinstance(msg, ChatMessage):
|
| 158 |
+
formatted.append({"role": msg.role, "content": msg.content})
|
| 159 |
+
elif isinstance(msg, dict):
|
| 160 |
+
formatted.append(msg)
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"Invalid message type: {type(msg)}")
|
| 163 |
+
return formatted
|
| 164 |
+
|
| 165 |
+
def generate(
|
| 166 |
+
self,
|
| 167 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 168 |
+
temperature: Optional[float] = None,
|
| 169 |
+
top_p: Optional[float] = None,
|
| 170 |
+
top_k: Optional[int] = None,
|
| 171 |
+
max_tokens: Optional[int] = None,
|
| 172 |
+
presence_penalty: Optional[float] = None,
|
| 173 |
+
**kwargs
|
| 174 |
+
) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Generate text from messages
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
messages: List of chat messages
|
| 180 |
+
temperature: Sampling temperature (uses config default if None)
|
| 181 |
+
top_p: Top-p sampling parameter (uses config default if None)
|
| 182 |
+
top_k: Top-k sampling parameter (uses config default if None)
|
| 183 |
+
max_tokens: Maximum tokens to generate (uses config default if None)
|
| 184 |
+
presence_penalty: Presence penalty (uses config default if None)
|
| 185 |
+
**kwargs: Additional parameters
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Generated text
|
| 189 |
+
"""
|
| 190 |
+
formatted_messages = self._format_messages(messages)
|
| 191 |
+
|
| 192 |
+
# Use provided values or fall back to config defaults
|
| 193 |
+
temperature = temperature if temperature is not None else self.default_temperature
|
| 194 |
+
top_p = top_p if top_p is not None else self.default_top_p
|
| 195 |
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
| 196 |
+
presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
|
| 197 |
+
|
| 198 |
+
# Ensure semaphore is initialized (lazy, thread-safe)
|
| 199 |
+
self._ensure_semaphore()
|
| 200 |
+
|
| 201 |
+
# Use semaphore to limit concurrent requests
|
| 202 |
+
with VLLMService._request_semaphore:
|
| 203 |
+
last_exception = None
|
| 204 |
+
|
| 205 |
+
for retry_attempt in range(self.max_retries + 1):
|
| 206 |
+
try:
|
| 207 |
+
response = self.client.chat.completions.create(
|
| 208 |
+
model=self.model_name,
|
| 209 |
+
messages=formatted_messages,
|
| 210 |
+
temperature=temperature,
|
| 211 |
+
top_p=top_p,
|
| 212 |
+
max_tokens=max_tokens,
|
| 213 |
+
presence_penalty=presence_penalty,
|
| 214 |
+
**kwargs
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return response.choices[0].message.content
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
last_exception = e
|
| 221 |
+
|
| 222 |
+
# Check if it's a server error (500, 502, 503, 504) that we should retry
|
| 223 |
+
should_retry = False
|
| 224 |
+
error_str = str(e).lower()
|
| 225 |
+
|
| 226 |
+
if any(code in error_str for code in ["500", "502", "503", "504"]):
|
| 227 |
+
should_retry = True
|
| 228 |
+
elif "server error" in error_str or "internal server error" in error_str:
|
| 229 |
+
should_retry = True
|
| 230 |
+
|
| 231 |
+
# Don't retry on last attempt
|
| 232 |
+
if retry_attempt < self.max_retries and should_retry:
|
| 233 |
+
# Calculate delay with exponential backoff and jitter
|
| 234 |
+
delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
|
| 235 |
+
jitter = random.uniform(0, delay * 0.1) # 10% jitter
|
| 236 |
+
time.sleep(delay + jitter)
|
| 237 |
+
continue
|
| 238 |
+
else:
|
| 239 |
+
# Either not a retryable error or out of retries
|
| 240 |
+
raise last_exception
|
| 241 |
+
|
| 242 |
+
def stream_generate(
|
| 243 |
+
self,
|
| 244 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 245 |
+
temperature: Optional[float] = None,
|
| 246 |
+
top_p: Optional[float] = None,
|
| 247 |
+
top_k: Optional[int] = None,
|
| 248 |
+
max_tokens: Optional[int] = None,
|
| 249 |
+
presence_penalty: Optional[float] = None,
|
| 250 |
+
**kwargs
|
| 251 |
+
):
|
| 252 |
+
"""
|
| 253 |
+
Stream generate text from messages
|
| 254 |
+
|
| 255 |
+
Yields:
|
| 256 |
+
Generated text chunks
|
| 257 |
+
"""
|
| 258 |
+
formatted_messages = self._format_messages(messages)
|
| 259 |
+
|
| 260 |
+
# Use provided values or fall back to config defaults
|
| 261 |
+
temperature = temperature if temperature is not None else self.default_temperature
|
| 262 |
+
top_p = top_p if top_p is not None else self.default_top_p
|
| 263 |
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
| 264 |
+
presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
|
| 265 |
+
|
| 266 |
+
# Ensure semaphore is initialized (lazy, thread-safe)
|
| 267 |
+
self._ensure_semaphore()
|
| 268 |
+
|
| 269 |
+
# Use semaphore to limit concurrent requests
|
| 270 |
+
with VLLMService._request_semaphore:
|
| 271 |
+
last_exception = None
|
| 272 |
+
|
| 273 |
+
for retry_attempt in range(self.max_retries + 1):
|
| 274 |
+
try:
|
| 275 |
+
stream = self.client.chat.completions.create(
|
| 276 |
+
model=self.model_name,
|
| 277 |
+
messages=formatted_messages,
|
| 278 |
+
temperature=temperature,
|
| 279 |
+
top_p=top_p,
|
| 280 |
+
max_tokens=max_tokens,
|
| 281 |
+
presence_penalty=presence_penalty,
|
| 282 |
+
stream=True,
|
| 283 |
+
**kwargs
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Stream chunks
|
| 287 |
+
for chunk in stream:
|
| 288 |
+
if chunk.choices[0].delta.content:
|
| 289 |
+
yield chunk.choices[0].delta.content
|
| 290 |
+
|
| 291 |
+
return # Success, exit retry loop
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
last_exception = e
|
| 295 |
+
|
| 296 |
+
# Check if it's a server error that we should retry
|
| 297 |
+
should_retry = False
|
| 298 |
+
error_str = str(e).lower()
|
| 299 |
+
|
| 300 |
+
if any(code in error_str for code in ["500", "502", "503", "504"]):
|
| 301 |
+
should_retry = True
|
| 302 |
+
elif "server error" in error_str or "internal server error" in error_str:
|
| 303 |
+
should_retry = True
|
| 304 |
+
|
| 305 |
+
# Don't retry on last attempt
|
| 306 |
+
if retry_attempt < self.max_retries and should_retry:
|
| 307 |
+
# Calculate delay with exponential backoff and jitter
|
| 308 |
+
delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
|
| 309 |
+
jitter = random.uniform(0, delay * 0.1) # 10% jitter
|
| 310 |
+
time.sleep(delay + jitter)
|
| 311 |
+
continue
|
| 312 |
+
else:
|
| 313 |
+
# Either not a retryable error or out of retries
|
| 314 |
+
raise last_exception
|
shared/utils/vllm_service_simple.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified vLLM service for review system
|
| 3 |
+
|
| 4 |
+
This service only handles API calls, no load balancing logic.
|
| 5 |
+
Load balancing should be handled at the deployment service level (e.g., nginx reverse proxy).
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import random
|
| 10 |
+
import yaml
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List, Dict, Optional, Any, Union
|
| 13 |
+
from threading import Semaphore, Lock as ThreadLock
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
from .llm_service import LLMService, ChatMessage
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VLLMService(LLMService):
|
| 19 |
+
"""
|
| 20 |
+
Simplified vLLM service wrapper for local LLM deployment
|
| 21 |
+
|
| 22 |
+
This service connects to a vLLM server endpoint.
|
| 23 |
+
Load balancing should be handled at the deployment level (e.g., nginx, multiple services behind a load balancer).
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Simple API calls to a single endpoint
|
| 27 |
+
- Automatic retry with exponential backoff for 500 errors
|
| 28 |
+
- Configurable max concurrent requests (per service instance)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Class-level semaphore for rate limiting (shared across all instances of this service)
|
| 32 |
+
# Use lazy initialization to avoid pickle issues with multiprocessing
|
| 33 |
+
_request_semaphore: Optional[Semaphore] = None
|
| 34 |
+
_max_concurrent_requests: int = 8 # Default limit
|
| 35 |
+
_semaphore_lock = ThreadLock() # Thread-safe initialization lock
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
base_url: Optional[str] = None,
|
| 40 |
+
api_key: Optional[str] = None,
|
| 41 |
+
model_name: Optional[str] = None,
|
| 42 |
+
timeout: Optional[int] = None,
|
| 43 |
+
config_file: Optional[str] = None,
|
| 44 |
+
max_concurrent_requests: Optional[int] = None,
|
| 45 |
+
max_retries: int = 3,
|
| 46 |
+
retry_delay: float = 1.0,
|
| 47 |
+
retry_backoff: float = 2.0,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Initialize vLLM service
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
base_url: vLLM server base URL (default: from config or http://localhost:8000/v1)
|
| 54 |
+
api_key: API key (overrides config)
|
| 55 |
+
model_name: Model name identifier (overrides config)
|
| 56 |
+
timeout: Request timeout in seconds (overrides config)
|
| 57 |
+
config_file: Path to config file (default: configs/llm_service_config.yaml)
|
| 58 |
+
max_concurrent_requests: Maximum concurrent requests per service instance (default: 8)
|
| 59 |
+
max_retries: Maximum number of retries for failed requests (default: 3)
|
| 60 |
+
retry_delay: Initial retry delay in seconds (default: 1.0)
|
| 61 |
+
retry_backoff: Retry delay multiplier (default: 2.0)
|
| 62 |
+
"""
|
| 63 |
+
# Load config from YAML
|
| 64 |
+
config = self._load_config(config_file)
|
| 65 |
+
vllm_config = config.get("vllm", {})
|
| 66 |
+
|
| 67 |
+
# Use provided values or fall back to config, then environment variables
|
| 68 |
+
self.base_url = base_url or vllm_config.get("base_url") or os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1")
|
| 69 |
+
self.model_name = model_name or vllm_config.get("model_name", "Qwen/Qwen3-4B-Instruct-2507")
|
| 70 |
+
self.api_key = api_key or vllm_config.get("api_key", "dummy-key")
|
| 71 |
+
self.timeout = timeout or vllm_config.get("timeout", 300)
|
| 72 |
+
|
| 73 |
+
# Retry configuration
|
| 74 |
+
self.max_retries = max_retries
|
| 75 |
+
self.retry_delay = retry_delay
|
| 76 |
+
self.retry_backoff = retry_backoff
|
| 77 |
+
|
| 78 |
+
# Rate limiting: Initialize class-level semaphore if not already initialized
|
| 79 |
+
# Use lazy initialization with thread-safe check to avoid pickle issues
|
| 80 |
+
if max_concurrent_requests is not None:
|
| 81 |
+
VLLMService._max_concurrent_requests = max_concurrent_requests
|
| 82 |
+
else:
|
| 83 |
+
# Try to get from config
|
| 84 |
+
config_max_concurrent = vllm_config.get("max_concurrent_requests")
|
| 85 |
+
if config_max_concurrent is not None:
|
| 86 |
+
VLLMService._max_concurrent_requests = config_max_concurrent
|
| 87 |
+
|
| 88 |
+
# Lazy initialization of semaphore will happen on first use
|
| 89 |
+
# This avoids pickle issues when using multiprocessing/ThreadPoolExecutor
|
| 90 |
+
|
| 91 |
+
# Store default sampling parameters from config
|
| 92 |
+
self.default_temperature = vllm_config.get("temperature", 0.7)
|
| 93 |
+
self.default_top_p = vllm_config.get("top_p", 0.8)
|
| 94 |
+
self.default_top_k = vllm_config.get("top_k", 20)
|
| 95 |
+
self.default_max_tokens = vllm_config.get("max_tokens", 16384)
|
| 96 |
+
self.default_presence_penalty = vllm_config.get("presence_penalty", 0.0)
|
| 97 |
+
|
| 98 |
+
# Create OpenAI client
|
| 99 |
+
self.client = OpenAI(
|
| 100 |
+
api_key=self.api_key,
|
| 101 |
+
base_url=self.base_url,
|
| 102 |
+
timeout=self.timeout,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _load_config(config_file: Optional[str] = None) -> Dict[str, Any]:
|
| 107 |
+
"""
|
| 108 |
+
Load configuration from YAML file
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
config_file: Path to config file
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Configuration dictionary
|
| 115 |
+
"""
|
| 116 |
+
if config_file is None:
|
| 117 |
+
project_root = Path(__file__).parent.parent.parent
|
| 118 |
+
config_file = project_root / "shared" / "configs" / "llm_service_config.yaml"
|
| 119 |
+
|
| 120 |
+
config_path = Path(config_file)
|
| 121 |
+
if not config_path.exists():
|
| 122 |
+
# Return defaults if config file doesn't exist
|
| 123 |
+
return {
|
| 124 |
+
"vllm": {
|
| 125 |
+
"base_url": "http://localhost:8000/v1",
|
| 126 |
+
"api_key": "dummy-key",
|
| 127 |
+
"model_name": "Qwen/Qwen3-4B-Instruct-2507",
|
| 128 |
+
"timeout": 300,
|
| 129 |
+
"max_concurrent_requests": 8,
|
| 130 |
+
"max_retries": 3,
|
| 131 |
+
"retry_delay": 1.0,
|
| 132 |
+
"retry_backoff": 2.0,
|
| 133 |
+
"temperature": 0.7,
|
| 134 |
+
"top_p": 0.8,
|
| 135 |
+
"top_k": 20,
|
| 136 |
+
"max_tokens": 16384,
|
| 137 |
+
"presence_penalty": 0.0,
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 142 |
+
return yaml.safe_load(f) or {}
|
| 143 |
+
|
| 144 |
+
@classmethod
|
| 145 |
+
def _ensure_semaphore(cls):
|
| 146 |
+
"""Thread-safe lazy initialization of semaphore to avoid pickle issues"""
|
| 147 |
+
if cls._request_semaphore is None:
|
| 148 |
+
with cls._semaphore_lock:
|
| 149 |
+
# Double-check pattern
|
| 150 |
+
if cls._request_semaphore is None:
|
| 151 |
+
cls._request_semaphore = Semaphore(cls._max_concurrent_requests)
|
| 152 |
+
|
| 153 |
+
def _format_messages(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> List[Dict[str, str]]:
|
| 154 |
+
"""Format messages for OpenAI API"""
|
| 155 |
+
formatted = []
|
| 156 |
+
for msg in messages:
|
| 157 |
+
if isinstance(msg, ChatMessage):
|
| 158 |
+
formatted.append({"role": msg.role, "content": msg.content})
|
| 159 |
+
elif isinstance(msg, dict):
|
| 160 |
+
formatted.append(msg)
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"Invalid message type: {type(msg)}")
|
| 163 |
+
return formatted
|
| 164 |
+
|
| 165 |
+
def generate(
|
| 166 |
+
self,
|
| 167 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 168 |
+
temperature: Optional[float] = None,
|
| 169 |
+
top_p: Optional[float] = None,
|
| 170 |
+
top_k: Optional[int] = None,
|
| 171 |
+
max_tokens: Optional[int] = None,
|
| 172 |
+
presence_penalty: Optional[float] = None,
|
| 173 |
+
**kwargs
|
| 174 |
+
) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Generate text from messages
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
messages: List of chat messages
|
| 180 |
+
temperature: Sampling temperature (uses config default if None)
|
| 181 |
+
top_p: Top-p sampling parameter (uses config default if None)
|
| 182 |
+
top_k: Top-k sampling parameter (uses config default if None)
|
| 183 |
+
max_tokens: Maximum tokens to generate (uses config default if None)
|
| 184 |
+
presence_penalty: Presence penalty (uses config default if None)
|
| 185 |
+
**kwargs: Additional parameters
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Generated text
|
| 189 |
+
"""
|
| 190 |
+
formatted_messages = self._format_messages(messages)
|
| 191 |
+
|
| 192 |
+
# Use provided values or fall back to config defaults
|
| 193 |
+
temperature = temperature if temperature is not None else self.default_temperature
|
| 194 |
+
top_p = top_p if top_p is not None else self.default_top_p
|
| 195 |
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
| 196 |
+
presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
|
| 197 |
+
|
| 198 |
+
# Ensure semaphore is initialized (lazy, thread-safe)
|
| 199 |
+
self._ensure_semaphore()
|
| 200 |
+
|
| 201 |
+
# Use semaphore to limit concurrent requests
|
| 202 |
+
with VLLMService._request_semaphore:
|
| 203 |
+
last_exception = None
|
| 204 |
+
|
| 205 |
+
for retry_attempt in range(self.max_retries + 1):
|
| 206 |
+
try:
|
| 207 |
+
response = self.client.chat.completions.create(
|
| 208 |
+
model=self.model_name,
|
| 209 |
+
messages=formatted_messages,
|
| 210 |
+
temperature=temperature,
|
| 211 |
+
top_p=top_p,
|
| 212 |
+
max_tokens=max_tokens,
|
| 213 |
+
presence_penalty=presence_penalty,
|
| 214 |
+
**kwargs
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return response.choices[0].message.content
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
last_exception = e
|
| 221 |
+
|
| 222 |
+
# Check if it's a server error (500, 502, 503, 504) that we should retry
|
| 223 |
+
should_retry = False
|
| 224 |
+
error_str = str(e).lower()
|
| 225 |
+
|
| 226 |
+
if any(code in error_str for code in ["500", "502", "503", "504"]):
|
| 227 |
+
should_retry = True
|
| 228 |
+
elif "server error" in error_str or "internal server error" in error_str:
|
| 229 |
+
should_retry = True
|
| 230 |
+
|
| 231 |
+
# Don't retry on last attempt
|
| 232 |
+
if retry_attempt < self.max_retries and should_retry:
|
| 233 |
+
# Calculate delay with exponential backoff and jitter
|
| 234 |
+
delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
|
| 235 |
+
jitter = random.uniform(0, delay * 0.1) # 10% jitter
|
| 236 |
+
time.sleep(delay + jitter)
|
| 237 |
+
continue
|
| 238 |
+
else:
|
| 239 |
+
# Either not a retryable error or out of retries
|
| 240 |
+
raise last_exception
|
| 241 |
+
|
| 242 |
+
def stream_generate(
|
| 243 |
+
self,
|
| 244 |
+
messages: List[Union[ChatMessage, Dict[str, str]]],
|
| 245 |
+
temperature: Optional[float] = None,
|
| 246 |
+
top_p: Optional[float] = None,
|
| 247 |
+
top_k: Optional[int] = None,
|
| 248 |
+
max_tokens: Optional[int] = None,
|
| 249 |
+
presence_penalty: Optional[float] = None,
|
| 250 |
+
**kwargs
|
| 251 |
+
):
|
| 252 |
+
"""
|
| 253 |
+
Stream generate text from messages
|
| 254 |
+
|
| 255 |
+
Yields:
|
| 256 |
+
Generated text chunks
|
| 257 |
+
"""
|
| 258 |
+
formatted_messages = self._format_messages(messages)
|
| 259 |
+
|
| 260 |
+
# Use provided values or fall back to config defaults
|
| 261 |
+
temperature = temperature if temperature is not None else self.default_temperature
|
| 262 |
+
top_p = top_p if top_p is not None else self.default_top_p
|
| 263 |
+
max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
|
| 264 |
+
presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
|
| 265 |
+
|
| 266 |
+
# Ensure semaphore is initialized (lazy, thread-safe)
|
| 267 |
+
self._ensure_semaphore()
|
| 268 |
+
|
| 269 |
+
# Use semaphore to limit concurrent requests
|
| 270 |
+
with VLLMService._request_semaphore:
|
| 271 |
+
last_exception = None
|
| 272 |
+
|
| 273 |
+
for retry_attempt in range(self.max_retries + 1):
|
| 274 |
+
try:
|
| 275 |
+
stream = self.client.chat.completions.create(
|
| 276 |
+
model=self.model_name,
|
| 277 |
+
messages=formatted_messages,
|
| 278 |
+
temperature=temperature,
|
| 279 |
+
top_p=top_p,
|
| 280 |
+
max_tokens=max_tokens,
|
| 281 |
+
presence_penalty=presence_penalty,
|
| 282 |
+
stream=True,
|
| 283 |
+
**kwargs
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Stream chunks
|
| 287 |
+
for chunk in stream:
|
| 288 |
+
if chunk.choices[0].delta.content:
|
| 289 |
+
yield chunk.choices[0].delta.content
|
| 290 |
+
|
| 291 |
+
return # Success, exit retry loop
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
last_exception = e
|
| 295 |
+
|
| 296 |
+
# Check if it's a server error that we should retry
|
| 297 |
+
should_retry = False
|
| 298 |
+
error_str = str(e).lower()
|
| 299 |
+
|
| 300 |
+
if any(code in error_str for code in ["500", "502", "503", "504"]):
|
| 301 |
+
should_retry = True
|
| 302 |
+
elif "server error" in error_str or "internal server error" in error_str:
|
| 303 |
+
should_retry = True
|
| 304 |
+
|
| 305 |
+
# Don't retry on last attempt
|
| 306 |
+
if retry_attempt < self.max_retries and should_retry:
|
| 307 |
+
# Calculate delay with exponential backoff and jitter
|
| 308 |
+
delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
|
| 309 |
+
jitter = random.uniform(0, delay * 0.1) # 10% jitter
|
| 310 |
+
time.sleep(delay + jitter)
|
| 311 |
+
continue
|
| 312 |
+
else:
|
| 313 |
+
# Either not a retryable error or out of retries
|
| 314 |
+
raise last_exception
|
src/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Review System
|
| 3 |
+
|
| 4 |
+
A comprehensive system for paper review generation and evaluation.
|
| 5 |
+
"""
|
| 6 |
+
__version__ = "1.0.0"
|
src/evaluator/1_get_rubrics.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate review-based rubrics by querying LLMs with concurrent parallel requests.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Reads the JSON file with review data
|
| 6 |
+
2. Extracts entries with 'id', 'pred_fast_mode_baseline', 'paper_context', and 'decision'
|
| 7 |
+
3. Loads the rubric generation prompt from prompts.yaml
|
| 8 |
+
4. Loads LLM configuration from configs.yaml (supports gpt and vllm modes)
|
| 9 |
+
5. For each entry, generates rubrics by replacing <<golden_review>> with the ground truth review
|
| 10 |
+
6. Uses concurrent parallel requests (ThreadPoolExecutor) for efficient LLM queries
|
| 11 |
+
7. Extracts rubrics from LLM responses and saves to eval_rubrics.json
|
| 12 |
+
|
| 13 |
+
Output JSON file (eval_rubrics.json) contains a list of dicts with:
|
| 14 |
+
- id: Entry identifier
|
| 15 |
+
- paper_context: Paper content
|
| 16 |
+
- decision: Decision field from input
|
| 17 |
+
- golden_review: The pred_fast_mode_baseline review (ground truth)
|
| 18 |
+
- rubrics: List of rubric objects, each with title, description, and weight
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
python 1_generate_review_based_rubrics.py \
|
| 22 |
+
--json_path input.json \
|
| 23 |
+
--output_path eval_rubrics.json \
|
| 24 |
+
--yaml_path prompts.yaml \
|
| 25 |
+
--config_path configs.yaml \
|
| 26 |
+
--max_workers 5
|
| 27 |
+
|
| 28 |
+
The configs.yaml should specify either "gpt" or "vllm" mode and corresponding settings.
|
| 29 |
+
"""
|
| 30 |
+
import json
|
| 31 |
+
import os
|
| 32 |
+
import sys
|
| 33 |
+
import argparse
|
| 34 |
+
import yaml
|
| 35 |
+
from typing import Dict, List, Any, Optional
|
| 36 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
import pandas as pd
|
| 39 |
+
from dotenv import load_dotenv
|
| 40 |
+
|
| 41 |
+
# Add parent directory to path to import llm_service
|
| 42 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 43 |
+
# Import parse_llm_response from local llm_service module (for parsing LLM responses)
|
| 44 |
+
import llm_service as local_llm_service
|
| 45 |
+
parse_llm_response = local_llm_service.parse_llm_response
|
| 46 |
+
|
| 47 |
+
# Import from shared/utils for gpt/vllm support
|
| 48 |
+
# Add project root to path to enable absolute imports from shared.utils
|
| 49 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 50 |
+
if project_root not in sys.path:
|
| 51 |
+
sys.path.insert(0, project_root)
|
| 52 |
+
|
| 53 |
+
# Use absolute imports from shared.utils package
|
| 54 |
+
from shared.utils.llm_service import LLMService
|
| 55 |
+
from shared.utils.vllm_service import VLLMService
|
| 56 |
+
from shared.utils.gpt_service import GPTService
|
| 57 |
+
|
| 58 |
+
# Load environment variables
|
| 59 |
+
load_dotenv()
|
| 60 |
+
|
| 61 |
+
class ReviewProcessor:
|
| 62 |
+
"""Handles the extraction and processing of reviews from different sources."""
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def extract_review_content(pred_context):
|
| 66 |
+
"""
|
| 67 |
+
Extract the review content from the prediction context.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
pred_context: Raw prediction data that contains the review
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
str: Extracted review content
|
| 74 |
+
"""
|
| 75 |
+
try:
|
| 76 |
+
# First attempt to extract from boxed format
|
| 77 |
+
return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 78 |
+
except Exception:
|
| 79 |
+
# Alternative extraction if the first method fails
|
| 80 |
+
if isinstance(pred_context, dict) and 'output' in pred_context:
|
| 81 |
+
return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 82 |
+
else:
|
| 83 |
+
# Return as is if extraction fails
|
| 84 |
+
return pred_context
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_json_data(json_path: str) -> List[Dict[str, Any]]:
|
| 89 |
+
"""
|
| 90 |
+
Load JSON data from file.
|
| 91 |
+
Handles both list and dict formats.
|
| 92 |
+
"""
|
| 93 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 94 |
+
data = json.load(f)
|
| 95 |
+
|
| 96 |
+
# Convert dict to list if needed
|
| 97 |
+
if isinstance(data, dict):
|
| 98 |
+
data = list(data.values())
|
| 99 |
+
|
| 100 |
+
return data
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_prompt_template(yaml_path: str) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Load the rubric generation prompt from YAML file.
|
| 106 |
+
"""
|
| 107 |
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
| 108 |
+
prompts = yaml.safe_load(f)
|
| 109 |
+
|
| 110 |
+
prompt_template = prompts.get('v2_rubric_generation_prompt', '')
|
| 111 |
+
rubric_template = prompts.get('rubrics', '')
|
| 112 |
+
|
| 113 |
+
prompt_template = prompt_template.replace('<<rubric_template>>', rubric_template)
|
| 114 |
+
|
| 115 |
+
return prompt_template
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def clean_rubrics_json(json_str: str) -> str:
|
| 119 |
+
"""
|
| 120 |
+
Clean JSON string by escaping unescaped double quotes inside string values.
|
| 121 |
+
|
| 122 |
+
This function handles cases where the model outputs double quotes inside
|
| 123 |
+
string values (especially in description fields) without proper escaping.
|
| 124 |
+
|
| 125 |
+
The expected format is a JSON array of objects with "title", "description", "weight" fields.
|
| 126 |
+
Strategy: Find each field's value and escape unescaped quotes inside it.
|
| 127 |
+
"""
|
| 128 |
+
import re
|
| 129 |
+
|
| 130 |
+
# First, try to extract JSON array if wrapped in markdown code blocks
|
| 131 |
+
json_match = re.search(r'```json\s*(\[.*?\])\s*```', json_str, re.DOTALL)
|
| 132 |
+
if json_match:
|
| 133 |
+
json_str = json_match.group(1)
|
| 134 |
+
else:
|
| 135 |
+
# Try to find JSON array directly
|
| 136 |
+
json_match = re.search(r'(\[.*?\])', json_str, re.DOTALL)
|
| 137 |
+
if json_match:
|
| 138 |
+
json_str = json_match.group(1)
|
| 139 |
+
|
| 140 |
+
# Process the JSON string character by character to find and fix string values
|
| 141 |
+
# We'll look for patterns like "field": " and then find the matching closing quote
|
| 142 |
+
result = []
|
| 143 |
+
i = 0
|
| 144 |
+
|
| 145 |
+
while i < len(json_str):
|
| 146 |
+
# Look for field pattern: "field_name": "
|
| 147 |
+
field_match = re.search(r'"(title|description|weight)"\s*:\s*"', json_str[i:])
|
| 148 |
+
if not field_match:
|
| 149 |
+
# No more fields to process, append rest and break
|
| 150 |
+
result.append(json_str[i:])
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
# Append everything before the match
|
| 154 |
+
match_start = i + field_match.start()
|
| 155 |
+
result.append(json_str[i:match_start])
|
| 156 |
+
|
| 157 |
+
# Process the field value
|
| 158 |
+
value_start = i + field_match.end() # Position after opening quote
|
| 159 |
+
|
| 160 |
+
# Find the closing quote by scanning character by character
|
| 161 |
+
# The closing quote should be followed by comma, closing brace, or closing bracket
|
| 162 |
+
j = value_start
|
| 163 |
+
found_closing = False
|
| 164 |
+
|
| 165 |
+
while j < len(json_str):
|
| 166 |
+
if json_str[j] == '\\':
|
| 167 |
+
# Skip escaped character (could be \", \\, etc.)
|
| 168 |
+
if j + 1 < len(json_str):
|
| 169 |
+
j += 2
|
| 170 |
+
continue
|
| 171 |
+
else:
|
| 172 |
+
j += 1
|
| 173 |
+
break
|
| 174 |
+
elif json_str[j] == '"':
|
| 175 |
+
# Found a quote - check if it's the closing quote
|
| 176 |
+
# Look ahead (skip whitespace) to see if followed by comma, brace, or bracket
|
| 177 |
+
k = j + 1
|
| 178 |
+
while k < len(json_str) and json_str[k] in ' \t\n\r':
|
| 179 |
+
k += 1
|
| 180 |
+
|
| 181 |
+
if k < len(json_str) and json_str[k] in ',}]':
|
| 182 |
+
# This is the closing quote!
|
| 183 |
+
value_content = json_str[value_start:j]
|
| 184 |
+
closing_part = json_str[j:k+1] # " followed by , } or ]
|
| 185 |
+
|
| 186 |
+
# Fix unescaped quotes in value_content
|
| 187 |
+
# Strategy: preserve already-escaped quotes, escape others
|
| 188 |
+
fixed_content = value_content.replace('\\"', '__TEMP_ESC__')
|
| 189 |
+
fixed_content = fixed_content.replace('"', '\\"')
|
| 190 |
+
fixed_content = fixed_content.replace('__TEMP_ESC__', '\\"')
|
| 191 |
+
|
| 192 |
+
# Append the fixed field
|
| 193 |
+
result.append(json_str[match_start:value_start]) # "field": "
|
| 194 |
+
result.append(fixed_content) # fixed value content
|
| 195 |
+
result.append(closing_part) # " followed by punctuation
|
| 196 |
+
|
| 197 |
+
i = k + 1
|
| 198 |
+
found_closing = True
|
| 199 |
+
break
|
| 200 |
+
j += 1
|
| 201 |
+
|
| 202 |
+
if not found_closing:
|
| 203 |
+
# Couldn't find proper closing quote, append rest and break
|
| 204 |
+
result.append(json_str[match_start:])
|
| 205 |
+
break
|
| 206 |
+
|
| 207 |
+
return ''.join(result)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def extract_rubrics_from_response(response: str) -> Optional[List[Dict[str, Any]]]:
|
| 211 |
+
"""
|
| 212 |
+
Extract rubrics (JSON array) from LLM response.
|
| 213 |
+
Handles cases where description fields contain unescaped double quotes.
|
| 214 |
+
Returns None if parsing fails (silently, no error messages printed).
|
| 215 |
+
"""
|
| 216 |
+
try:
|
| 217 |
+
# First, try using parse_llm_response (handles markdown blocks)
|
| 218 |
+
try:
|
| 219 |
+
parsed = parse_llm_response(response)
|
| 220 |
+
|
| 221 |
+
# Check if parsed result is a list (array of rubrics)
|
| 222 |
+
if isinstance(parsed, list):
|
| 223 |
+
return parsed
|
| 224 |
+
|
| 225 |
+
# If parsed result is a dict, check for common keys that might contain the array
|
| 226 |
+
if isinstance(parsed, dict):
|
| 227 |
+
# Check for common keys
|
| 228 |
+
for key in ['rubrics', 'rubric', 'items', 'criteria']:
|
| 229 |
+
if key in parsed and isinstance(parsed[key], list):
|
| 230 |
+
return parsed[key]
|
| 231 |
+
|
| 232 |
+
# If no key found, try to find the first list value
|
| 233 |
+
for value in parsed.values():
|
| 234 |
+
if isinstance(value, list):
|
| 235 |
+
return value
|
| 236 |
+
except Exception:
|
| 237 |
+
# parse_llm_response failed, try manual cleaning
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
# If parse_llm_response failed, try manual extraction and cleaning
|
| 241 |
+
import re
|
| 242 |
+
|
| 243 |
+
# Try to find JSON array in response
|
| 244 |
+
json_match = re.search(r'\[.*?\]', response, re.DOTALL)
|
| 245 |
+
if json_match:
|
| 246 |
+
json_str = json_match.group(0)
|
| 247 |
+
|
| 248 |
+
# Try direct parsing first
|
| 249 |
+
try:
|
| 250 |
+
rubrics = json.loads(json_str)
|
| 251 |
+
if isinstance(rubrics, list):
|
| 252 |
+
return rubrics
|
| 253 |
+
except json.JSONDecodeError:
|
| 254 |
+
# JSON parsing failed, try cleaning
|
| 255 |
+
try:
|
| 256 |
+
cleaned_json = clean_rubrics_json(json_str)
|
| 257 |
+
rubrics = json.loads(cleaned_json)
|
| 258 |
+
if isinstance(rubrics, list):
|
| 259 |
+
return rubrics
|
| 260 |
+
except Exception:
|
| 261 |
+
# Last resort: try a more aggressive cleaning approach
|
| 262 |
+
try:
|
| 263 |
+
# Replace unescaped quotes in description fields more aggressively
|
| 264 |
+
# Pattern: "description": "..." where ... may contain quotes
|
| 265 |
+
def fix_description_quotes(match):
|
| 266 |
+
prefix = match.group(1) # "description": "
|
| 267 |
+
content = match.group(2) # the content
|
| 268 |
+
suffix = match.group(3) # closing quote
|
| 269 |
+
|
| 270 |
+
# Escape all quotes in content, but preserve escaped ones
|
| 271 |
+
# First, mark escaped quotes temporarily
|
| 272 |
+
content = content.replace('\\"', '__ESCAPED_QUOTE__')
|
| 273 |
+
# Escape all remaining quotes
|
| 274 |
+
content = content.replace('"', '\\"')
|
| 275 |
+
# Restore escaped quotes
|
| 276 |
+
content = content.replace('__ESCAPED_QUOTE__', '\\"')
|
| 277 |
+
|
| 278 |
+
return prefix + content + suffix
|
| 279 |
+
|
| 280 |
+
# More specific pattern for description field
|
| 281 |
+
desc_pattern = r'("description"\s*:\s*")(.*?)("(?:\s*[,}])?)'
|
| 282 |
+
fixed_json = re.sub(desc_pattern, fix_description_quotes, json_str, flags=re.DOTALL)
|
| 283 |
+
|
| 284 |
+
rubrics = json.loads(fixed_json)
|
| 285 |
+
if isinstance(rubrics, list):
|
| 286 |
+
return rubrics
|
| 287 |
+
except Exception:
|
| 288 |
+
pass
|
| 289 |
+
|
| 290 |
+
# If all else fails, return None (silently)
|
| 291 |
+
return None
|
| 292 |
+
|
| 293 |
+
except Exception:
|
| 294 |
+
# Any unexpected error, return None (silently)
|
| 295 |
+
return None
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def generate_rubrics_for_entry(
|
| 299 |
+
entry: Dict[str, Any],
|
| 300 |
+
prompt_template: str,
|
| 301 |
+
llm_service: LLMService,
|
| 302 |
+
max_retries: int = 16
|
| 303 |
+
) -> Dict[str, Any]:
|
| 304 |
+
"""
|
| 305 |
+
Generate rubrics for a single entry with retry mechanism.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
entry: Dictionary with 'id', 'pred_fast_mode_baseline', 'paper_context', 'decision'
|
| 309 |
+
prompt_template: Prompt template with <<golden_review>> placeholder
|
| 310 |
+
llm_service: LLMService instance (VLLMService or GPTService)
|
| 311 |
+
max_retries: Maximum number of retries if JSON parsing fails (default: 16)
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Dictionary with 'id', 'paper_context', 'decision', 'golden_review', 'rubrics' (list)
|
| 315 |
+
"""
|
| 316 |
+
entry_id = entry.get('id', 'unknown')
|
| 317 |
+
golden_review = entry.get('pred_fast_mode_baseline', '')
|
| 318 |
+
paper_context = entry.get('paper_context', '')
|
| 319 |
+
decision = entry.get('decision', '')
|
| 320 |
+
|
| 321 |
+
# Replace placeholder in prompt template
|
| 322 |
+
prompt = prompt_template.replace('<<golden_review>>', golden_review)
|
| 323 |
+
prompt = prompt.replace('<<paper_context>>', paper_context)
|
| 324 |
+
|
| 325 |
+
# Convert prompt to messages format (shared/utils services use messages format)
|
| 326 |
+
messages = [{"role": "user", "content": prompt}]
|
| 327 |
+
|
| 328 |
+
# Retry loop for JSON parsing failures
|
| 329 |
+
last_error = None
|
| 330 |
+
for attempt in range(max_retries):
|
| 331 |
+
try:
|
| 332 |
+
# Generate response from LLM
|
| 333 |
+
response = llm_service.generate(messages=messages)
|
| 334 |
+
|
| 335 |
+
# Extract rubrics from response
|
| 336 |
+
rubrics_list = extract_rubrics_from_response(response)
|
| 337 |
+
|
| 338 |
+
# If successful, return the result (silently, no output during retries)
|
| 339 |
+
if rubrics_list is not None and isinstance(rubrics_list, list):
|
| 340 |
+
return {
|
| 341 |
+
'id': entry_id,
|
| 342 |
+
'paper_context': paper_context,
|
| 343 |
+
'decision': decision,
|
| 344 |
+
'golden_review': golden_review,
|
| 345 |
+
'rubrics': rubrics_list
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
# If extraction failed, continue retrying silently
|
| 349 |
+
# Store the error message for the last attempt
|
| 350 |
+
if attempt == max_retries - 1:
|
| 351 |
+
last_error = "Failed to extract valid rubrics from response"
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
# Store the error (will be overwritten by subsequent attempts until the last one)
|
| 355 |
+
last_error = e
|
| 356 |
+
|
| 357 |
+
# All retries failed, output warning only once
|
| 358 |
+
if last_error:
|
| 359 |
+
print(f"[WARN] Failed to generate rubrics for entry {entry_id} after {max_retries} attempts: {last_error}")
|
| 360 |
+
|
| 361 |
+
# All retries failed, return with empty rubrics
|
| 362 |
+
result = {
|
| 363 |
+
'id': entry_id,
|
| 364 |
+
'paper_context': paper_context,
|
| 365 |
+
'decision': decision,
|
| 366 |
+
'golden_review': golden_review,
|
| 367 |
+
'rubrics': [] # Empty list as fallback
|
| 368 |
+
}
|
| 369 |
+
if last_error:
|
| 370 |
+
result['error'] = str(last_error)
|
| 371 |
+
return result
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def load_llm_config(config_path: str) -> Dict[str, Any]:
|
| 375 |
+
"""
|
| 376 |
+
Load LLM configuration from YAML file.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
config_path: Path to configs.yaml file
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
Configuration dictionary
|
| 383 |
+
"""
|
| 384 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 385 |
+
config = yaml.safe_load(f)
|
| 386 |
+
return config
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
|
| 390 |
+
"""
|
| 391 |
+
Create LLM service from configuration.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
config: Configuration dictionary from configs.yaml
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
LLMService instance (VLLMService or GPTService)
|
| 398 |
+
"""
|
| 399 |
+
mode = config.get('mode', 'gpt').lower()
|
| 400 |
+
|
| 401 |
+
if mode == 'gpt':
|
| 402 |
+
gpt_config = config.get('gpt', {})
|
| 403 |
+
api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
|
| 404 |
+
if not api_key:
|
| 405 |
+
raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
|
| 406 |
+
|
| 407 |
+
service = GPTService(
|
| 408 |
+
api_key=api_key,
|
| 409 |
+
model_name=gpt_config.get('model_name', 'gpt-4o'),
|
| 410 |
+
base_url=gpt_config.get('base_url'),
|
| 411 |
+
timeout=gpt_config.get('timeout', 300)
|
| 412 |
+
)
|
| 413 |
+
return service
|
| 414 |
+
|
| 415 |
+
elif mode == 'vllm':
|
| 416 |
+
vllm_config = config.get('vllm', {})
|
| 417 |
+
service = VLLMService(
|
| 418 |
+
base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
|
| 419 |
+
api_key=vllm_config.get('api_key', 'dummy-key'),
|
| 420 |
+
model_name=vllm_config.get('model_name'),
|
| 421 |
+
timeout=vllm_config.get('timeout', 300),
|
| 422 |
+
max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
|
| 423 |
+
max_retries=vllm_config.get('max_retries', 3),
|
| 424 |
+
retry_delay=vllm_config.get('retry_delay', 1.0),
|
| 425 |
+
retry_backoff=vllm_config.get('retry_backoff', 2.0)
|
| 426 |
+
)
|
| 427 |
+
return service
|
| 428 |
+
|
| 429 |
+
else:
|
| 430 |
+
raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def parse_args():
|
| 434 |
+
"""Parse command line arguments."""
|
| 435 |
+
parser = argparse.ArgumentParser(description="Generate review-based rubrics using LLMs")
|
| 436 |
+
|
| 437 |
+
# Input/Output paths
|
| 438 |
+
parser.add_argument("--json_path", type=str, required=True,
|
| 439 |
+
help="Path to input JSON file with review data")
|
| 440 |
+
parser.add_argument("--output_path", type=str, default=None,
|
| 441 |
+
help="Path to output JSON file (default: eval_rubrics.json in same dir as input)")
|
| 442 |
+
parser.add_argument("--yaml_path", type=str, default=None,
|
| 443 |
+
help="Path to prompts.yaml file (default: prompts.yaml in same dir as script)")
|
| 444 |
+
parser.add_argument("--config_path", type=str, default=None,
|
| 445 |
+
help="Path to configs.yaml file (default: configs.yaml in same dir as script)")
|
| 446 |
+
|
| 447 |
+
# Multi-threading
|
| 448 |
+
parser.add_argument("--max_workers", type=int, default=None,
|
| 449 |
+
help="Maximum number of worker threads (default: from MAX_WORKERS env var or 5)")
|
| 450 |
+
|
| 451 |
+
return parser.parse_args()
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def main():
|
| 455 |
+
"""Main execution function."""
|
| 456 |
+
args = parse_args()
|
| 457 |
+
|
| 458 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 459 |
+
|
| 460 |
+
# File paths
|
| 461 |
+
json_path = args.json_path
|
| 462 |
+
if not os.path.isabs(json_path):
|
| 463 |
+
json_path = os.path.join(script_dir, json_path)
|
| 464 |
+
|
| 465 |
+
if args.output_path:
|
| 466 |
+
output_path = args.output_path
|
| 467 |
+
if not os.path.isabs(output_path):
|
| 468 |
+
output_path = os.path.join(script_dir, output_path)
|
| 469 |
+
else:
|
| 470 |
+
# Default: same directory as input JSON, with eval_rubrics.json name
|
| 471 |
+
output_dir = os.path.dirname(json_path)
|
| 472 |
+
output_path = os.path.join(output_dir, 'eval_rubrics.json')
|
| 473 |
+
|
| 474 |
+
if args.yaml_path:
|
| 475 |
+
yaml_path = args.yaml_path
|
| 476 |
+
if not os.path.isabs(yaml_path):
|
| 477 |
+
yaml_path = os.path.join(script_dir, yaml_path)
|
| 478 |
+
else:
|
| 479 |
+
yaml_path = os.path.join(script_dir, 'prompts.yaml')
|
| 480 |
+
|
| 481 |
+
if args.config_path:
|
| 482 |
+
config_path = args.config_path
|
| 483 |
+
if not os.path.isabs(config_path):
|
| 484 |
+
config_path = os.path.join(script_dir, config_path)
|
| 485 |
+
else:
|
| 486 |
+
config_path = os.path.join(script_dir, 'configs.yaml')
|
| 487 |
+
|
| 488 |
+
max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
|
| 489 |
+
|
| 490 |
+
# Check if files exist
|
| 491 |
+
if not os.path.exists(json_path):
|
| 492 |
+
raise FileNotFoundError(f"JSON file not found: {json_path}")
|
| 493 |
+
if not os.path.exists(yaml_path):
|
| 494 |
+
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
|
| 495 |
+
if not os.path.exists(config_path):
|
| 496 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 497 |
+
|
| 498 |
+
print(f"Loading JSON data from {json_path}...")
|
| 499 |
+
data = load_json_data(json_path)
|
| 500 |
+
print(f"Loaded {len(data)} entries")
|
| 501 |
+
|
| 502 |
+
print(f"Loading prompt template from {yaml_path}...")
|
| 503 |
+
prompt_template = load_prompt_template(yaml_path)
|
| 504 |
+
if not prompt_template:
|
| 505 |
+
raise ValueError("Could not find 'v2_rubric_generation_prompt' in YAML file")
|
| 506 |
+
print("Prompt template loaded successfully")
|
| 507 |
+
|
| 508 |
+
# Load LLM configuration and create service
|
| 509 |
+
print(f"Loading LLM configuration from {config_path}...")
|
| 510 |
+
llm_config = load_llm_config(config_path)
|
| 511 |
+
llm_service = create_llm_service_from_config(llm_config)
|
| 512 |
+
mode = llm_config.get('mode', 'gpt')
|
| 513 |
+
print(f"LLM service initialized (mode: {mode})")
|
| 514 |
+
if hasattr(llm_service, 'model_name'):
|
| 515 |
+
print(f"Using model: {llm_service.model_name}")
|
| 516 |
+
|
| 517 |
+
# Extract required fields from each entry
|
| 518 |
+
print("Extracting required fields from entries...")
|
| 519 |
+
entries = []
|
| 520 |
+
for item in data:
|
| 521 |
+
if 'id' in item and 'pred_fast_mode_baseline' in item:
|
| 522 |
+
entries.append(item)
|
| 523 |
+
else:
|
| 524 |
+
print(f"[WARN] Skipping entry missing required fields: {item.get('id', 'unknown')}")
|
| 525 |
+
|
| 526 |
+
print(f"Processing {len(entries)} entries with {max_workers} workers...")
|
| 527 |
+
|
| 528 |
+
# Generate rubrics using concurrent processing
|
| 529 |
+
results = []
|
| 530 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 531 |
+
# Submit all tasks
|
| 532 |
+
future_to_entry = {
|
| 533 |
+
executor.submit(
|
| 534 |
+
generate_rubrics_for_entry,
|
| 535 |
+
entry,
|
| 536 |
+
prompt_template,
|
| 537 |
+
llm_service
|
| 538 |
+
): entry
|
| 539 |
+
for entry in entries
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
# Process completed tasks with progress bar
|
| 543 |
+
for future in tqdm(as_completed(future_to_entry), total=len(entries), desc="Generating rubrics"):
|
| 544 |
+
try:
|
| 545 |
+
result = future.result()
|
| 546 |
+
results.append(result)
|
| 547 |
+
except Exception as e:
|
| 548 |
+
entry = future_to_entry[future]
|
| 549 |
+
entry_id = entry.get('id', 'unknown')
|
| 550 |
+
print(f"\n[ERROR] Failed to process entry {entry_id}: {e}")
|
| 551 |
+
# Add error entry with empty rubrics
|
| 552 |
+
results.append({
|
| 553 |
+
'id': entry_id,
|
| 554 |
+
'paper_context': entry.get('paper_context', ''),
|
| 555 |
+
'decision': entry.get('decision', ''),
|
| 556 |
+
'golden_review': entry.get('pred_fast_mode_baseline', ''),
|
| 557 |
+
'rubrics': [],
|
| 558 |
+
'error': str(e)
|
| 559 |
+
})
|
| 560 |
+
|
| 561 |
+
print(f"\nSuccessfully generated rubrics for {len(results)} entries")
|
| 562 |
+
|
| 563 |
+
# Save to JSON
|
| 564 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 565 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 566 |
+
print(f"\nResults saved to {output_path}")
|
| 567 |
+
|
| 568 |
+
# Print summary statistics
|
| 569 |
+
print("\n" + "="*80)
|
| 570 |
+
print("SUMMARY STATISTICS")
|
| 571 |
+
print("="*80)
|
| 572 |
+
print(f"Total entries processed: {len(results)}")
|
| 573 |
+
|
| 574 |
+
# Count successful vs failed
|
| 575 |
+
successful = sum(1 for r in results if 'error' not in r and len(r.get('rubrics', [])) > 0)
|
| 576 |
+
failed = len(results) - successful
|
| 577 |
+
print(f"Successful: {successful}")
|
| 578 |
+
print(f"Failed: {failed}")
|
| 579 |
+
|
| 580 |
+
# Check rubrics statistics
|
| 581 |
+
rubric_counts = [len(r.get('rubrics', [])) for r in results if isinstance(r.get('rubrics'), list)]
|
| 582 |
+
|
| 583 |
+
if rubric_counts:
|
| 584 |
+
print(f"\nRubrics per entry:")
|
| 585 |
+
print(f" Mean: {sum(rubric_counts) / len(rubric_counts):.2f}")
|
| 586 |
+
print(f" Min: {min(rubric_counts)}")
|
| 587 |
+
print(f" Max: {max(rubric_counts)}")
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
if __name__ == "__main__":
|
| 591 |
+
main()
|
| 592 |
+
|
| 593 |
+
"""
|
| 594 |
+
Example usage:
|
| 595 |
+
python 1_generate_review_based_rubrics.py \
|
| 596 |
+
--json_path ./examples/input.json \
|
| 597 |
+
--output_path eval_rubrics.json \
|
| 598 |
+
--yaml_path prompts.yaml \
|
| 599 |
+
--config_path configs.yaml \
|
| 600 |
+
--max_workers 5
|
| 601 |
+
"""
|
src/evaluator/2_evaluate.py
ADDED
|
@@ -0,0 +1,1730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
|
| 6 |
+
2. Reads input JSON file containing model reviews (supports multiple formats)
|
| 7 |
+
3. Supports three evaluation modes:
|
| 8 |
+
- semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
|
| 9 |
+
- auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
|
| 10 |
+
- both: Run both evaluations separately
|
| 11 |
+
4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
|
| 12 |
+
5. Outputs separate JSON files for results and summaries
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Semantic evaluation only
|
| 16 |
+
python 2_evaluate.py \
|
| 17 |
+
--rubrics_path eval_rubrics.json \
|
| 18 |
+
--reviews_path model_reviews.json \
|
| 19 |
+
--mode semantic \
|
| 20 |
+
--yaml_path prompts.yaml \
|
| 21 |
+
--config_path configs.yaml \
|
| 22 |
+
--semantic_output semantic_results.json \
|
| 23 |
+
--max_workers 5
|
| 24 |
+
|
| 25 |
+
# Auto-metric evaluation only
|
| 26 |
+
python 2_evaluate.py \
|
| 27 |
+
--rubrics_path eval_rubrics.json \
|
| 28 |
+
--reviews_path model_reviews.json \
|
| 29 |
+
--mode auto_metric \
|
| 30 |
+
--auto_metric_output auto_metric_results.json
|
| 31 |
+
|
| 32 |
+
# Auto-metric evaluation with strict mode (normalize scores to discrete scales)
|
| 33 |
+
python 2_evaluate.py \
|
| 34 |
+
--rubrics_path eval_rubrics.json \
|
| 35 |
+
--reviews_path model_reviews.json \
|
| 36 |
+
--mode auto_metric \
|
| 37 |
+
--auto_metric_output auto_metric_results.json \
|
| 38 |
+
--strict_mode
|
| 39 |
+
|
| 40 |
+
# Auto-metric evaluation with manually specified input format (refined)
|
| 41 |
+
python 2_evaluate.py \
|
| 42 |
+
--rubrics_path eval_rubrics.json \
|
| 43 |
+
--reviews_path model_reviews.json \
|
| 44 |
+
--mode auto_metric \
|
| 45 |
+
--auto_metric_output auto_metric_results.json \
|
| 46 |
+
--input_format refined
|
| 47 |
+
|
| 48 |
+
# Auto-metric evaluation with manually specified input format (original)
|
| 49 |
+
python 2_evaluate.py \
|
| 50 |
+
--rubrics_path eval_rubrics.json \
|
| 51 |
+
--reviews_path ours.json \
|
| 52 |
+
--mode auto_metric \
|
| 53 |
+
--auto_metric_output auto_metric_results.json \
|
| 54 |
+
--input_format original
|
| 55 |
+
|
| 56 |
+
# Both evaluations
|
| 57 |
+
python 2_evaluate.py \
|
| 58 |
+
--rubrics_path eval_rubrics.json \
|
| 59 |
+
--reviews_path model_reviews.json \
|
| 60 |
+
--mode both \
|
| 61 |
+
--yaml_path prompts.yaml \
|
| 62 |
+
--config_path configs.yaml \
|
| 63 |
+
--semantic_output semantic_results.json \
|
| 64 |
+
--auto_metric_output auto_metric_results.json \
|
| 65 |
+
--max_workers 32
|
| 66 |
+
"""
|
| 67 |
+
from __future__ import annotations
|
| 68 |
+
|
| 69 |
+
import json
|
| 70 |
+
import os
|
| 71 |
+
import sys
|
| 72 |
+
import argparse
|
| 73 |
+
import yaml
|
| 74 |
+
import math
|
| 75 |
+
from typing import Dict, List, Any, Optional
|
| 76 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 77 |
+
from tqdm import tqdm
|
| 78 |
+
from itertools import combinations
|
| 79 |
+
from scipy.stats import spearmanr
|
| 80 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 81 |
+
|
| 82 |
+
# Add parent directory to path
|
| 83 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 84 |
+
# Import parse_llm_response from local llm_service module
|
| 85 |
+
import llm_service as local_llm_service
|
| 86 |
+
parse_llm_response = local_llm_service.parse_llm_response
|
| 87 |
+
|
| 88 |
+
# Import from shared/utils for gpt/vllm support
|
| 89 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 90 |
+
if project_root not in sys.path:
|
| 91 |
+
sys.path.insert(0, project_root)
|
| 92 |
+
|
| 93 |
+
from shared.utils.llm_service import LLMService
|
| 94 |
+
from shared.utils.vllm_service import VLLMService
|
| 95 |
+
from shared.utils.gpt_service import GPTService
|
| 96 |
+
sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
|
| 97 |
+
from json_parser import parse_review_markdown
|
| 98 |
+
|
| 99 |
+
class ReviewProcessor:
|
| 100 |
+
"""Handles the extraction and processing of reviews from different sources."""
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def extract_review_content(pred_context):
|
| 104 |
+
"""
|
| 105 |
+
Extract the review content from the prediction context.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
pred_context: Raw prediction data that contains the review
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
str: Extracted review content
|
| 112 |
+
"""
|
| 113 |
+
try:
|
| 114 |
+
# First attempt to extract from boxed format
|
| 115 |
+
return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 116 |
+
except Exception:
|
| 117 |
+
# Alternative extraction if the first method fails
|
| 118 |
+
if isinstance(pred_context, dict) and 'output' in pred_context:
|
| 119 |
+
return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 120 |
+
else:
|
| 121 |
+
# Return as is if extraction fails
|
| 122 |
+
return pred_context
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ============================================================================
|
| 126 |
+
# Semantic Evaluation Functions (from 2_evaluate_direct.py)
|
| 127 |
+
# ============================================================================
|
| 128 |
+
|
| 129 |
+
def load_prompt_template(yaml_path: str) -> str:
|
| 130 |
+
"""Load the evaluator prompt from YAML file."""
|
| 131 |
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
| 132 |
+
prompts = yaml.safe_load(f)
|
| 133 |
+
return prompts.get('v1_evaluator_prompt', '')
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def build_evaluation_prompt(
|
| 137 |
+
rubrics: List[Dict[str, Any]],
|
| 138 |
+
paper_content: str,
|
| 139 |
+
review: str,
|
| 140 |
+
prompt_template: str
|
| 141 |
+
) -> str:
|
| 142 |
+
"""Build the evaluation prompt by replacing placeholders."""
|
| 143 |
+
rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
|
| 144 |
+
prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
|
| 145 |
+
prompt = prompt.replace('<<paper_content>>', paper_content)
|
| 146 |
+
prompt = prompt.replace('<<review>>', review)
|
| 147 |
+
return prompt
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def calculate_weighted_scores(
|
| 151 |
+
raw_scores: Dict[str, Dict[str, Any]],
|
| 152 |
+
rubrics: List[Dict[str, Any]]
|
| 153 |
+
) -> Dict[str, float]:
|
| 154 |
+
"""Calculate weighted scores for each rubric."""
|
| 155 |
+
rubric_weights = {r['title']: r['weight'] for r in rubrics}
|
| 156 |
+
weighted_scores = {}
|
| 157 |
+
|
| 158 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 159 |
+
if rubric_title not in rubric_weights:
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
rubric_score = rubric_data.get('score', 0)
|
| 163 |
+
if isinstance(rubric_score, str):
|
| 164 |
+
try:
|
| 165 |
+
rubric_score = int(rubric_score)
|
| 166 |
+
except ValueError:
|
| 167 |
+
rubric_score = 0
|
| 168 |
+
|
| 169 |
+
if rubric_score not in [0, 1]:
|
| 170 |
+
rubric_score = 1 if rubric_score > 0 else 0
|
| 171 |
+
|
| 172 |
+
weight = rubric_weights[rubric_title]
|
| 173 |
+
weighted_scores[rubric_title] = rubric_score * weight
|
| 174 |
+
|
| 175 |
+
return weighted_scores
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
|
| 179 |
+
"""Calculate scores for each rubric."""
|
| 180 |
+
scores = {}
|
| 181 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 182 |
+
scores[rubric_title] = rubric_data.get('score', 0)
|
| 183 |
+
return scores
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def evaluate_review_semantic(
|
| 187 |
+
entry: Dict[str, Any],
|
| 188 |
+
paper_content: str,
|
| 189 |
+
prompt_template: str,
|
| 190 |
+
llm_service: LLMService
|
| 191 |
+
) -> Dict[str, Any]:
|
| 192 |
+
"""Evaluate a single review using article-specific rubrics."""
|
| 193 |
+
entry_id = entry.get('id', 'unknown')
|
| 194 |
+
rubrics = entry.get('rubrics', [])
|
| 195 |
+
model_review = entry.get('model_review', '')
|
| 196 |
+
|
| 197 |
+
if not rubrics:
|
| 198 |
+
return {
|
| 199 |
+
'id': entry_id,
|
| 200 |
+
'raw_scores': {},
|
| 201 |
+
'weighted_scores': {},
|
| 202 |
+
'total_score': 0.0,
|
| 203 |
+
'error': 'No valid rubrics found',
|
| 204 |
+
'raw_response': ''
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# Build prompt
|
| 208 |
+
prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
|
| 209 |
+
|
| 210 |
+
# Call LLM
|
| 211 |
+
try:
|
| 212 |
+
messages = [{"role": "user", "content": prompt}]
|
| 213 |
+
response = llm_service.generate(messages=messages)
|
| 214 |
+
|
| 215 |
+
# Parse response
|
| 216 |
+
raw_scores = parse_llm_response(response)
|
| 217 |
+
weighted_scores = calculate_scores(raw_scores)
|
| 218 |
+
total_score = sum(weighted_scores.values())
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
'id': entry_id,
|
| 222 |
+
'raw_scores': raw_scores,
|
| 223 |
+
'weighted_scores': weighted_scores,
|
| 224 |
+
'total_score': total_score,
|
| 225 |
+
'raw_response': response
|
| 226 |
+
}
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"[ERROR] Error evaluating review {entry_id}: {e}")
|
| 229 |
+
return {
|
| 230 |
+
'id': entry_id,
|
| 231 |
+
'raw_scores': {},
|
| 232 |
+
'weighted_scores': {},
|
| 233 |
+
'total_score': 0.0,
|
| 234 |
+
'error': str(e),
|
| 235 |
+
'raw_response': ''
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def calculate_per_rubric_statistics(
|
| 240 |
+
valid_results: List[Dict[str, Any]],
|
| 241 |
+
rubric_titles: List[str]
|
| 242 |
+
) -> Dict[str, Dict[str, float]]:
|
| 243 |
+
"""Calculate per-rubric statistics from evaluation results."""
|
| 244 |
+
rubric_scores = {title: [] for title in rubric_titles}
|
| 245 |
+
|
| 246 |
+
for result in valid_results:
|
| 247 |
+
weighted_scores = result.get('weighted_scores', {})
|
| 248 |
+
if not isinstance(weighted_scores, dict):
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
for rubric_title in rubric_titles:
|
| 252 |
+
if rubric_title in weighted_scores:
|
| 253 |
+
score = weighted_scores[rubric_title]
|
| 254 |
+
if isinstance(score, str):
|
| 255 |
+
try:
|
| 256 |
+
score = float(score)
|
| 257 |
+
except ValueError:
|
| 258 |
+
continue
|
| 259 |
+
elif isinstance(score, (int, float)):
|
| 260 |
+
score = float(score)
|
| 261 |
+
else:
|
| 262 |
+
continue
|
| 263 |
+
rubric_scores[rubric_title].append(score)
|
| 264 |
+
|
| 265 |
+
per_rubric_stats = {}
|
| 266 |
+
for rubric_title in rubric_titles:
|
| 267 |
+
scores = rubric_scores[rubric_title]
|
| 268 |
+
if not scores:
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
mean_score = sum(scores) / len(scores)
|
| 272 |
+
min_score = min(scores)
|
| 273 |
+
max_score = max(scores)
|
| 274 |
+
count = len(scores)
|
| 275 |
+
|
| 276 |
+
if rubric_title == "False or Contradictory Claims":
|
| 277 |
+
pass_count = sum(1 for s in scores if s >= 0)
|
| 278 |
+
else:
|
| 279 |
+
pass_count = sum(1 for s in scores if s >= 1)
|
| 280 |
+
pass_rate = pass_count / count if count > 0 else 0.0
|
| 281 |
+
|
| 282 |
+
per_rubric_stats[rubric_title] = {
|
| 283 |
+
'mean': mean_score,
|
| 284 |
+
'min': min_score,
|
| 285 |
+
'max': max_score,
|
| 286 |
+
'count': count,
|
| 287 |
+
'pass_rate': pass_rate
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
return per_rubric_stats
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ============================================================================
|
| 294 |
+
# Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
|
| 295 |
+
# ============================================================================
|
| 296 |
+
|
| 297 |
+
def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
|
| 298 |
+
"""Extract numeric scores and decision from a review markdown text."""
|
| 299 |
+
if not review_text:
|
| 300 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
parsed = parse_review_markdown(review_text)
|
| 304 |
+
decision = parsed.get('decision', '')
|
| 305 |
+
if decision:
|
| 306 |
+
decision_lower = decision.lower().strip()
|
| 307 |
+
if 'accept' in decision_lower:
|
| 308 |
+
decision = 'accept'
|
| 309 |
+
elif 'reject' in decision_lower:
|
| 310 |
+
decision = 'reject'
|
| 311 |
+
elif 'undecided' in decision_lower:
|
| 312 |
+
decision = 'undecided'
|
| 313 |
+
else:
|
| 314 |
+
decision = decision_lower
|
| 315 |
+
else:
|
| 316 |
+
decision = None
|
| 317 |
+
|
| 318 |
+
return {
|
| 319 |
+
'soundness': parsed.get('soundness'),
|
| 320 |
+
'presentation': parsed.get('presentation'),
|
| 321 |
+
'rating': parsed.get('rating'),
|
| 322 |
+
'confidence': parsed.get('confidence'),
|
| 323 |
+
'decision': decision
|
| 324 |
+
}
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"Warning: Failed to parse review text: {e}")
|
| 327 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
|
| 331 |
+
"""Calculate Mean Squared Error for a single value."""
|
| 332 |
+
if predicted is None or ground_truth is None:
|
| 333 |
+
return None
|
| 334 |
+
return (predicted - ground_truth) ** 2
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
|
| 338 |
+
"""Calculate Mean Absolute Error for a single value."""
|
| 339 |
+
if predicted is None or ground_truth is None:
|
| 340 |
+
return None
|
| 341 |
+
return abs(predicted - ground_truth)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
|
| 345 |
+
"""
|
| 346 |
+
Normalize a float score to the nearest discrete value based on scale type.
|
| 347 |
+
Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
score: The float score to normalize (can be None)
|
| 351 |
+
scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
|
| 352 |
+
or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
Normalized discrete score, or None if input is None
|
| 356 |
+
"""
|
| 357 |
+
if score is None:
|
| 358 |
+
return None
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
score = float(score)
|
| 362 |
+
except (ValueError, TypeError):
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
if scale_type == '0-5':
|
| 366 |
+
# Discrete values: 0, 1, 2, 3, 4, 5
|
| 367 |
+
discrete_values = [0, 1, 2, 3, 4, 5]
|
| 368 |
+
# Clamp to valid range
|
| 369 |
+
score = max(0, min(5, score))
|
| 370 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 371 |
+
# For ties, prefer the higher value
|
| 372 |
+
best_value = None
|
| 373 |
+
best_distance = float('inf')
|
| 374 |
+
for val in discrete_values:
|
| 375 |
+
distance = abs(val - score)
|
| 376 |
+
if distance < best_distance:
|
| 377 |
+
best_distance = distance
|
| 378 |
+
best_value = val
|
| 379 |
+
elif distance == best_distance and val > best_value:
|
| 380 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 381 |
+
best_value = val
|
| 382 |
+
return best_value
|
| 383 |
+
elif scale_type == '0-10':
|
| 384 |
+
# Discrete values: 0, 2, 4, 6, 8, 10
|
| 385 |
+
discrete_values = [0, 2, 4, 6, 8, 10]
|
| 386 |
+
# Clamp to valid range
|
| 387 |
+
score = max(0, min(10, score))
|
| 388 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 389 |
+
best_value = None
|
| 390 |
+
best_distance = float('inf')
|
| 391 |
+
for val in discrete_values:
|
| 392 |
+
distance = abs(val - score)
|
| 393 |
+
if distance < best_distance:
|
| 394 |
+
best_distance = distance
|
| 395 |
+
best_value = val
|
| 396 |
+
elif distance == best_distance and val > best_value:
|
| 397 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 398 |
+
best_value = val
|
| 399 |
+
return best_value
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
|
| 405 |
+
"""
|
| 406 |
+
Normalize all scores in a dictionary to their appropriate discrete scales.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Dictionary with normalized scores
|
| 413 |
+
"""
|
| 414 |
+
normalized = {}
|
| 415 |
+
|
| 416 |
+
# soundness, presentation, confidence use 0-5 scale
|
| 417 |
+
for key in ['soundness', 'presentation', 'confidence']:
|
| 418 |
+
normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
|
| 419 |
+
|
| 420 |
+
# rating uses 0-10 scale
|
| 421 |
+
normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
|
| 422 |
+
|
| 423 |
+
return normalized
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def calculate_score_metrics(
|
| 427 |
+
model_scores: Dict[str, float],
|
| 428 |
+
ground_truth_scores: Dict[str, float],
|
| 429 |
+
normalize: bool = False
|
| 430 |
+
) -> Dict[str, Any]:
|
| 431 |
+
"""
|
| 432 |
+
Calculate MSE and MAE metrics for each scoring dimension.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
model_scores: Dictionary with model scores
|
| 436 |
+
ground_truth_scores: Dictionary with ground truth scores
|
| 437 |
+
normalize: If True, normalize scores to discrete scales before computing metrics
|
| 438 |
+
|
| 439 |
+
Returns:
|
| 440 |
+
Dictionary with MSE, MAE metrics and optionally normalized scores
|
| 441 |
+
"""
|
| 442 |
+
dimensions = ['soundness', 'presentation', 'rating', 'confidence']
|
| 443 |
+
|
| 444 |
+
# Normalize scores to discrete scales if requested
|
| 445 |
+
if normalize:
|
| 446 |
+
model_scores_normalized = normalize_scores_dict(model_scores)
|
| 447 |
+
gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
|
| 448 |
+
else:
|
| 449 |
+
model_scores_normalized = model_scores
|
| 450 |
+
gt_scores_normalized = ground_truth_scores
|
| 451 |
+
|
| 452 |
+
mse_values = {}
|
| 453 |
+
mae_values = {}
|
| 454 |
+
valid_count = 0
|
| 455 |
+
|
| 456 |
+
for dim in dimensions:
|
| 457 |
+
# Use normalized scores for metric calculation
|
| 458 |
+
mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 459 |
+
mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 460 |
+
mse_values[f'{dim}_mse'] = mse
|
| 461 |
+
mae_values[f'{dim}_mae'] = mae
|
| 462 |
+
if mse is not None:
|
| 463 |
+
valid_count += 1
|
| 464 |
+
|
| 465 |
+
overall_error = sum([v for v in mse_values.values() if v is not None])
|
| 466 |
+
|
| 467 |
+
result = {
|
| 468 |
+
**mse_values,
|
| 469 |
+
**mae_values,
|
| 470 |
+
'overall_error': overall_error if valid_count > 0 else None,
|
| 471 |
+
'valid_dimensions': valid_count
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
# Include normalized scores in result for transparency (only if normalize=True)
|
| 475 |
+
if normalize:
|
| 476 |
+
result['model_scores_normalized'] = model_scores_normalized
|
| 477 |
+
result['gt_scores_normalized'] = gt_scores_normalized
|
| 478 |
+
|
| 479 |
+
return result
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def normalize_score_value(value):
|
| 483 |
+
"""Normalize score value to float, handling string representations."""
|
| 484 |
+
if value is None:
|
| 485 |
+
return None
|
| 486 |
+
if isinstance(value, (int, float)):
|
| 487 |
+
return float(value)
|
| 488 |
+
if isinstance(value, str):
|
| 489 |
+
# Try to extract numeric value from string (e.g., "2.75" -> 2.75)
|
| 490 |
+
try:
|
| 491 |
+
import re
|
| 492 |
+
match = re.search(r'(\d+\.?\d*)', value)
|
| 493 |
+
if match:
|
| 494 |
+
return float(match.group(1))
|
| 495 |
+
except:
|
| 496 |
+
pass
|
| 497 |
+
return None
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def normalize_decision(decision):
|
| 501 |
+
"""Normalize decision string to standard format."""
|
| 502 |
+
if decision is None:
|
| 503 |
+
return None
|
| 504 |
+
decision_lower = str(decision).lower().strip()
|
| 505 |
+
if 'accept' in decision_lower:
|
| 506 |
+
return 'accept'
|
| 507 |
+
elif 'reject' in decision_lower:
|
| 508 |
+
return 'reject'
|
| 509 |
+
elif 'undecided' in decision_lower:
|
| 510 |
+
return 'undecided'
|
| 511 |
+
else:
|
| 512 |
+
return decision_lower
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 516 |
+
"""
|
| 517 |
+
Extract scores from a structured dictionary (scores or initial_scores format).
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
|
| 524 |
+
"""
|
| 525 |
+
if not scores_dict:
|
| 526 |
+
return {
|
| 527 |
+
'soundness': None,
|
| 528 |
+
'presentation': None,
|
| 529 |
+
'rating': None,
|
| 530 |
+
'confidence': None,
|
| 531 |
+
'decision': None
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
return {
|
| 535 |
+
'soundness': normalize_score_value(scores_dict.get('soundness')),
|
| 536 |
+
'presentation': normalize_score_value(scores_dict.get('presentation')),
|
| 537 |
+
'rating': normalize_score_value(scores_dict.get('rating')),
|
| 538 |
+
'confidence': normalize_score_value(scores_dict.get('confidence')),
|
| 539 |
+
'decision': normalize_decision(scores_dict.get('decision'))
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
|
| 544 |
+
"""
|
| 545 |
+
Evaluate a single entry by extracting scores and calculating metrics.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
entry: Evaluation entry containing model_review, scores, initial_scores, etc.
|
| 549 |
+
use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
|
| 550 |
+
|
| 551 |
+
Returns:
|
| 552 |
+
Dict containing evaluation metrics
|
| 553 |
+
"""
|
| 554 |
+
entry_id = entry.get('id', 'unknown')
|
| 555 |
+
model_review = entry.get('model_review', '')
|
| 556 |
+
format_type = entry.get('format', 'unknown')
|
| 557 |
+
|
| 558 |
+
# Extract scores based on format
|
| 559 |
+
model_scores = {}
|
| 560 |
+
model_decision = None
|
| 561 |
+
|
| 562 |
+
if format_type == 'refined' and not use_initial_scores:
|
| 563 |
+
# Use refined scores from structured data
|
| 564 |
+
scores_dict = entry.get('scores', {})
|
| 565 |
+
model_data = extract_scores_from_dict(scores_dict)
|
| 566 |
+
model_scores = {
|
| 567 |
+
'soundness': model_data.get('soundness'),
|
| 568 |
+
'presentation': model_data.get('presentation'),
|
| 569 |
+
'rating': model_data.get('rating'),
|
| 570 |
+
'confidence': model_data.get('confidence')
|
| 571 |
+
}
|
| 572 |
+
model_decision = model_data.get('decision')
|
| 573 |
+
elif format_type == 'refined' and use_initial_scores:
|
| 574 |
+
# Use initial scores from structured data
|
| 575 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 576 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 577 |
+
model_scores = {
|
| 578 |
+
'soundness': model_data.get('soundness'),
|
| 579 |
+
'presentation': model_data.get('presentation'),
|
| 580 |
+
'rating': model_data.get('rating'),
|
| 581 |
+
'confidence': model_data.get('confidence')
|
| 582 |
+
}
|
| 583 |
+
model_decision = model_data.get('decision')
|
| 584 |
+
elif format_type == 'original':
|
| 585 |
+
# Use initial scores from structured data
|
| 586 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 587 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 588 |
+
model_scores = {
|
| 589 |
+
'soundness': model_data.get('soundness'),
|
| 590 |
+
'presentation': model_data.get('presentation'),
|
| 591 |
+
'rating': model_data.get('rating'),
|
| 592 |
+
'confidence': model_data.get('confidence')
|
| 593 |
+
}
|
| 594 |
+
model_decision = model_data.get('decision')
|
| 595 |
+
|
| 596 |
+
# Fallback: If confidence is missing from structured data, try to extract from review text
|
| 597 |
+
# (meta_review may not have confidence field, but review text might)
|
| 598 |
+
if model_scores.get('confidence') is None and model_review:
|
| 599 |
+
try:
|
| 600 |
+
review_data = extract_scores_from_review(model_review)
|
| 601 |
+
if review_data.get('confidence') is not None:
|
| 602 |
+
model_scores['confidence'] = review_data.get('confidence')
|
| 603 |
+
except Exception:
|
| 604 |
+
pass # Keep confidence as None if extraction fails
|
| 605 |
+
else:
|
| 606 |
+
# Fallback: extract from markdown review text
|
| 607 |
+
model_data = extract_scores_from_review(model_review)
|
| 608 |
+
model_scores = {
|
| 609 |
+
'soundness': model_data.get('soundness'),
|
| 610 |
+
'presentation': model_data.get('presentation'),
|
| 611 |
+
'rating': model_data.get('rating'),
|
| 612 |
+
'confidence': model_data.get('confidence')
|
| 613 |
+
}
|
| 614 |
+
model_decision = model_data.get('decision')
|
| 615 |
+
|
| 616 |
+
# Get ground truth scores from golden_review ONLY
|
| 617 |
+
# Ground truth must ONLY come from golden_review, never from model output
|
| 618 |
+
# If extraction fails, leave fields as None (do not use model_review as fallback)
|
| 619 |
+
ground_truth_review = entry.get('golden_review', '')
|
| 620 |
+
ground_truth_scores = {}
|
| 621 |
+
gt_decision = None
|
| 622 |
+
|
| 623 |
+
if not ground_truth_review:
|
| 624 |
+
print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
|
| 625 |
+
else:
|
| 626 |
+
try:
|
| 627 |
+
# Extract scores from golden_review markdown text
|
| 628 |
+
gt_data = extract_scores_from_review(ground_truth_review)
|
| 629 |
+
if not gt_data:
|
| 630 |
+
print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
|
| 631 |
+
else:
|
| 632 |
+
ground_truth_scores = {
|
| 633 |
+
'soundness': gt_data.get('soundness'),
|
| 634 |
+
'presentation': gt_data.get('presentation'),
|
| 635 |
+
'rating': gt_data.get('rating'),
|
| 636 |
+
'confidence': gt_data.get('confidence')
|
| 637 |
+
}
|
| 638 |
+
gt_decision = normalize_decision(gt_data.get('decision'))
|
| 639 |
+
# Note: If any field is None, it stays None - we do NOT use model_review as fallback
|
| 640 |
+
# Using model output as ground truth would inflate evaluation scores
|
| 641 |
+
except Exception as e:
|
| 642 |
+
print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
|
| 643 |
+
print(f" Ground truth scores will be empty. Error: {str(e)}")
|
| 644 |
+
|
| 645 |
+
# Calculate MSE and MAE metrics (with optional normalization in strict mode)
|
| 646 |
+
score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
|
| 647 |
+
|
| 648 |
+
# Calculate decision accuracy
|
| 649 |
+
decision_match = False
|
| 650 |
+
decision_accuracy = None
|
| 651 |
+
if model_decision is not None and gt_decision is not None:
|
| 652 |
+
model_decision_normalized = normalize_decision(model_decision)
|
| 653 |
+
decision_match = (model_decision_normalized == gt_decision)
|
| 654 |
+
decision_accuracy = 1.0 if decision_match else 0.0
|
| 655 |
+
|
| 656 |
+
result = {
|
| 657 |
+
'id': entry_id,
|
| 658 |
+
'format': format_type,
|
| 659 |
+
'model_soundness': model_scores.get('soundness'),
|
| 660 |
+
'model_presentation': model_scores.get('presentation'),
|
| 661 |
+
'model_rating': model_scores.get('rating'),
|
| 662 |
+
'model_confidence': model_scores.get('confidence'),
|
| 663 |
+
'model_decision': model_decision,
|
| 664 |
+
'gt_soundness': ground_truth_scores.get('soundness'),
|
| 665 |
+
'gt_presentation': ground_truth_scores.get('presentation'),
|
| 666 |
+
'gt_rating': ground_truth_scores.get('rating'),
|
| 667 |
+
'gt_confidence': ground_truth_scores.get('confidence'),
|
| 668 |
+
'gt_decision': gt_decision,
|
| 669 |
+
'decision_match': decision_match,
|
| 670 |
+
'decision_accuracy': decision_accuracy,
|
| 671 |
+
**score_metrics
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
# Add prefix to indicate which scores were used
|
| 675 |
+
if format_type == 'refined':
|
| 676 |
+
if use_initial_scores:
|
| 677 |
+
result['score_type'] = 'initial'
|
| 678 |
+
else:
|
| 679 |
+
result['score_type'] = 'refined'
|
| 680 |
+
else:
|
| 681 |
+
result['score_type'] = 'auto'
|
| 682 |
+
|
| 683 |
+
return result
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
|
| 687 |
+
"""Calculate pairwise accuracy for each metric by comparing rankings."""
|
| 688 |
+
if len(paper_scores) < 2:
|
| 689 |
+
return {}
|
| 690 |
+
|
| 691 |
+
total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 692 |
+
correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 693 |
+
|
| 694 |
+
for paper1, paper2 in combinations(paper_scores, 2):
|
| 695 |
+
# Check rating ranking
|
| 696 |
+
if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
|
| 697 |
+
paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
|
| 698 |
+
total_valid_pairs['rating'] += 1
|
| 699 |
+
true_order = paper1['true_rating'] > paper2['true_rating']
|
| 700 |
+
pred_order = paper1['pred_rating'] > paper2['pred_rating']
|
| 701 |
+
if true_order == pred_order:
|
| 702 |
+
correct_pairs['rating'] += 1
|
| 703 |
+
|
| 704 |
+
# Similar for other dimensions...
|
| 705 |
+
# (abbreviated for space, similar logic for soundness, presentation, confidence)
|
| 706 |
+
for metric in ['soundness', 'presentation', 'confidence']:
|
| 707 |
+
true_key = f'true_{metric}'
|
| 708 |
+
pred_key = f'pred_{metric}'
|
| 709 |
+
if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
|
| 710 |
+
paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
|
| 711 |
+
total_valid_pairs[metric] += 1
|
| 712 |
+
true_order = paper1[true_key] > paper2[true_key]
|
| 713 |
+
pred_order = paper1[pred_key] > paper2[pred_key]
|
| 714 |
+
if true_order == pred_order:
|
| 715 |
+
correct_pairs[metric] += 1
|
| 716 |
+
|
| 717 |
+
pairwise_accuracies = {
|
| 718 |
+
metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
|
| 719 |
+
for metric in ['rating', 'soundness', 'presentation', 'confidence']
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
return pairwise_accuracies
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
# ============================================================================
|
| 726 |
+
# Data Loading Functions
|
| 727 |
+
# ============================================================================
|
| 728 |
+
|
| 729 |
+
def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
|
| 730 |
+
"""Load rubrics JSON and create lookup by id."""
|
| 731 |
+
with open(rubrics_path, 'r', encoding='utf-8') as f:
|
| 732 |
+
data = json.load(f)
|
| 733 |
+
|
| 734 |
+
if isinstance(data, list):
|
| 735 |
+
return {item['id']: item for item in data}
|
| 736 |
+
elif isinstance(data, dict):
|
| 737 |
+
return data
|
| 738 |
+
else:
|
| 739 |
+
raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
| 743 |
+
"""
|
| 744 |
+
Load model reviews JSON and extract reviews by id.
|
| 745 |
+
|
| 746 |
+
Supports two input formats:
|
| 747 |
+
1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
|
| 748 |
+
2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
reviews_path: Path to JSON file containing model reviews
|
| 752 |
+
format_override: Optional format override ('refined', 'original', or None for auto-detect)
|
| 753 |
+
|
| 754 |
+
Returns:
|
| 755 |
+
Dict mapping paper_id to dict containing:
|
| 756 |
+
- 'review': review text (markdown)
|
| 757 |
+
- 'scores': refined scores dict (if available)
|
| 758 |
+
- 'initial_scores': initial scores dict (if available)
|
| 759 |
+
- 'format': 'refined' or 'original'
|
| 760 |
+
"""
|
| 761 |
+
with open(reviews_path, 'r', encoding='utf-8') as f:
|
| 762 |
+
data = json.load(f)
|
| 763 |
+
|
| 764 |
+
if isinstance(data, dict):
|
| 765 |
+
data = list(data.values())
|
| 766 |
+
|
| 767 |
+
reviews_dict = {}
|
| 768 |
+
for item in data:
|
| 769 |
+
item_id = None
|
| 770 |
+
review_text = ''
|
| 771 |
+
scores = None
|
| 772 |
+
initial_scores = None
|
| 773 |
+
format_type = None
|
| 774 |
+
|
| 775 |
+
# Use format override if provided, otherwise auto-detect
|
| 776 |
+
if format_override and format_override != 'auto':
|
| 777 |
+
# Force use specified format
|
| 778 |
+
if format_override == 'refined':
|
| 779 |
+
item_id = item.get('paper_id') or item.get('id')
|
| 780 |
+
if not item_id:
|
| 781 |
+
continue
|
| 782 |
+
format_type = 'refined'
|
| 783 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 784 |
+
scores = item.get('scores', {})
|
| 785 |
+
initial_scores = item.get('initial_scores', {})
|
| 786 |
+
elif format_override == 'original':
|
| 787 |
+
item_id = item.get('id')
|
| 788 |
+
if not item_id:
|
| 789 |
+
continue
|
| 790 |
+
format_type = 'original'
|
| 791 |
+
model_prediction = item.get('model_prediction', {})
|
| 792 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 793 |
+
review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 794 |
+
initial_scores = {
|
| 795 |
+
'rating': meta_review.get('rating'),
|
| 796 |
+
'soundness': meta_review.get('soundness'),
|
| 797 |
+
'presentation': meta_review.get('presentation'),
|
| 798 |
+
'contribution': meta_review.get('contribution'),
|
| 799 |
+
'decision': model_prediction.get('decision'),
|
| 800 |
+
}
|
| 801 |
+
else:
|
| 802 |
+
raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
|
| 803 |
+
else:
|
| 804 |
+
# Auto-detect format
|
| 805 |
+
if "paper_id" in item:
|
| 806 |
+
# Refined format (from refinement pipeline)
|
| 807 |
+
item_id = item.get('paper_id')
|
| 808 |
+
if not item_id:
|
| 809 |
+
continue
|
| 810 |
+
|
| 811 |
+
# Check if this is refined format (has scores and initial_scores)
|
| 812 |
+
if 'scores' in item and 'initial_scores' in item:
|
| 813 |
+
format_type = 'refined'
|
| 814 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 815 |
+
scores = item.get('scores', {})
|
| 816 |
+
initial_scores = item.get('initial_scores', {})
|
| 817 |
+
else:
|
| 818 |
+
# Standard format with paper_id
|
| 819 |
+
format_type = 'standard'
|
| 820 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 821 |
+
elif "model_prediction" in item:
|
| 822 |
+
# Original format (like ours.json)
|
| 823 |
+
item_id = item.get('id')
|
| 824 |
+
if not item_id:
|
| 825 |
+
continue
|
| 826 |
+
|
| 827 |
+
format_type = 'original'
|
| 828 |
+
model_prediction = item.get('model_prediction', {})
|
| 829 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 830 |
+
|
| 831 |
+
# Extract review content (prefer meta_review.content, fallback to raw_text)
|
| 832 |
+
review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 833 |
+
|
| 834 |
+
# Extract initial scores
|
| 835 |
+
initial_scores = {
|
| 836 |
+
'rating': meta_review.get('rating'),
|
| 837 |
+
'soundness': meta_review.get('soundness'),
|
| 838 |
+
'presentation': meta_review.get('presentation'),
|
| 839 |
+
'contribution': meta_review.get('contribution'),
|
| 840 |
+
'decision': model_prediction.get('decision'),
|
| 841 |
+
}
|
| 842 |
+
else:
|
| 843 |
+
# Legacy format (pred_fast_mode)
|
| 844 |
+
item_id = item.get('id')
|
| 845 |
+
if not item_id:
|
| 846 |
+
continue
|
| 847 |
+
|
| 848 |
+
format_type = 'legacy'
|
| 849 |
+
review_dict = item.get('pred_fast_mode', {})
|
| 850 |
+
if isinstance(review_dict, dict):
|
| 851 |
+
# review_text = review_dict.get('raw_text', '')
|
| 852 |
+
review_text = review_dict
|
| 853 |
+
else:
|
| 854 |
+
review_text = str(review_dict)
|
| 855 |
+
|
| 856 |
+
# Extract review content from the review text field
|
| 857 |
+
try:
|
| 858 |
+
if review_text:
|
| 859 |
+
extracted_review = ReviewProcessor.extract_review_content(review_text)
|
| 860 |
+
else:
|
| 861 |
+
extracted_review = ''
|
| 862 |
+
|
| 863 |
+
reviews_dict[item_id] = {
|
| 864 |
+
'review': extracted_review,
|
| 865 |
+
'scores': scores,
|
| 866 |
+
'initial_scores': initial_scores,
|
| 867 |
+
'format': format_type
|
| 868 |
+
}
|
| 869 |
+
except Exception as e:
|
| 870 |
+
print(f"[WARN] Failed to extract review for {item_id}: {e}")
|
| 871 |
+
continue
|
| 872 |
+
|
| 873 |
+
return reviews_dict
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
def combine_rubrics_and_reviews(
|
| 877 |
+
rubrics_data: Dict[str, Dict[str, Any]],
|
| 878 |
+
reviews_dict: Dict[str, Dict[str, Any]]
|
| 879 |
+
) -> List[Dict[str, Any]]:
|
| 880 |
+
"""
|
| 881 |
+
Combine rubrics and reviews into evaluation entries.
|
| 882 |
+
|
| 883 |
+
Args:
|
| 884 |
+
rubrics_data: Dict mapping paper_id to rubric entry
|
| 885 |
+
reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
|
| 886 |
+
|
| 887 |
+
Returns:
|
| 888 |
+
List of evaluation entries with model_review, scores, initial_scores, and format info
|
| 889 |
+
"""
|
| 890 |
+
combined = []
|
| 891 |
+
missing_reviews = []
|
| 892 |
+
|
| 893 |
+
for paper_id, rubric_entry in rubrics_data.items():
|
| 894 |
+
review_data = reviews_dict.get(paper_id)
|
| 895 |
+
if not review_data or not review_data.get('review'):
|
| 896 |
+
missing_reviews.append(paper_id)
|
| 897 |
+
continue
|
| 898 |
+
|
| 899 |
+
entry = {
|
| 900 |
+
'id': paper_id,
|
| 901 |
+
'paper_context': rubric_entry.get('paper_context', ''),
|
| 902 |
+
'decision': rubric_entry.get('decision', ''),
|
| 903 |
+
'golden_review': rubric_entry.get('golden_review', ''),
|
| 904 |
+
'rubrics': rubric_entry.get('rubrics', []),
|
| 905 |
+
'model_review': review_data.get('review', ''),
|
| 906 |
+
'scores': review_data.get('scores'), # Refined scores (if available)
|
| 907 |
+
'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
|
| 908 |
+
'format': review_data.get('format', 'unknown') # Format type
|
| 909 |
+
}
|
| 910 |
+
combined.append(entry)
|
| 911 |
+
|
| 912 |
+
if missing_reviews:
|
| 913 |
+
print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
|
| 914 |
+
|
| 915 |
+
return combined
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
# ============================================================================
|
| 919 |
+
# LLM Service Configuration
|
| 920 |
+
# ============================================================================
|
| 921 |
+
|
| 922 |
+
def load_llm_config(config_path: str) -> Dict[str, Any]:
|
| 923 |
+
"""Load LLM configuration from YAML file."""
|
| 924 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 925 |
+
config = yaml.safe_load(f)
|
| 926 |
+
return config
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
|
| 930 |
+
"""Create LLM service from configuration."""
|
| 931 |
+
mode = config.get('mode', 'gpt').lower()
|
| 932 |
+
|
| 933 |
+
if mode == 'gpt':
|
| 934 |
+
gpt_config = config.get('gpt', {})
|
| 935 |
+
api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
|
| 936 |
+
if not api_key:
|
| 937 |
+
raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
|
| 938 |
+
|
| 939 |
+
service = GPTService(
|
| 940 |
+
api_key=api_key,
|
| 941 |
+
model_name=gpt_config.get('model_name', 'gpt-4o'),
|
| 942 |
+
base_url=gpt_config.get('base_url'),
|
| 943 |
+
timeout=gpt_config.get('timeout', 300)
|
| 944 |
+
)
|
| 945 |
+
return service
|
| 946 |
+
|
| 947 |
+
elif mode == 'vllm':
|
| 948 |
+
vllm_config = config.get('vllm', {})
|
| 949 |
+
service = VLLMService(
|
| 950 |
+
base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
|
| 951 |
+
api_key=vllm_config.get('api_key', 'dummy-key'),
|
| 952 |
+
model_name=vllm_config.get('model_name'),
|
| 953 |
+
timeout=vllm_config.get('timeout', 300),
|
| 954 |
+
max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
|
| 955 |
+
max_retries=vllm_config.get('max_retries', 3),
|
| 956 |
+
retry_delay=vllm_config.get('retry_delay', 1.0),
|
| 957 |
+
retry_backoff=vllm_config.get('retry_backoff', 2.0)
|
| 958 |
+
)
|
| 959 |
+
return service
|
| 960 |
+
|
| 961 |
+
else:
|
| 962 |
+
raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
# ============================================================================
|
| 966 |
+
# Main Evaluation Functions
|
| 967 |
+
# ============================================================================
|
| 968 |
+
|
| 969 |
+
def run_semantic_evaluation(
|
| 970 |
+
evaluation_data: List[Dict[str, Any]],
|
| 971 |
+
prompt_template: str,
|
| 972 |
+
llm_service: LLMService,
|
| 973 |
+
max_workers: int
|
| 974 |
+
) -> tuple:
|
| 975 |
+
"""Run semantic evaluation and return results and summary."""
|
| 976 |
+
print(f"\n{'='*80}")
|
| 977 |
+
print("RUNNING SEMANTIC EVALUATION")
|
| 978 |
+
print(f"{'='*80}")
|
| 979 |
+
print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
|
| 980 |
+
|
| 981 |
+
results = []
|
| 982 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 983 |
+
future_to_entry = {
|
| 984 |
+
executor.submit(
|
| 985 |
+
evaluate_review_semantic,
|
| 986 |
+
entry,
|
| 987 |
+
entry['paper_context'],
|
| 988 |
+
prompt_template,
|
| 989 |
+
llm_service
|
| 990 |
+
): entry
|
| 991 |
+
for entry in evaluation_data
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
|
| 995 |
+
try:
|
| 996 |
+
result = future.result()
|
| 997 |
+
results.append(result)
|
| 998 |
+
except Exception as e:
|
| 999 |
+
entry = future_to_entry[future]
|
| 1000 |
+
print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
|
| 1001 |
+
results.append({
|
| 1002 |
+
'id': entry.get('id', 'unknown'),
|
| 1003 |
+
'raw_scores': {},
|
| 1004 |
+
'weighted_scores': {},
|
| 1005 |
+
'total_score': 0.0,
|
| 1006 |
+
'error': str(e),
|
| 1007 |
+
'raw_response': ''
|
| 1008 |
+
})
|
| 1009 |
+
|
| 1010 |
+
# Calculate statistics
|
| 1011 |
+
valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
|
| 1012 |
+
review_scores = [r.get('total_score', 0.0) for r in valid_results]
|
| 1013 |
+
|
| 1014 |
+
summary = {
|
| 1015 |
+
'total_entries': len(results),
|
| 1016 |
+
'valid_entries': len(valid_results),
|
| 1017 |
+
'failed_entries': len(results) - len(valid_results)
|
| 1018 |
+
}
|
| 1019 |
+
|
| 1020 |
+
if review_scores:
|
| 1021 |
+
summary['overall_score'] = {
|
| 1022 |
+
'mean': sum(review_scores) / len(review_scores),
|
| 1023 |
+
'min': min(review_scores),
|
| 1024 |
+
'max': max(review_scores)
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
# Calculate per-rubric statistics (extract rubric titles from first entry)
|
| 1028 |
+
if evaluation_data and evaluation_data[0].get('rubrics'):
|
| 1029 |
+
rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
|
| 1030 |
+
per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
|
| 1031 |
+
summary['per_rubric_statistics'] = per_rubric_stats
|
| 1032 |
+
|
| 1033 |
+
return results, summary
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def run_auto_metric_evaluation(
|
| 1037 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1038 |
+
strict_mode: bool = False
|
| 1039 |
+
) -> tuple:
|
| 1040 |
+
"""
|
| 1041 |
+
Run auto-metric evaluation and return results and summary.
|
| 1042 |
+
|
| 1043 |
+
For refined format (has scores and initial_scores), evaluates both:
|
| 1044 |
+
- Refined scores evaluation
|
| 1045 |
+
- Initial scores evaluation
|
| 1046 |
+
|
| 1047 |
+
For original format (only initial_scores), evaluates:
|
| 1048 |
+
- Initial scores evaluation only
|
| 1049 |
+
|
| 1050 |
+
Returns:
|
| 1051 |
+
Tuple of (results_list, summary_dict)
|
| 1052 |
+
- results_list: List of evaluation results (may contain both refined and initial results for refined format)
|
| 1053 |
+
- summary_dict: Summary statistics
|
| 1054 |
+
"""
|
| 1055 |
+
print(f"\n{'='*80}")
|
| 1056 |
+
print("RUNNING AUTO-METRIC EVALUATION")
|
| 1057 |
+
print(f"{'='*80}")
|
| 1058 |
+
print(f"Evaluating {len(evaluation_data)} entries...")
|
| 1059 |
+
|
| 1060 |
+
# Detect format types
|
| 1061 |
+
refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
|
| 1062 |
+
original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
|
| 1063 |
+
|
| 1064 |
+
if refined_format_count > 0:
|
| 1065 |
+
print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
|
| 1066 |
+
if original_format_count > 0:
|
| 1067 |
+
print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
|
| 1068 |
+
|
| 1069 |
+
results = []
|
| 1070 |
+
for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
|
| 1071 |
+
format_type = entry.get('format', 'unknown')
|
| 1072 |
+
|
| 1073 |
+
if format_type == 'refined':
|
| 1074 |
+
# Evaluate both refined scores and initial scores
|
| 1075 |
+
try:
|
| 1076 |
+
entry_id = entry.get('id', 'unknown')
|
| 1077 |
+
|
| 1078 |
+
# Evaluate refined scores
|
| 1079 |
+
refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1080 |
+
refined_result['paper_id'] = entry_id # Keep original paper_id
|
| 1081 |
+
refined_result['id'] = f"{entry_id}_refined"
|
| 1082 |
+
results.append(refined_result)
|
| 1083 |
+
|
| 1084 |
+
# Evaluate initial scores
|
| 1085 |
+
initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
|
| 1086 |
+
initial_result['paper_id'] = entry_id # Keep original paper_id
|
| 1087 |
+
initial_result['id'] = f"{entry_id}_initial"
|
| 1088 |
+
results.append(initial_result)
|
| 1089 |
+
except Exception as e:
|
| 1090 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1091 |
+
results.append({
|
| 1092 |
+
'id': entry.get('id', 'unknown'),
|
| 1093 |
+
'error': str(e)
|
| 1094 |
+
})
|
| 1095 |
+
else:
|
| 1096 |
+
# Evaluate initial scores only (or extract from markdown)
|
| 1097 |
+
try:
|
| 1098 |
+
result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1099 |
+
results.append(result)
|
| 1100 |
+
except Exception as e:
|
| 1101 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1102 |
+
results.append({
|
| 1103 |
+
'id': entry.get('id', 'unknown'),
|
| 1104 |
+
'error': str(e)
|
| 1105 |
+
})
|
| 1106 |
+
|
| 1107 |
+
# Calculate statistics
|
| 1108 |
+
valid_results = [r for r in results if 'error' not in r]
|
| 1109 |
+
mse_results = [r for r in valid_results if r.get('overall_error') is not None]
|
| 1110 |
+
|
| 1111 |
+
# Separate refined and initial results for refined format
|
| 1112 |
+
refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
|
| 1113 |
+
initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
|
| 1114 |
+
auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
|
| 1115 |
+
|
| 1116 |
+
summary = {
|
| 1117 |
+
'total_entries': len(results),
|
| 1118 |
+
'valid_entries': len(valid_results),
|
| 1119 |
+
'mse_entries': len(mse_results),
|
| 1120 |
+
'refined_results_count': len(refined_results),
|
| 1121 |
+
'initial_results_count': len(initial_results),
|
| 1122 |
+
'auto_results_count': len(auto_results)
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
# Calculate MSE/MAE statistics
|
| 1126 |
+
# For refined format, only use refined results for overall statistics (avoid double counting)
|
| 1127 |
+
# For other formats, use all results
|
| 1128 |
+
if refined_format_count > 0:
|
| 1129 |
+
# Refined format: use only refined results for overall statistics
|
| 1130 |
+
stats_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1131 |
+
else:
|
| 1132 |
+
# Original/other formats: use all results
|
| 1133 |
+
stats_results = mse_results
|
| 1134 |
+
|
| 1135 |
+
if stats_results:
|
| 1136 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1137 |
+
mse_stats = {}
|
| 1138 |
+
mae_stats = {}
|
| 1139 |
+
|
| 1140 |
+
for dim in dimensions:
|
| 1141 |
+
mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
|
| 1142 |
+
mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
|
| 1143 |
+
|
| 1144 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1145 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1146 |
+
|
| 1147 |
+
if mse_clean:
|
| 1148 |
+
mse_stats[dim] = {
|
| 1149 |
+
'mean': sum(mse_clean) / len(mse_clean),
|
| 1150 |
+
'count': len(mse_clean)
|
| 1151 |
+
}
|
| 1152 |
+
if mae_clean:
|
| 1153 |
+
mae_stats[dim] = {
|
| 1154 |
+
'mean': sum(mae_clean) / len(mae_clean),
|
| 1155 |
+
'count': len(mae_clean)
|
| 1156 |
+
}
|
| 1157 |
+
|
| 1158 |
+
overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
|
| 1159 |
+
overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1160 |
+
|
| 1161 |
+
if overall_clean:
|
| 1162 |
+
summary['overall_error'] = {
|
| 1163 |
+
'mean': sum(overall_clean) / len(overall_clean),
|
| 1164 |
+
'count': len(overall_clean)
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
summary['mse_statistics'] = mse_stats
|
| 1168 |
+
summary['mae_statistics'] = mae_stats
|
| 1169 |
+
|
| 1170 |
+
# Calculate separate statistics for refined and initial results
|
| 1171 |
+
if refined_results:
|
| 1172 |
+
refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1173 |
+
if refined_mse_results:
|
| 1174 |
+
refined_mse_stats = {}
|
| 1175 |
+
refined_mae_stats = {}
|
| 1176 |
+
for dim in dimensions:
|
| 1177 |
+
mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1178 |
+
mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1179 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1180 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1181 |
+
if mse_clean:
|
| 1182 |
+
refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1183 |
+
if mae_clean:
|
| 1184 |
+
refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1185 |
+
summary['refined_mse_statistics'] = refined_mse_stats
|
| 1186 |
+
summary['refined_mae_statistics'] = refined_mae_stats
|
| 1187 |
+
|
| 1188 |
+
if initial_results:
|
| 1189 |
+
initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
|
| 1190 |
+
if initial_mse_results:
|
| 1191 |
+
initial_mse_stats = {}
|
| 1192 |
+
initial_mae_stats = {}
|
| 1193 |
+
for dim in dimensions:
|
| 1194 |
+
mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1195 |
+
mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1196 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1197 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1198 |
+
if mse_clean:
|
| 1199 |
+
initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1200 |
+
if mae_clean:
|
| 1201 |
+
initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1202 |
+
summary['initial_mse_statistics'] = initial_mse_stats
|
| 1203 |
+
summary['initial_mae_statistics'] = initial_mae_stats
|
| 1204 |
+
|
| 1205 |
+
# Calculate Spearman correlations
|
| 1206 |
+
def filter_valid_pairs(true_list, pred_list):
|
| 1207 |
+
filtered_true = []
|
| 1208 |
+
filtered_pred = []
|
| 1209 |
+
for t, p in zip(true_list, pred_list):
|
| 1210 |
+
if (t is not None and p is not None and
|
| 1211 |
+
not (isinstance(t, float) and math.isnan(t)) and
|
| 1212 |
+
not (isinstance(p, float) and math.isnan(p))):
|
| 1213 |
+
filtered_true.append(t)
|
| 1214 |
+
filtered_pred.append(p)
|
| 1215 |
+
return filtered_true, filtered_pred
|
| 1216 |
+
|
| 1217 |
+
# Calculate Spearman correlations
|
| 1218 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1219 |
+
# For other formats, use all results
|
| 1220 |
+
if refined_format_count > 0:
|
| 1221 |
+
# Calculate refined spearman correlations
|
| 1222 |
+
refined_spearman_stats = {}
|
| 1223 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1224 |
+
for dim in dimensions:
|
| 1225 |
+
true_values = [r.get(f'gt_{dim}') for r in refined_results]
|
| 1226 |
+
pred_values = [r.get(f'model_{dim}') for r in refined_results]
|
| 1227 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1228 |
+
|
| 1229 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1230 |
+
try:
|
| 1231 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1232 |
+
if not math.isnan(corr):
|
| 1233 |
+
refined_spearman_stats[dim] = {
|
| 1234 |
+
'correlation': corr,
|
| 1235 |
+
'count': len(true_clean)
|
| 1236 |
+
}
|
| 1237 |
+
except Exception:
|
| 1238 |
+
pass
|
| 1239 |
+
|
| 1240 |
+
# Calculate initial spearman correlations
|
| 1241 |
+
initial_spearman_stats = {}
|
| 1242 |
+
for dim in dimensions:
|
| 1243 |
+
true_values = [r.get(f'gt_{dim}') for r in initial_results]
|
| 1244 |
+
pred_values = [r.get(f'model_{dim}') for r in initial_results]
|
| 1245 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1246 |
+
|
| 1247 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1248 |
+
try:
|
| 1249 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1250 |
+
if not math.isnan(corr):
|
| 1251 |
+
initial_spearman_stats[dim] = {
|
| 1252 |
+
'correlation': corr,
|
| 1253 |
+
'count': len(true_clean)
|
| 1254 |
+
}
|
| 1255 |
+
except Exception:
|
| 1256 |
+
pass
|
| 1257 |
+
|
| 1258 |
+
# Use refined for overall statistics (avoid double counting)
|
| 1259 |
+
summary['spearman_correlations'] = refined_spearman_stats
|
| 1260 |
+
summary['refined_spearman_correlations'] = refined_spearman_stats
|
| 1261 |
+
summary['initial_spearman_correlations'] = initial_spearman_stats
|
| 1262 |
+
else:
|
| 1263 |
+
# Original/other formats: use all results
|
| 1264 |
+
correlation_results = valid_results
|
| 1265 |
+
spearman_stats = {}
|
| 1266 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1267 |
+
for dim in dimensions:
|
| 1268 |
+
true_values = [r.get(f'gt_{dim}') for r in correlation_results]
|
| 1269 |
+
pred_values = [r.get(f'model_{dim}') for r in correlation_results]
|
| 1270 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1271 |
+
|
| 1272 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1273 |
+
try:
|
| 1274 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1275 |
+
if not math.isnan(corr):
|
| 1276 |
+
spearman_stats[dim] = {
|
| 1277 |
+
'correlation': corr,
|
| 1278 |
+
'count': len(true_clean)
|
| 1279 |
+
}
|
| 1280 |
+
except Exception:
|
| 1281 |
+
pass
|
| 1282 |
+
|
| 1283 |
+
summary['spearman_correlations'] = spearman_stats
|
| 1284 |
+
|
| 1285 |
+
# Calculate Decision metrics
|
| 1286 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1287 |
+
# For other formats, use all results
|
| 1288 |
+
if refined_format_count > 0:
|
| 1289 |
+
# Calculate refined decision metrics
|
| 1290 |
+
refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1291 |
+
if refined_decision_results:
|
| 1292 |
+
true_decisions = []
|
| 1293 |
+
pred_decisions = []
|
| 1294 |
+
decision_acc = []
|
| 1295 |
+
|
| 1296 |
+
for r in refined_decision_results:
|
| 1297 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1298 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1299 |
+
|
| 1300 |
+
if 'accept' in pred_decision:
|
| 1301 |
+
pred_binary = 1
|
| 1302 |
+
else:
|
| 1303 |
+
pred_binary = 0
|
| 1304 |
+
|
| 1305 |
+
if 'accept' in gt_decision:
|
| 1306 |
+
gt_binary = 1
|
| 1307 |
+
else:
|
| 1308 |
+
gt_binary = 0
|
| 1309 |
+
|
| 1310 |
+
true_decisions.append(gt_binary)
|
| 1311 |
+
pred_decisions.append(pred_binary)
|
| 1312 |
+
|
| 1313 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1314 |
+
decision_acc.append(1.0)
|
| 1315 |
+
else:
|
| 1316 |
+
decision_acc.append(0.0)
|
| 1317 |
+
|
| 1318 |
+
if decision_acc:
|
| 1319 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1320 |
+
try:
|
| 1321 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1322 |
+
refined_decision_metrics = {
|
| 1323 |
+
'accuracy': decision_accuracy,
|
| 1324 |
+
'f1_macro': f1_score,
|
| 1325 |
+
'count': len(decision_acc)
|
| 1326 |
+
}
|
| 1327 |
+
except Exception:
|
| 1328 |
+
refined_decision_metrics = {
|
| 1329 |
+
'accuracy': decision_accuracy,
|
| 1330 |
+
'count': len(decision_acc)
|
| 1331 |
+
}
|
| 1332 |
+
summary['refined_decision_metrics'] = refined_decision_metrics
|
| 1333 |
+
summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
|
| 1334 |
+
|
| 1335 |
+
# Calculate initial decision metrics
|
| 1336 |
+
initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1337 |
+
if initial_decision_results:
|
| 1338 |
+
true_decisions = []
|
| 1339 |
+
pred_decisions = []
|
| 1340 |
+
decision_acc = []
|
| 1341 |
+
|
| 1342 |
+
for r in initial_decision_results:
|
| 1343 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1344 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1345 |
+
|
| 1346 |
+
if 'accept' in pred_decision:
|
| 1347 |
+
pred_binary = 1
|
| 1348 |
+
else:
|
| 1349 |
+
pred_binary = 0
|
| 1350 |
+
|
| 1351 |
+
if 'accept' in gt_decision:
|
| 1352 |
+
gt_binary = 1
|
| 1353 |
+
else:
|
| 1354 |
+
gt_binary = 0
|
| 1355 |
+
|
| 1356 |
+
true_decisions.append(gt_binary)
|
| 1357 |
+
pred_decisions.append(pred_binary)
|
| 1358 |
+
|
| 1359 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1360 |
+
decision_acc.append(1.0)
|
| 1361 |
+
else:
|
| 1362 |
+
decision_acc.append(0.0)
|
| 1363 |
+
|
| 1364 |
+
if decision_acc:
|
| 1365 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1366 |
+
try:
|
| 1367 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1368 |
+
initial_decision_metrics = {
|
| 1369 |
+
'accuracy': decision_accuracy,
|
| 1370 |
+
'f1_macro': f1_score,
|
| 1371 |
+
'count': len(decision_acc)
|
| 1372 |
+
}
|
| 1373 |
+
except Exception:
|
| 1374 |
+
initial_decision_metrics = {
|
| 1375 |
+
'accuracy': decision_accuracy,
|
| 1376 |
+
'count': len(decision_acc)
|
| 1377 |
+
}
|
| 1378 |
+
summary['initial_decision_metrics'] = initial_decision_metrics
|
| 1379 |
+
else:
|
| 1380 |
+
# Original/other formats: use all results
|
| 1381 |
+
decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1382 |
+
if decision_results:
|
| 1383 |
+
true_decisions = []
|
| 1384 |
+
pred_decisions = []
|
| 1385 |
+
decision_acc = []
|
| 1386 |
+
|
| 1387 |
+
for r in decision_results:
|
| 1388 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1389 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1390 |
+
|
| 1391 |
+
if 'accept' in pred_decision:
|
| 1392 |
+
pred_binary = 1
|
| 1393 |
+
else:
|
| 1394 |
+
pred_binary = 0
|
| 1395 |
+
|
| 1396 |
+
if 'accept' in gt_decision:
|
| 1397 |
+
gt_binary = 1
|
| 1398 |
+
else:
|
| 1399 |
+
gt_binary = 0
|
| 1400 |
+
|
| 1401 |
+
true_decisions.append(gt_binary)
|
| 1402 |
+
pred_decisions.append(pred_binary)
|
| 1403 |
+
|
| 1404 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1405 |
+
decision_acc.append(1.0)
|
| 1406 |
+
else:
|
| 1407 |
+
decision_acc.append(0.0)
|
| 1408 |
+
|
| 1409 |
+
if decision_acc:
|
| 1410 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1411 |
+
try:
|
| 1412 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1413 |
+
summary['decision_metrics'] = {
|
| 1414 |
+
'accuracy': decision_accuracy,
|
| 1415 |
+
'f1_macro': f1_score,
|
| 1416 |
+
'count': len(decision_acc)
|
| 1417 |
+
}
|
| 1418 |
+
except Exception:
|
| 1419 |
+
summary['decision_metrics'] = {
|
| 1420 |
+
'accuracy': decision_accuracy,
|
| 1421 |
+
'count': len(decision_acc)
|
| 1422 |
+
}
|
| 1423 |
+
|
| 1424 |
+
# Calculate Pairwise comparison
|
| 1425 |
+
# For refined format, only use refined results (avoid double counting)
|
| 1426 |
+
# For other formats, use all results
|
| 1427 |
+
if refined_format_count > 0:
|
| 1428 |
+
pairwise_results = refined_results
|
| 1429 |
+
else:
|
| 1430 |
+
pairwise_results = valid_results
|
| 1431 |
+
|
| 1432 |
+
paper_scores = []
|
| 1433 |
+
for r in pairwise_results:
|
| 1434 |
+
if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
|
| 1435 |
+
(r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
|
| 1436 |
+
paper_scores.append({
|
| 1437 |
+
'true_rating': r.get('gt_rating'),
|
| 1438 |
+
'pred_rating': r.get('model_rating'),
|
| 1439 |
+
'true_soundness': r.get('gt_soundness'),
|
| 1440 |
+
'pred_soundness': r.get('model_soundness'),
|
| 1441 |
+
'true_presentation': r.get('gt_presentation'),
|
| 1442 |
+
'pred_presentation': r.get('model_presentation'),
|
| 1443 |
+
'true_confidence': r.get('gt_confidence'),
|
| 1444 |
+
'pred_confidence': r.get('model_confidence')
|
| 1445 |
+
})
|
| 1446 |
+
|
| 1447 |
+
if len(paper_scores) >= 2:
|
| 1448 |
+
pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
|
| 1449 |
+
summary['pairwise_accuracies'] = pairwise_accuracies
|
| 1450 |
+
|
| 1451 |
+
return results, summary
|
| 1452 |
+
|
| 1453 |
+
|
| 1454 |
+
# ============================================================================
|
| 1455 |
+
# Main Function
|
| 1456 |
+
# ============================================================================
|
| 1457 |
+
|
| 1458 |
+
def parse_args():
|
| 1459 |
+
"""Parse command line arguments."""
|
| 1460 |
+
parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
|
| 1461 |
+
|
| 1462 |
+
# Input paths
|
| 1463 |
+
parser.add_argument("--rubrics_path", type=str, required=True,
|
| 1464 |
+
help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
|
| 1465 |
+
parser.add_argument("--reviews_path", type=str, required=True,
|
| 1466 |
+
help="Path to JSON file with model reviews (contains pred_fast_mode)")
|
| 1467 |
+
|
| 1468 |
+
# Evaluation mode
|
| 1469 |
+
parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
|
| 1470 |
+
help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
|
| 1471 |
+
|
| 1472 |
+
# Output paths
|
| 1473 |
+
parser.add_argument("--semantic_output", type=str, default=None,
|
| 1474 |
+
help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
|
| 1475 |
+
parser.add_argument("--auto_metric_output", type=str, default=None,
|
| 1476 |
+
help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
|
| 1477 |
+
|
| 1478 |
+
# Semantic evaluation settings
|
| 1479 |
+
parser.add_argument("--yaml_path", type=str, default=None,
|
| 1480 |
+
help="Path to prompts.yaml file (required for semantic evaluation)")
|
| 1481 |
+
parser.add_argument("--config_path", type=str, default=None,
|
| 1482 |
+
help="Path to configs.yaml file (required for semantic evaluation)")
|
| 1483 |
+
|
| 1484 |
+
# Multi-threading
|
| 1485 |
+
parser.add_argument("--max_workers", type=int, default=None,
|
| 1486 |
+
help="Maximum number of worker threads for semantic evaluation (default: 5)")
|
| 1487 |
+
|
| 1488 |
+
# Strict mode (normalize scores to discrete scales)
|
| 1489 |
+
parser.add_argument("--strict_mode", action="store_true", default=False,
|
| 1490 |
+
help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
|
| 1491 |
+
|
| 1492 |
+
# Input format override
|
| 1493 |
+
parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
|
| 1494 |
+
help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
|
| 1495 |
+
|
| 1496 |
+
return parser.parse_args()
|
| 1497 |
+
|
| 1498 |
+
|
| 1499 |
+
def main():
|
| 1500 |
+
"""Main execution function."""
|
| 1501 |
+
args = parse_args()
|
| 1502 |
+
|
| 1503 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 1504 |
+
|
| 1505 |
+
# Resolve paths
|
| 1506 |
+
rubrics_path = args.rubrics_path
|
| 1507 |
+
if not os.path.isabs(rubrics_path):
|
| 1508 |
+
rubrics_path = os.path.join(script_dir, rubrics_path)
|
| 1509 |
+
|
| 1510 |
+
reviews_path = args.reviews_path
|
| 1511 |
+
if not os.path.isabs(reviews_path):
|
| 1512 |
+
reviews_path = os.path.join(script_dir, reviews_path)
|
| 1513 |
+
|
| 1514 |
+
max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
|
| 1515 |
+
|
| 1516 |
+
# Validate mode and output paths
|
| 1517 |
+
if args.mode in ["semantic", "both"]:
|
| 1518 |
+
if not args.semantic_output:
|
| 1519 |
+
raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
|
| 1520 |
+
if not args.yaml_path:
|
| 1521 |
+
raise ValueError("--yaml_path is required for semantic evaluation")
|
| 1522 |
+
if not args.config_path:
|
| 1523 |
+
raise ValueError("--config_path is required for semantic evaluation")
|
| 1524 |
+
|
| 1525 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1526 |
+
if not args.auto_metric_output:
|
| 1527 |
+
raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
|
| 1528 |
+
|
| 1529 |
+
# Check if files exist
|
| 1530 |
+
if not os.path.exists(rubrics_path):
|
| 1531 |
+
raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
|
| 1532 |
+
if not os.path.exists(reviews_path):
|
| 1533 |
+
raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
|
| 1534 |
+
|
| 1535 |
+
# Load data
|
| 1536 |
+
print(f"Loading rubrics from {rubrics_path}...")
|
| 1537 |
+
rubrics_data = load_rubrics_json(rubrics_path)
|
| 1538 |
+
print(f"Loaded {len(rubrics_data)} rubrics entries")
|
| 1539 |
+
|
| 1540 |
+
print(f"Loading model reviews from {reviews_path}...")
|
| 1541 |
+
if args.input_format != 'auto':
|
| 1542 |
+
print(f"Using manually specified format: {args.input_format}")
|
| 1543 |
+
else:
|
| 1544 |
+
print("Auto-detecting input format...")
|
| 1545 |
+
reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
|
| 1546 |
+
print(f"Loaded {len(reviews_dict)} model reviews")
|
| 1547 |
+
|
| 1548 |
+
# Combine rubrics and reviews
|
| 1549 |
+
print("Combining rubrics and reviews...")
|
| 1550 |
+
evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
|
| 1551 |
+
print(f"Prepared {len(evaluation_data)} entries for evaluation")
|
| 1552 |
+
|
| 1553 |
+
# Run evaluations based on mode
|
| 1554 |
+
if args.mode in ["semantic", "both"]:
|
| 1555 |
+
# Resolve semantic evaluation paths
|
| 1556 |
+
yaml_path = args.yaml_path
|
| 1557 |
+
if not os.path.isabs(yaml_path):
|
| 1558 |
+
yaml_path = os.path.join(script_dir, yaml_path)
|
| 1559 |
+
|
| 1560 |
+
config_path = args.config_path
|
| 1561 |
+
if not os.path.isabs(config_path):
|
| 1562 |
+
config_path = os.path.join(script_dir, config_path)
|
| 1563 |
+
|
| 1564 |
+
if not os.path.exists(yaml_path):
|
| 1565 |
+
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
|
| 1566 |
+
if not os.path.exists(config_path):
|
| 1567 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 1568 |
+
|
| 1569 |
+
# Load prompt template
|
| 1570 |
+
print(f"Loading prompt template from {yaml_path}...")
|
| 1571 |
+
prompt_template = load_prompt_template(yaml_path)
|
| 1572 |
+
if not prompt_template:
|
| 1573 |
+
raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
|
| 1574 |
+
|
| 1575 |
+
# Initialize LLM service
|
| 1576 |
+
print(f"Loading LLM configuration from {config_path}...")
|
| 1577 |
+
llm_config = load_llm_config(config_path)
|
| 1578 |
+
llm_service = create_llm_service_from_config(llm_config)
|
| 1579 |
+
mode = llm_config.get('mode', 'gpt')
|
| 1580 |
+
print(f"LLM service initialized (mode: {mode})")
|
| 1581 |
+
if hasattr(llm_service, 'model_name'):
|
| 1582 |
+
print(f"Using model: {llm_service.model_name}")
|
| 1583 |
+
|
| 1584 |
+
# Run semantic evaluation
|
| 1585 |
+
semantic_results, semantic_summary = run_semantic_evaluation(
|
| 1586 |
+
evaluation_data, prompt_template, llm_service, max_workers
|
| 1587 |
+
)
|
| 1588 |
+
|
| 1589 |
+
# Save semantic results
|
| 1590 |
+
semantic_output = args.semantic_output
|
| 1591 |
+
if not os.path.isabs(semantic_output):
|
| 1592 |
+
semantic_output = os.path.join(script_dir, semantic_output)
|
| 1593 |
+
|
| 1594 |
+
output_dir = os.path.dirname(semantic_output)
|
| 1595 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1596 |
+
|
| 1597 |
+
with open(semantic_output, 'w', encoding='utf-8') as f:
|
| 1598 |
+
json.dump(semantic_results, f, ensure_ascii=False, indent=2)
|
| 1599 |
+
print(f"\nSemantic evaluation results saved to {semantic_output}")
|
| 1600 |
+
|
| 1601 |
+
# Save semantic summary
|
| 1602 |
+
semantic_summary_path = semantic_output.replace('.json', '_summary.json')
|
| 1603 |
+
with open(semantic_summary_path, 'w', encoding='utf-8') as f:
|
| 1604 |
+
json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
|
| 1605 |
+
print(f"Semantic evaluation summary saved to {semantic_summary_path}")
|
| 1606 |
+
|
| 1607 |
+
# Print semantic summary
|
| 1608 |
+
print("\n" + "="*80)
|
| 1609 |
+
print("SEMANTIC EVALUATION SUMMARY")
|
| 1610 |
+
print("="*80)
|
| 1611 |
+
print(f"Total entries: {semantic_summary['total_entries']}")
|
| 1612 |
+
print(f"Valid entries: {semantic_summary['valid_entries']}")
|
| 1613 |
+
print(f"Failed entries: {semantic_summary['failed_entries']}")
|
| 1614 |
+
if 'overall_score' in semantic_summary:
|
| 1615 |
+
score = semantic_summary['overall_score']
|
| 1616 |
+
print(f"\nOverall Score:")
|
| 1617 |
+
print(f" Mean: {score['mean']:.2f}")
|
| 1618 |
+
print(f" Min: {score['min']:.2f}")
|
| 1619 |
+
print(f" Max: {score['max']:.2f}")
|
| 1620 |
+
|
| 1621 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1622 |
+
# Run auto-metric evaluation
|
| 1623 |
+
auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
|
| 1624 |
+
evaluation_data,
|
| 1625 |
+
strict_mode=args.strict_mode
|
| 1626 |
+
)
|
| 1627 |
+
|
| 1628 |
+
# Save auto-metric results
|
| 1629 |
+
auto_metric_output = args.auto_metric_output
|
| 1630 |
+
if not os.path.isabs(auto_metric_output):
|
| 1631 |
+
auto_metric_output = os.path.join(script_dir, auto_metric_output)
|
| 1632 |
+
|
| 1633 |
+
output_dir = os.path.dirname(auto_metric_output)
|
| 1634 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1635 |
+
|
| 1636 |
+
with open(auto_metric_output, 'w', encoding='utf-8') as f:
|
| 1637 |
+
json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
|
| 1638 |
+
print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
|
| 1639 |
+
|
| 1640 |
+
# Save auto-metric summary
|
| 1641 |
+
auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
|
| 1642 |
+
with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
|
| 1643 |
+
json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
|
| 1644 |
+
print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
|
| 1645 |
+
|
| 1646 |
+
# Print auto-metric summary
|
| 1647 |
+
print("\n" + "="*80)
|
| 1648 |
+
print("AUTO-METRIC EVALUATION SUMMARY")
|
| 1649 |
+
print("="*80)
|
| 1650 |
+
print(f"Total entries: {auto_metric_summary['total_entries']}")
|
| 1651 |
+
print(f"Valid entries: {auto_metric_summary['valid_entries']}")
|
| 1652 |
+
print(f"MSE entries: {auto_metric_summary['mse_entries']}")
|
| 1653 |
+
|
| 1654 |
+
if 'mse_statistics' in auto_metric_summary:
|
| 1655 |
+
print("\nMSE Statistics:")
|
| 1656 |
+
for dim, stats in auto_metric_summary['mse_statistics'].items():
|
| 1657 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1658 |
+
|
| 1659 |
+
if 'mae_statistics' in auto_metric_summary:
|
| 1660 |
+
print("\nMAE Statistics:")
|
| 1661 |
+
for dim, stats in auto_metric_summary['mae_statistics'].items():
|
| 1662 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1663 |
+
|
| 1664 |
+
# Print refined and initial statistics if available
|
| 1665 |
+
if 'refined_mse_statistics' in auto_metric_summary:
|
| 1666 |
+
print("\nRefined Scores - MSE Statistics:")
|
| 1667 |
+
for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
|
| 1668 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1669 |
+
|
| 1670 |
+
if 'refined_mae_statistics' in auto_metric_summary:
|
| 1671 |
+
print("\nRefined Scores - MAE Statistics:")
|
| 1672 |
+
for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
|
| 1673 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1674 |
+
|
| 1675 |
+
if 'initial_mse_statistics' in auto_metric_summary:
|
| 1676 |
+
print("\nInitial Scores - MSE Statistics:")
|
| 1677 |
+
for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
|
| 1678 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1679 |
+
|
| 1680 |
+
if 'initial_mae_statistics' in auto_metric_summary:
|
| 1681 |
+
print("\nInitial Scores - MAE Statistics:")
|
| 1682 |
+
for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
|
| 1683 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1684 |
+
|
| 1685 |
+
if 'spearman_correlations' in auto_metric_summary:
|
| 1686 |
+
print("\nSpearman Correlations:")
|
| 1687 |
+
for dim, stats in auto_metric_summary['spearman_correlations'].items():
|
| 1688 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1689 |
+
|
| 1690 |
+
# Print refined and initial spearman correlations if available
|
| 1691 |
+
if 'refined_spearman_correlations' in auto_metric_summary:
|
| 1692 |
+
print("\nRefined Scores - Spearman Correlations:")
|
| 1693 |
+
for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
|
| 1694 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1695 |
+
|
| 1696 |
+
if 'initial_spearman_correlations' in auto_metric_summary:
|
| 1697 |
+
print("\nInitial Scores - Spearman Correlations:")
|
| 1698 |
+
for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
|
| 1699 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1700 |
+
|
| 1701 |
+
if 'decision_metrics' in auto_metric_summary:
|
| 1702 |
+
dm = auto_metric_summary['decision_metrics']
|
| 1703 |
+
print(f"\nDecision Metrics:")
|
| 1704 |
+
print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
|
| 1705 |
+
if 'f1_macro' in dm:
|
| 1706 |
+
print(f" F1 (macro): {dm['f1_macro']:.4f}")
|
| 1707 |
+
|
| 1708 |
+
# Print refined and initial decision metrics if available
|
| 1709 |
+
if 'refined_decision_metrics' in auto_metric_summary:
|
| 1710 |
+
print("\nRefined Scores - Decision Metrics:")
|
| 1711 |
+
rdm = auto_metric_summary['refined_decision_metrics']
|
| 1712 |
+
print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
|
| 1713 |
+
if 'f1_macro' in rdm:
|
| 1714 |
+
print(f" F1 (macro): {rdm['f1_macro']:.4f}")
|
| 1715 |
+
|
| 1716 |
+
if 'initial_decision_metrics' in auto_metric_summary:
|
| 1717 |
+
print("\nInitial Scores - Decision Metrics:")
|
| 1718 |
+
idm = auto_metric_summary['initial_decision_metrics']
|
| 1719 |
+
print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
|
| 1720 |
+
if 'f1_macro' in idm:
|
| 1721 |
+
print(f" F1 (macro): {idm['f1_macro']:.4f}")
|
| 1722 |
+
|
| 1723 |
+
print("\n" + "="*80)
|
| 1724 |
+
print("EVALUATION COMPLETE")
|
| 1725 |
+
print("="*80)
|
| 1726 |
+
|
| 1727 |
+
|
| 1728 |
+
if __name__ == "__main__":
|
| 1729 |
+
main()
|
| 1730 |
+
|
src/evaluator/2_evaluate_agenticreview.py
ADDED
|
@@ -0,0 +1,1866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
|
| 6 |
+
2. Reads input JSON file containing model reviews (supports multiple formats)
|
| 7 |
+
3. Supports three evaluation modes:
|
| 8 |
+
- semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
|
| 9 |
+
- auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
|
| 10 |
+
- both: Run both evaluations separately
|
| 11 |
+
4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
|
| 12 |
+
5. Outputs separate JSON files for results and summaries
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Semantic evaluation only
|
| 16 |
+
python 2_evaluate.py \
|
| 17 |
+
--rubrics_path eval_rubrics.json \
|
| 18 |
+
--reviews_path model_reviews.json \
|
| 19 |
+
--mode semantic \
|
| 20 |
+
--yaml_path prompts.yaml \
|
| 21 |
+
--config_path configs.yaml \
|
| 22 |
+
--semantic_output semantic_results.json \
|
| 23 |
+
--max_workers 5
|
| 24 |
+
|
| 25 |
+
# Auto-metric evaluation only
|
| 26 |
+
python 2_evaluate.py \
|
| 27 |
+
--rubrics_path eval_rubrics.json \
|
| 28 |
+
--reviews_path model_reviews.json \
|
| 29 |
+
--mode auto_metric \
|
| 30 |
+
--auto_metric_output auto_metric_results.json
|
| 31 |
+
|
| 32 |
+
# Auto-metric evaluation with strict mode (normalize scores to discrete scales)
|
| 33 |
+
python 2_evaluate.py \
|
| 34 |
+
--rubrics_path eval_rubrics.json \
|
| 35 |
+
--reviews_path model_reviews.json \
|
| 36 |
+
--mode auto_metric \
|
| 37 |
+
--auto_metric_output auto_metric_results.json \
|
| 38 |
+
--strict_mode
|
| 39 |
+
|
| 40 |
+
# Auto-metric evaluation with manually specified input format (refined)
|
| 41 |
+
python 2_evaluate.py \
|
| 42 |
+
--rubrics_path eval_rubrics.json \
|
| 43 |
+
--reviews_path model_reviews.json \
|
| 44 |
+
--mode auto_metric \
|
| 45 |
+
--auto_metric_output auto_metric_results.json \
|
| 46 |
+
--input_format refined
|
| 47 |
+
|
| 48 |
+
# Auto-metric evaluation with manually specified input format (original)
|
| 49 |
+
python 2_evaluate.py \
|
| 50 |
+
--rubrics_path eval_rubrics.json \
|
| 51 |
+
--reviews_path ours.json \
|
| 52 |
+
--mode auto_metric \
|
| 53 |
+
--auto_metric_output auto_metric_results.json \
|
| 54 |
+
--input_format original
|
| 55 |
+
|
| 56 |
+
# Both evaluations
|
| 57 |
+
python 2_evaluate.py \
|
| 58 |
+
--rubrics_path eval_rubrics.json \
|
| 59 |
+
--reviews_path model_reviews.json \
|
| 60 |
+
--mode both \
|
| 61 |
+
--yaml_path prompts.yaml \
|
| 62 |
+
--config_path configs.yaml \
|
| 63 |
+
--semantic_output semantic_results.json \
|
| 64 |
+
--auto_metric_output auto_metric_results.json \
|
| 65 |
+
--max_workers 32
|
| 66 |
+
"""
|
| 67 |
+
from __future__ import annotations
|
| 68 |
+
|
| 69 |
+
import json
|
| 70 |
+
import os
|
| 71 |
+
import sys
|
| 72 |
+
import argparse
|
| 73 |
+
import yaml
|
| 74 |
+
import math
|
| 75 |
+
import re
|
| 76 |
+
from typing import Dict, List, Any, Optional
|
| 77 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 78 |
+
from tqdm import tqdm
|
| 79 |
+
from itertools import combinations
|
| 80 |
+
from scipy.stats import spearmanr
|
| 81 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 82 |
+
|
| 83 |
+
# Add parent directory to path
|
| 84 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 85 |
+
# Import parse_llm_response from local llm_service module
|
| 86 |
+
import llm_service as local_llm_service
|
| 87 |
+
parse_llm_response = local_llm_service.parse_llm_response
|
| 88 |
+
|
| 89 |
+
# Import from shared/utils for gpt/vllm support
|
| 90 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 91 |
+
if project_root not in sys.path:
|
| 92 |
+
sys.path.insert(0, project_root)
|
| 93 |
+
|
| 94 |
+
from shared.utils.llm_service import LLMService
|
| 95 |
+
from shared.utils.vllm_service import VLLMService
|
| 96 |
+
from shared.utils.gpt_service import GPTService
|
| 97 |
+
sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
|
| 98 |
+
from json_parser import parse_review_markdown
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def convert_ai_researcher(review: dict) -> str:
|
| 102 |
+
"""
|
| 103 |
+
Convert the review text from ai-researcher format to unified review system format.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
summary = review["Summary"]
|
| 107 |
+
strengths = "\n".join(f"- {s}" for s in review["Strengths"])
|
| 108 |
+
weaknesses = "\n".join(f"- {w}" for w in review["Weaknesses"])
|
| 109 |
+
|
| 110 |
+
# scores
|
| 111 |
+
originality = review["Originality"]
|
| 112 |
+
quality = review["Quality"]
|
| 113 |
+
clarity = review["Clarity"]
|
| 114 |
+
significance = review["Significance"]
|
| 115 |
+
|
| 116 |
+
questions = "\n".join(f"- {q}" for q in review["Questions"])
|
| 117 |
+
limitations = "\n".join(f"- {l}" for l in review["Limitations"])
|
| 118 |
+
ethical_concerns = review["Ethical Concerns"]
|
| 119 |
+
|
| 120 |
+
# scores again
|
| 121 |
+
soundness = review["Soundness"]
|
| 122 |
+
presentation = review["Presentation"]
|
| 123 |
+
contribution = review["Contribution"]
|
| 124 |
+
overall = review["Overall"]
|
| 125 |
+
confidence = review["Confidence"]
|
| 126 |
+
|
| 127 |
+
# final decision
|
| 128 |
+
decision = review["Decision"]
|
| 129 |
+
|
| 130 |
+
meta_review = {
|
| 131 |
+
"rating": overall,
|
| 132 |
+
"soundness": soundness,
|
| 133 |
+
"presentation": presentation,
|
| 134 |
+
"contribution": contribution,
|
| 135 |
+
"confidence": confidence,
|
| 136 |
+
"decision": decision.lower().strip(),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return f"Summary: {summary}\nStrengths: {strengths}\nWeaknesses: {weaknesses}\nOriginality: {originality}\nQuality: {quality}\nClarity: {clarity}\nSignificance: {significance}\nQuestions: {questions}\nLimitations: {limitations}\nEthical Concerns: {ethical_concerns}\nSoundness: {soundness}\nPresentation: {presentation}\nContribution: {contribution}\nOverall: {overall}\nConfidence: {confidence}\nDecision: {decision}", meta_review
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def convert_agenticreview(review_text: str) -> tuple:
|
| 143 |
+
"""
|
| 144 |
+
Convert the review text from agenticreview format to unified review system format.
|
| 145 |
+
|
| 146 |
+
The agenticreview format has text like:
|
| 147 |
+
"Overall rating: 5\n\nSignificance and novelty: ..."
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
review_text: Raw review text string
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Tuple of (formatted_review_text, meta_review_dict)
|
| 154 |
+
"""
|
| 155 |
+
# Extract rating from "Overall rating: x" format
|
| 156 |
+
rating = None
|
| 157 |
+
rating_match = re.search(r'Overall\s+rating\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
|
| 158 |
+
if rating_match:
|
| 159 |
+
try:
|
| 160 |
+
rating = float(rating_match.group(1))
|
| 161 |
+
except (ValueError, IndexError):
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
# If not found, try alternative patterns
|
| 165 |
+
if rating is None:
|
| 166 |
+
rating_match = re.search(r'(?:rating|score)\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
|
| 167 |
+
if rating_match:
|
| 168 |
+
try:
|
| 169 |
+
rating = float(rating_match.group(1))
|
| 170 |
+
except (ValueError, IndexError):
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
# Try to extract from parse_review_markdown as fallback
|
| 174 |
+
if rating is None:
|
| 175 |
+
try:
|
| 176 |
+
parsed = parse_review_markdown(review_text)
|
| 177 |
+
rating = parsed.get('rating')
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
# Create meta_review dict - agenticreview only has rating, no other scores
|
| 182 |
+
meta_review = {
|
| 183 |
+
"rating": rating,
|
| 184 |
+
"soundness": None,
|
| 185 |
+
"presentation": None,
|
| 186 |
+
"contribution": None,
|
| 187 |
+
"confidence": None,
|
| 188 |
+
"decision": None,
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Return the review text as-is (it's already in a readable format)
|
| 192 |
+
return review_text, meta_review
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ReviewProcessor:
|
| 196 |
+
"""Handles the extraction and processing of reviews from different sources."""
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def extract_review_content(pred_context):
|
| 200 |
+
"""
|
| 201 |
+
Extract the review content from the prediction context.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
pred_context: Raw prediction data that contains the review
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
str: Extracted review content
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# First attempt to extract from boxed format
|
| 211 |
+
return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 212 |
+
except Exception:
|
| 213 |
+
# Alternative extraction if the first method fails
|
| 214 |
+
if isinstance(pred_context, dict) and 'output' in pred_context:
|
| 215 |
+
return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 216 |
+
else:
|
| 217 |
+
# Return as is if extraction fails
|
| 218 |
+
return pred_context
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ============================================================================
|
| 222 |
+
# Semantic Evaluation Functions (from 2_evaluate_direct.py)
|
| 223 |
+
# ============================================================================
|
| 224 |
+
|
| 225 |
+
def load_prompt_template(yaml_path: str) -> str:
|
| 226 |
+
"""Load the evaluator prompt from YAML file."""
|
| 227 |
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
| 228 |
+
prompts = yaml.safe_load(f)
|
| 229 |
+
return prompts.get('v1_evaluator_prompt', '')
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def build_evaluation_prompt(
|
| 233 |
+
rubrics: List[Dict[str, Any]],
|
| 234 |
+
paper_content: str,
|
| 235 |
+
review: str,
|
| 236 |
+
prompt_template: str
|
| 237 |
+
) -> str:
|
| 238 |
+
"""Build the evaluation prompt by replacing placeholders."""
|
| 239 |
+
rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
|
| 240 |
+
prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
|
| 241 |
+
prompt = prompt.replace('<<paper_content>>', paper_content)
|
| 242 |
+
prompt = prompt.replace('<<review>>', review)
|
| 243 |
+
return prompt
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def calculate_weighted_scores(
|
| 247 |
+
raw_scores: Dict[str, Dict[str, Any]],
|
| 248 |
+
rubrics: List[Dict[str, Any]]
|
| 249 |
+
) -> Dict[str, float]:
|
| 250 |
+
"""Calculate weighted scores for each rubric."""
|
| 251 |
+
rubric_weights = {r['title']: r['weight'] for r in rubrics}
|
| 252 |
+
weighted_scores = {}
|
| 253 |
+
|
| 254 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 255 |
+
if rubric_title not in rubric_weights:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
rubric_score = rubric_data.get('score', 0)
|
| 259 |
+
if isinstance(rubric_score, str):
|
| 260 |
+
try:
|
| 261 |
+
rubric_score = int(rubric_score)
|
| 262 |
+
except ValueError:
|
| 263 |
+
rubric_score = 0
|
| 264 |
+
|
| 265 |
+
if rubric_score not in [0, 1]:
|
| 266 |
+
rubric_score = 1 if rubric_score > 0 else 0
|
| 267 |
+
|
| 268 |
+
weight = rubric_weights[rubric_title]
|
| 269 |
+
weighted_scores[rubric_title] = rubric_score * weight
|
| 270 |
+
|
| 271 |
+
return weighted_scores
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
|
| 275 |
+
"""Calculate scores for each rubric."""
|
| 276 |
+
scores = {}
|
| 277 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 278 |
+
scores[rubric_title] = rubric_data.get('score', 0)
|
| 279 |
+
return scores
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def evaluate_review_semantic(
|
| 283 |
+
entry: Dict[str, Any],
|
| 284 |
+
paper_content: str,
|
| 285 |
+
prompt_template: str,
|
| 286 |
+
llm_service: LLMService
|
| 287 |
+
) -> Dict[str, Any]:
|
| 288 |
+
"""Evaluate a single review using article-specific rubrics."""
|
| 289 |
+
entry_id = entry.get('id', 'unknown')
|
| 290 |
+
rubrics = entry.get('rubrics', [])
|
| 291 |
+
model_review = entry.get('model_review', '')
|
| 292 |
+
|
| 293 |
+
if not rubrics:
|
| 294 |
+
return {
|
| 295 |
+
'id': entry_id,
|
| 296 |
+
'raw_scores': {},
|
| 297 |
+
'weighted_scores': {},
|
| 298 |
+
'total_score': 0.0,
|
| 299 |
+
'error': 'No valid rubrics found',
|
| 300 |
+
'raw_response': ''
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
# Build prompt
|
| 304 |
+
prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
|
| 305 |
+
|
| 306 |
+
# Call LLM
|
| 307 |
+
try:
|
| 308 |
+
messages = [{"role": "user", "content": prompt}]
|
| 309 |
+
response = llm_service.generate(messages=messages)
|
| 310 |
+
|
| 311 |
+
# Parse response
|
| 312 |
+
raw_scores = parse_llm_response(response)
|
| 313 |
+
weighted_scores = calculate_scores(raw_scores)
|
| 314 |
+
total_score = sum(weighted_scores.values())
|
| 315 |
+
|
| 316 |
+
return {
|
| 317 |
+
'id': entry_id,
|
| 318 |
+
'raw_scores': raw_scores,
|
| 319 |
+
'weighted_scores': weighted_scores,
|
| 320 |
+
'total_score': total_score,
|
| 321 |
+
'raw_response': response
|
| 322 |
+
}
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f"[ERROR] Error evaluating review {entry_id}: {e}")
|
| 325 |
+
return {
|
| 326 |
+
'id': entry_id,
|
| 327 |
+
'raw_scores': {},
|
| 328 |
+
'weighted_scores': {},
|
| 329 |
+
'total_score': 0.0,
|
| 330 |
+
'error': str(e),
|
| 331 |
+
'raw_response': ''
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def calculate_per_rubric_statistics(
|
| 336 |
+
valid_results: List[Dict[str, Any]],
|
| 337 |
+
rubric_titles: List[str]
|
| 338 |
+
) -> Dict[str, Dict[str, float]]:
|
| 339 |
+
"""Calculate per-rubric statistics from evaluation results."""
|
| 340 |
+
rubric_scores = {title: [] for title in rubric_titles}
|
| 341 |
+
|
| 342 |
+
for result in valid_results:
|
| 343 |
+
weighted_scores = result.get('weighted_scores', {})
|
| 344 |
+
if not isinstance(weighted_scores, dict):
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
for rubric_title in rubric_titles:
|
| 348 |
+
if rubric_title in weighted_scores:
|
| 349 |
+
score = weighted_scores[rubric_title]
|
| 350 |
+
if isinstance(score, str):
|
| 351 |
+
try:
|
| 352 |
+
score = float(score)
|
| 353 |
+
except ValueError:
|
| 354 |
+
continue
|
| 355 |
+
elif isinstance(score, (int, float)):
|
| 356 |
+
score = float(score)
|
| 357 |
+
else:
|
| 358 |
+
continue
|
| 359 |
+
rubric_scores[rubric_title].append(score)
|
| 360 |
+
|
| 361 |
+
per_rubric_stats = {}
|
| 362 |
+
for rubric_title in rubric_titles:
|
| 363 |
+
scores = rubric_scores[rubric_title]
|
| 364 |
+
if not scores:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
mean_score = sum(scores) / len(scores)
|
| 368 |
+
min_score = min(scores)
|
| 369 |
+
max_score = max(scores)
|
| 370 |
+
count = len(scores)
|
| 371 |
+
|
| 372 |
+
if rubric_title == "False or Contradictory Claims":
|
| 373 |
+
pass_count = sum(1 for s in scores if s >= 0)
|
| 374 |
+
else:
|
| 375 |
+
pass_count = sum(1 for s in scores if s >= 1)
|
| 376 |
+
pass_rate = pass_count / count if count > 0 else 0.0
|
| 377 |
+
|
| 378 |
+
per_rubric_stats[rubric_title] = {
|
| 379 |
+
'mean': mean_score,
|
| 380 |
+
'min': min_score,
|
| 381 |
+
'max': max_score,
|
| 382 |
+
'count': count,
|
| 383 |
+
'pass_rate': pass_rate
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
return per_rubric_stats
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# ============================================================================
|
| 390 |
+
# Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
|
| 391 |
+
# ============================================================================
|
| 392 |
+
|
| 393 |
+
def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
|
| 394 |
+
"""Extract numeric scores and decision from a review markdown text."""
|
| 395 |
+
if not review_text:
|
| 396 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
parsed = parse_review_markdown(review_text)
|
| 400 |
+
decision = parsed.get('decision', '')
|
| 401 |
+
if decision:
|
| 402 |
+
decision_lower = decision.lower().strip()
|
| 403 |
+
if 'accept' in decision_lower:
|
| 404 |
+
decision = 'accept'
|
| 405 |
+
elif 'reject' in decision_lower:
|
| 406 |
+
decision = 'reject'
|
| 407 |
+
elif 'undecided' in decision_lower:
|
| 408 |
+
decision = 'undecided'
|
| 409 |
+
else:
|
| 410 |
+
decision = decision_lower
|
| 411 |
+
else:
|
| 412 |
+
decision = None
|
| 413 |
+
|
| 414 |
+
return {
|
| 415 |
+
'soundness': parsed.get('soundness'),
|
| 416 |
+
'presentation': parsed.get('presentation'),
|
| 417 |
+
'rating': parsed.get('rating'),
|
| 418 |
+
'confidence': parsed.get('confidence'),
|
| 419 |
+
'decision': decision
|
| 420 |
+
}
|
| 421 |
+
except Exception as e:
|
| 422 |
+
print(f"Warning: Failed to parse review text: {e}")
|
| 423 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
|
| 427 |
+
"""Calculate Mean Squared Error for a single value."""
|
| 428 |
+
if predicted is None or ground_truth is None:
|
| 429 |
+
return None
|
| 430 |
+
return (predicted - ground_truth) ** 2
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
|
| 434 |
+
"""Calculate Mean Absolute Error for a single value."""
|
| 435 |
+
if predicted is None or ground_truth is None:
|
| 436 |
+
return None
|
| 437 |
+
return abs(predicted - ground_truth)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
|
| 441 |
+
"""
|
| 442 |
+
Normalize a float score to the nearest discrete value based on scale type.
|
| 443 |
+
Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
score: The float score to normalize (can be None)
|
| 447 |
+
scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
|
| 448 |
+
or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
Normalized discrete score, or None if input is None
|
| 452 |
+
"""
|
| 453 |
+
if score is None:
|
| 454 |
+
return None
|
| 455 |
+
|
| 456 |
+
try:
|
| 457 |
+
score = float(score)
|
| 458 |
+
except (ValueError, TypeError):
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
if scale_type == '0-5':
|
| 462 |
+
# Discrete values: 0, 1, 2, 3, 4, 5
|
| 463 |
+
discrete_values = [0, 1, 2, 3, 4, 5]
|
| 464 |
+
# Clamp to valid range
|
| 465 |
+
score = max(0, min(5, score))
|
| 466 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 467 |
+
# For ties, prefer the higher value
|
| 468 |
+
best_value = None
|
| 469 |
+
best_distance = float('inf')
|
| 470 |
+
for val in discrete_values:
|
| 471 |
+
distance = abs(val - score)
|
| 472 |
+
if distance < best_distance:
|
| 473 |
+
best_distance = distance
|
| 474 |
+
best_value = val
|
| 475 |
+
elif distance == best_distance and val > best_value:
|
| 476 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 477 |
+
best_value = val
|
| 478 |
+
return best_value
|
| 479 |
+
elif scale_type == '0-10':
|
| 480 |
+
# Discrete values: 0, 2, 4, 6, 8, 10
|
| 481 |
+
discrete_values = [0, 2, 4, 6, 8, 10]
|
| 482 |
+
# Clamp to valid range
|
| 483 |
+
score = max(0, min(10, score))
|
| 484 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 485 |
+
best_value = None
|
| 486 |
+
best_distance = float('inf')
|
| 487 |
+
for val in discrete_values:
|
| 488 |
+
distance = abs(val - score)
|
| 489 |
+
if distance < best_distance:
|
| 490 |
+
best_distance = distance
|
| 491 |
+
best_value = val
|
| 492 |
+
elif distance == best_distance and val > best_value:
|
| 493 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 494 |
+
best_value = val
|
| 495 |
+
return best_value
|
| 496 |
+
else:
|
| 497 |
+
raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
|
| 501 |
+
"""
|
| 502 |
+
Normalize all scores in a dictionary to their appropriate discrete scales.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Dictionary with normalized scores
|
| 509 |
+
"""
|
| 510 |
+
normalized = {}
|
| 511 |
+
|
| 512 |
+
# soundness, presentation, confidence use 0-5 scale
|
| 513 |
+
for key in ['soundness', 'presentation', 'confidence']:
|
| 514 |
+
normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
|
| 515 |
+
|
| 516 |
+
# rating uses 0-10 scale
|
| 517 |
+
normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
|
| 518 |
+
|
| 519 |
+
return normalized
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def calculate_score_metrics(
|
| 523 |
+
model_scores: Dict[str, float],
|
| 524 |
+
ground_truth_scores: Dict[str, float],
|
| 525 |
+
normalize: bool = False
|
| 526 |
+
) -> Dict[str, Any]:
|
| 527 |
+
"""
|
| 528 |
+
Calculate MSE and MAE metrics for each scoring dimension.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
model_scores: Dictionary with model scores
|
| 532 |
+
ground_truth_scores: Dictionary with ground truth scores
|
| 533 |
+
normalize: If True, normalize scores to discrete scales before computing metrics
|
| 534 |
+
|
| 535 |
+
Returns:
|
| 536 |
+
Dictionary with MSE, MAE metrics and optionally normalized scores
|
| 537 |
+
"""
|
| 538 |
+
dimensions = ['soundness', 'presentation', 'rating', 'confidence']
|
| 539 |
+
|
| 540 |
+
# Normalize scores to discrete scales if requested
|
| 541 |
+
if normalize:
|
| 542 |
+
model_scores_normalized = normalize_scores_dict(model_scores)
|
| 543 |
+
gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
|
| 544 |
+
else:
|
| 545 |
+
model_scores_normalized = model_scores
|
| 546 |
+
gt_scores_normalized = ground_truth_scores
|
| 547 |
+
|
| 548 |
+
mse_values = {}
|
| 549 |
+
mae_values = {}
|
| 550 |
+
valid_count = 0
|
| 551 |
+
|
| 552 |
+
for dim in dimensions:
|
| 553 |
+
# Use normalized scores for metric calculation
|
| 554 |
+
mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 555 |
+
mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 556 |
+
mse_values[f'{dim}_mse'] = mse
|
| 557 |
+
mae_values[f'{dim}_mae'] = mae
|
| 558 |
+
if mse is not None:
|
| 559 |
+
valid_count += 1
|
| 560 |
+
|
| 561 |
+
overall_error = sum([v for v in mse_values.values() if v is not None])
|
| 562 |
+
|
| 563 |
+
result = {
|
| 564 |
+
**mse_values,
|
| 565 |
+
**mae_values,
|
| 566 |
+
'overall_error': overall_error if valid_count > 0 else None,
|
| 567 |
+
'valid_dimensions': valid_count
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
# Include normalized scores in result for transparency (only if normalize=True)
|
| 571 |
+
if normalize:
|
| 572 |
+
result['model_scores_normalized'] = model_scores_normalized
|
| 573 |
+
result['gt_scores_normalized'] = gt_scores_normalized
|
| 574 |
+
|
| 575 |
+
return result
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def normalize_score_value(value):
|
| 579 |
+
"""Normalize score value to float, handling string representations."""
|
| 580 |
+
if value is None:
|
| 581 |
+
return None
|
| 582 |
+
if isinstance(value, (int, float)):
|
| 583 |
+
return float(value)
|
| 584 |
+
if isinstance(value, str):
|
| 585 |
+
# Try to extract numeric value from string (e.g., "2.75" -> 2.75)
|
| 586 |
+
try:
|
| 587 |
+
import re
|
| 588 |
+
match = re.search(r'(\d+\.?\d*)', value)
|
| 589 |
+
if match:
|
| 590 |
+
return float(match.group(1))
|
| 591 |
+
except:
|
| 592 |
+
pass
|
| 593 |
+
return None
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def normalize_decision(decision):
|
| 597 |
+
"""Normalize decision string to standard format."""
|
| 598 |
+
if decision is None:
|
| 599 |
+
return None
|
| 600 |
+
decision_lower = str(decision).lower().strip()
|
| 601 |
+
if 'accept' in decision_lower:
|
| 602 |
+
return 'accept'
|
| 603 |
+
elif 'reject' in decision_lower:
|
| 604 |
+
return 'reject'
|
| 605 |
+
elif 'undecided' in decision_lower:
|
| 606 |
+
return 'undecided'
|
| 607 |
+
else:
|
| 608 |
+
return decision_lower
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 612 |
+
"""
|
| 613 |
+
Extract scores from a structured dictionary (scores or initial_scores format).
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
|
| 620 |
+
"""
|
| 621 |
+
if not scores_dict:
|
| 622 |
+
return {
|
| 623 |
+
'soundness': None,
|
| 624 |
+
'presentation': None,
|
| 625 |
+
'rating': None,
|
| 626 |
+
'confidence': None,
|
| 627 |
+
'decision': None
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
return {
|
| 631 |
+
'soundness': normalize_score_value(scores_dict.get('soundness')),
|
| 632 |
+
'presentation': normalize_score_value(scores_dict.get('presentation')),
|
| 633 |
+
'rating': normalize_score_value(scores_dict.get('rating')),
|
| 634 |
+
'confidence': normalize_score_value(scores_dict.get('confidence')),
|
| 635 |
+
'decision': normalize_decision(scores_dict.get('decision'))
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
|
| 640 |
+
"""
|
| 641 |
+
Evaluate a single entry by extracting scores and calculating metrics.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
entry: Evaluation entry containing model_review, scores, initial_scores, etc.
|
| 645 |
+
use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
|
| 646 |
+
|
| 647 |
+
Returns:
|
| 648 |
+
Dict containing evaluation metrics
|
| 649 |
+
"""
|
| 650 |
+
entry_id = entry.get('id', 'unknown')
|
| 651 |
+
model_review = entry.get('model_review', '')
|
| 652 |
+
format_type = entry.get('format', 'unknown')
|
| 653 |
+
|
| 654 |
+
# Extract scores based on format
|
| 655 |
+
model_scores = {}
|
| 656 |
+
model_decision = None
|
| 657 |
+
|
| 658 |
+
if format_type == 'refined' and not use_initial_scores:
|
| 659 |
+
# Use refined scores from structured data
|
| 660 |
+
scores_dict = entry.get('scores', {})
|
| 661 |
+
model_data = extract_scores_from_dict(scores_dict)
|
| 662 |
+
model_scores = {
|
| 663 |
+
'soundness': model_data.get('soundness'),
|
| 664 |
+
'presentation': model_data.get('presentation'),
|
| 665 |
+
'rating': model_data.get('rating'),
|
| 666 |
+
'confidence': model_data.get('confidence')
|
| 667 |
+
}
|
| 668 |
+
model_decision = model_data.get('decision')
|
| 669 |
+
elif format_type == 'refined' and use_initial_scores:
|
| 670 |
+
# Use initial scores from structured data
|
| 671 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 672 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 673 |
+
model_scores = {
|
| 674 |
+
'soundness': model_data.get('soundness'),
|
| 675 |
+
'presentation': model_data.get('presentation'),
|
| 676 |
+
'rating': model_data.get('rating'),
|
| 677 |
+
'confidence': model_data.get('confidence')
|
| 678 |
+
}
|
| 679 |
+
model_decision = model_data.get('decision')
|
| 680 |
+
elif format_type == 'original':
|
| 681 |
+
# Use initial scores from structured data
|
| 682 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 683 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 684 |
+
model_scores = {
|
| 685 |
+
'soundness': model_data.get('soundness'),
|
| 686 |
+
'presentation': model_data.get('presentation'),
|
| 687 |
+
'rating': model_data.get('rating'),
|
| 688 |
+
'confidence': model_data.get('confidence')
|
| 689 |
+
}
|
| 690 |
+
model_decision = model_data.get('decision')
|
| 691 |
+
|
| 692 |
+
# Fallback: If confidence is missing from structured data, try to extract from review text
|
| 693 |
+
# (meta_review may not have confidence field, but review text might)
|
| 694 |
+
if model_scores.get('confidence') is None and model_review:
|
| 695 |
+
try:
|
| 696 |
+
review_data = extract_scores_from_review(model_review)
|
| 697 |
+
if review_data.get('confidence') is not None:
|
| 698 |
+
model_scores['confidence'] = review_data.get('confidence')
|
| 699 |
+
except Exception:
|
| 700 |
+
pass # Keep confidence as None if extraction fails
|
| 701 |
+
else:
|
| 702 |
+
# Fallback: extract from markdown review text
|
| 703 |
+
model_data = extract_scores_from_review(model_review)
|
| 704 |
+
model_scores = {
|
| 705 |
+
'soundness': model_data.get('soundness'),
|
| 706 |
+
'presentation': model_data.get('presentation'),
|
| 707 |
+
'rating': model_data.get('rating'),
|
| 708 |
+
'confidence': model_data.get('confidence')
|
| 709 |
+
}
|
| 710 |
+
model_decision = model_data.get('decision')
|
| 711 |
+
|
| 712 |
+
# Get ground truth scores from golden_review ONLY
|
| 713 |
+
# Ground truth must ONLY come from golden_review, never from model output
|
| 714 |
+
# If extraction fails, leave fields as None (do not use model_review as fallback)
|
| 715 |
+
ground_truth_review = entry.get('golden_review', '')
|
| 716 |
+
ground_truth_scores = {}
|
| 717 |
+
gt_decision = None
|
| 718 |
+
|
| 719 |
+
if not ground_truth_review:
|
| 720 |
+
print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
|
| 721 |
+
else:
|
| 722 |
+
try:
|
| 723 |
+
# Extract scores from golden_review markdown text
|
| 724 |
+
gt_data = extract_scores_from_review(ground_truth_review)
|
| 725 |
+
if not gt_data:
|
| 726 |
+
print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
|
| 727 |
+
else:
|
| 728 |
+
ground_truth_scores = {
|
| 729 |
+
'soundness': gt_data.get('soundness'),
|
| 730 |
+
'presentation': gt_data.get('presentation'),
|
| 731 |
+
'rating': gt_data.get('rating'),
|
| 732 |
+
'confidence': gt_data.get('confidence')
|
| 733 |
+
}
|
| 734 |
+
gt_decision = normalize_decision(gt_data.get('decision'))
|
| 735 |
+
# Note: If any field is None, it stays None - we do NOT use model_review as fallback
|
| 736 |
+
# Using model output as ground truth would inflate evaluation scores
|
| 737 |
+
except Exception as e:
|
| 738 |
+
print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
|
| 739 |
+
print(f" Ground truth scores will be empty. Error: {str(e)}")
|
| 740 |
+
|
| 741 |
+
# Calculate MSE and MAE metrics (with optional normalization in strict mode)
|
| 742 |
+
score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
|
| 743 |
+
|
| 744 |
+
# Calculate decision accuracy
|
| 745 |
+
decision_match = False
|
| 746 |
+
decision_accuracy = None
|
| 747 |
+
if model_decision is not None and gt_decision is not None:
|
| 748 |
+
model_decision_normalized = normalize_decision(model_decision)
|
| 749 |
+
decision_match = (model_decision_normalized == gt_decision)
|
| 750 |
+
decision_accuracy = 1.0 if decision_match else 0.0
|
| 751 |
+
|
| 752 |
+
result = {
|
| 753 |
+
'id': entry_id,
|
| 754 |
+
'format': format_type,
|
| 755 |
+
'model_soundness': model_scores.get('soundness'),
|
| 756 |
+
'model_presentation': model_scores.get('presentation'),
|
| 757 |
+
'model_rating': model_scores.get('rating'),
|
| 758 |
+
'model_confidence': model_scores.get('confidence'),
|
| 759 |
+
'model_decision': model_decision,
|
| 760 |
+
'gt_soundness': ground_truth_scores.get('soundness'),
|
| 761 |
+
'gt_presentation': ground_truth_scores.get('presentation'),
|
| 762 |
+
'gt_rating': ground_truth_scores.get('rating'),
|
| 763 |
+
'gt_confidence': ground_truth_scores.get('confidence'),
|
| 764 |
+
'gt_decision': gt_decision,
|
| 765 |
+
'decision_match': decision_match,
|
| 766 |
+
'decision_accuracy': decision_accuracy,
|
| 767 |
+
**score_metrics
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
# Add prefix to indicate which scores were used
|
| 771 |
+
if format_type == 'refined':
|
| 772 |
+
if use_initial_scores:
|
| 773 |
+
result['score_type'] = 'initial'
|
| 774 |
+
else:
|
| 775 |
+
result['score_type'] = 'refined'
|
| 776 |
+
else:
|
| 777 |
+
result['score_type'] = 'auto'
|
| 778 |
+
|
| 779 |
+
return result
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
|
| 783 |
+
"""Calculate pairwise accuracy for each metric by comparing rankings."""
|
| 784 |
+
if len(paper_scores) < 2:
|
| 785 |
+
return {}
|
| 786 |
+
|
| 787 |
+
total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 788 |
+
correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 789 |
+
|
| 790 |
+
for paper1, paper2 in combinations(paper_scores, 2):
|
| 791 |
+
# Check rating ranking
|
| 792 |
+
if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
|
| 793 |
+
paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
|
| 794 |
+
total_valid_pairs['rating'] += 1
|
| 795 |
+
true_order = paper1['true_rating'] > paper2['true_rating']
|
| 796 |
+
pred_order = paper1['pred_rating'] > paper2['pred_rating']
|
| 797 |
+
if true_order == pred_order:
|
| 798 |
+
correct_pairs['rating'] += 1
|
| 799 |
+
|
| 800 |
+
# Similar for other dimensions...
|
| 801 |
+
# (abbreviated for space, similar logic for soundness, presentation, confidence)
|
| 802 |
+
for metric in ['soundness', 'presentation', 'confidence']:
|
| 803 |
+
true_key = f'true_{metric}'
|
| 804 |
+
pred_key = f'pred_{metric}'
|
| 805 |
+
if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
|
| 806 |
+
paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
|
| 807 |
+
total_valid_pairs[metric] += 1
|
| 808 |
+
true_order = paper1[true_key] > paper2[true_key]
|
| 809 |
+
pred_order = paper1[pred_key] > paper2[pred_key]
|
| 810 |
+
if true_order == pred_order:
|
| 811 |
+
correct_pairs[metric] += 1
|
| 812 |
+
|
| 813 |
+
pairwise_accuracies = {
|
| 814 |
+
metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
|
| 815 |
+
for metric in ['rating', 'soundness', 'presentation', 'confidence']
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
return pairwise_accuracies
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
# ============================================================================
|
| 822 |
+
# Data Loading Functions
|
| 823 |
+
# ============================================================================
|
| 824 |
+
|
| 825 |
+
def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
|
| 826 |
+
"""Load rubrics JSON and create lookup by id."""
|
| 827 |
+
with open(rubrics_path, 'r', encoding='utf-8') as f:
|
| 828 |
+
data = json.load(f)
|
| 829 |
+
|
| 830 |
+
if isinstance(data, list):
|
| 831 |
+
return {item['id']: item for item in data}
|
| 832 |
+
elif isinstance(data, dict):
|
| 833 |
+
return data
|
| 834 |
+
else:
|
| 835 |
+
raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
| 839 |
+
"""
|
| 840 |
+
Load model reviews JSON and extract reviews by id.
|
| 841 |
+
|
| 842 |
+
Supports two input formats:
|
| 843 |
+
1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
|
| 844 |
+
2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
|
| 845 |
+
|
| 846 |
+
Args:
|
| 847 |
+
reviews_path: Path to JSON file containing model reviews
|
| 848 |
+
format_override: Optional format override ('refined', 'original', or None for auto-detect)
|
| 849 |
+
|
| 850 |
+
Returns:
|
| 851 |
+
Dict mapping paper_id to dict containing:
|
| 852 |
+
- 'review': review text (markdown)
|
| 853 |
+
- 'scores': refined scores dict (if available)
|
| 854 |
+
- 'initial_scores': initial scores dict (if available)
|
| 855 |
+
- 'format': 'refined' or 'original'
|
| 856 |
+
"""
|
| 857 |
+
with open(reviews_path, 'r', encoding='utf-8') as f:
|
| 858 |
+
data = json.load(f)
|
| 859 |
+
|
| 860 |
+
if isinstance(data, dict):
|
| 861 |
+
data = list(data.values())
|
| 862 |
+
|
| 863 |
+
reviews_dict = {}
|
| 864 |
+
for item in data:
|
| 865 |
+
item_id = None
|
| 866 |
+
review_text = ''
|
| 867 |
+
scores = None
|
| 868 |
+
initial_scores = None
|
| 869 |
+
format_type = None
|
| 870 |
+
|
| 871 |
+
# Use format override if provided, otherwise auto-detect
|
| 872 |
+
if format_override and format_override != 'auto':
|
| 873 |
+
# Force use specified format
|
| 874 |
+
if format_override == 'refined':
|
| 875 |
+
item_id = item.get('paper_id') or item.get('id')
|
| 876 |
+
if not item_id:
|
| 877 |
+
continue
|
| 878 |
+
format_type = 'refined'
|
| 879 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 880 |
+
scores = item.get('scores', {})
|
| 881 |
+
initial_scores = item.get('initial_scores', {})
|
| 882 |
+
elif format_override == 'original':
|
| 883 |
+
item_id = item.get('id')
|
| 884 |
+
if not item_id:
|
| 885 |
+
continue
|
| 886 |
+
format_type = 'original'
|
| 887 |
+
model_prediction = item.get('model_prediction', {})
|
| 888 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 889 |
+
review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 890 |
+
initial_scores = {
|
| 891 |
+
'rating': meta_review.get('rating'),
|
| 892 |
+
'soundness': meta_review.get('soundness'),
|
| 893 |
+
'presentation': meta_review.get('presentation'),
|
| 894 |
+
'contribution': meta_review.get('contribution'),
|
| 895 |
+
'decision': model_prediction.get('decision'),
|
| 896 |
+
}
|
| 897 |
+
else:
|
| 898 |
+
raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
|
| 899 |
+
else:
|
| 900 |
+
# Auto-detect format
|
| 901 |
+
if "paper_id" in item:
|
| 902 |
+
# Refined format (from refinement pipeline)
|
| 903 |
+
item_id = item.get('paper_id')
|
| 904 |
+
if not item_id:
|
| 905 |
+
continue
|
| 906 |
+
|
| 907 |
+
# Check if this is refined format (has scores and initial_scores)
|
| 908 |
+
if 'scores' in item and 'initial_scores' in item:
|
| 909 |
+
format_type = 'refined'
|
| 910 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 911 |
+
scores = item.get('scores', {})
|
| 912 |
+
initial_scores = item.get('initial_scores', {})
|
| 913 |
+
else:
|
| 914 |
+
# Standard format with paper_id
|
| 915 |
+
format_type = 'standard'
|
| 916 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 917 |
+
elif "model_prediction" in item:
|
| 918 |
+
# Original format (like ours.json) or agenticreview format
|
| 919 |
+
item_id = item.get('id')
|
| 920 |
+
if not item_id:
|
| 921 |
+
continue
|
| 922 |
+
|
| 923 |
+
format_type = 'original'
|
| 924 |
+
model_prediction = item.get('model_prediction', {})
|
| 925 |
+
|
| 926 |
+
review_text = model_prediction.get('raw_text', '')
|
| 927 |
+
|
| 928 |
+
if review_text is None:
|
| 929 |
+
continue
|
| 930 |
+
|
| 931 |
+
# Detect format: agenticreview has raw_text as string with "Overall rating: x"
|
| 932 |
+
# ai_researcher format has raw_text as dict or JSON string with structured fields
|
| 933 |
+
is_agenticreview = False
|
| 934 |
+
if isinstance(review_text, str):
|
| 935 |
+
# Check if it's a JSON string (ai_researcher format)
|
| 936 |
+
try:
|
| 937 |
+
parsed_json = json.loads(review_text)
|
| 938 |
+
if isinstance(parsed_json, dict) and any(key in parsed_json for key in ["Summary", "Strengths", "Overall", "Decision"]):
|
| 939 |
+
# It's ai_researcher format
|
| 940 |
+
review_text = parsed_json
|
| 941 |
+
review_text, meta_review = convert_ai_researcher(review_text)
|
| 942 |
+
else:
|
| 943 |
+
# It's agenticreview format (plain text with "Overall rating: x")
|
| 944 |
+
is_agenticreview = True
|
| 945 |
+
except (json.JSONDecodeError, TypeError):
|
| 946 |
+
# Not JSON, check if it contains "Overall rating:" pattern
|
| 947 |
+
if re.search(r'Overall\s+rating\s*[:=]', review_text, re.IGNORECASE):
|
| 948 |
+
is_agenticreview = True
|
| 949 |
+
else:
|
| 950 |
+
# Try to parse as ai_researcher anyway
|
| 951 |
+
try:
|
| 952 |
+
review_text = json.loads(review_text)
|
| 953 |
+
review_text, meta_review = convert_ai_researcher(review_text)
|
| 954 |
+
except:
|
| 955 |
+
review_text = 'Empty Review'
|
| 956 |
+
meta_review = {}
|
| 957 |
+
elif isinstance(review_text, dict):
|
| 958 |
+
# It's ai_researcher format (dict)
|
| 959 |
+
review_text, meta_review = convert_ai_researcher(review_text)
|
| 960 |
+
else:
|
| 961 |
+
review_text = 'Empty Review'
|
| 962 |
+
meta_review = {}
|
| 963 |
+
|
| 964 |
+
# Handle agenticreview format
|
| 965 |
+
if is_agenticreview:
|
| 966 |
+
review_text, meta_review = convert_agenticreview(review_text)
|
| 967 |
+
|
| 968 |
+
# Extract initial scores
|
| 969 |
+
# Use meta_review as primary source (from convert_ai_researcher or convert_agenticreview)
|
| 970 |
+
# Fallback to model_prediction.get('decision') if not in meta_review
|
| 971 |
+
initial_scores = {
|
| 972 |
+
'rating': meta_review.get('rating'),
|
| 973 |
+
'soundness': meta_review.get('soundness'),
|
| 974 |
+
'presentation': meta_review.get('presentation'),
|
| 975 |
+
'contribution': meta_review.get('contribution'),
|
| 976 |
+
'confidence': meta_review.get('confidence'),
|
| 977 |
+
'decision': meta_review.get('decision') or model_prediction.get('decision'),
|
| 978 |
+
}
|
| 979 |
+
else:
|
| 980 |
+
# Legacy format (pred_fast_mode)
|
| 981 |
+
item_id = item.get('id')
|
| 982 |
+
if not item_id:
|
| 983 |
+
continue
|
| 984 |
+
|
| 985 |
+
format_type = 'legacy'
|
| 986 |
+
review_dict = item.get('pred_fast_mode', {})
|
| 987 |
+
if isinstance(review_dict, dict):
|
| 988 |
+
review_text = review_dict.get('raw_text', '')
|
| 989 |
+
else:
|
| 990 |
+
review_text = str(review_dict)
|
| 991 |
+
|
| 992 |
+
# Extract review content from the review text field
|
| 993 |
+
try:
|
| 994 |
+
if review_text:
|
| 995 |
+
# extracted_review = ReviewProcessor.extract_review_content(review_text)
|
| 996 |
+
extracted_review = review_text
|
| 997 |
+
else:
|
| 998 |
+
extracted_review = ''
|
| 999 |
+
|
| 1000 |
+
reviews_dict[item_id] = {
|
| 1001 |
+
'review': extracted_review,
|
| 1002 |
+
'scores': scores,
|
| 1003 |
+
'initial_scores': initial_scores,
|
| 1004 |
+
'format': format_type
|
| 1005 |
+
}
|
| 1006 |
+
except Exception as e:
|
| 1007 |
+
print(f"[WARN] Failed to extract review for {item_id}: {e}")
|
| 1008 |
+
continue
|
| 1009 |
+
|
| 1010 |
+
return reviews_dict
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def combine_rubrics_and_reviews(
|
| 1014 |
+
rubrics_data: Dict[str, Dict[str, Any]],
|
| 1015 |
+
reviews_dict: Dict[str, Dict[str, Any]]
|
| 1016 |
+
) -> List[Dict[str, Any]]:
|
| 1017 |
+
"""
|
| 1018 |
+
Combine rubrics and reviews into evaluation entries.
|
| 1019 |
+
|
| 1020 |
+
Args:
|
| 1021 |
+
rubrics_data: Dict mapping paper_id to rubric entry
|
| 1022 |
+
reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
|
| 1023 |
+
|
| 1024 |
+
Returns:
|
| 1025 |
+
List of evaluation entries with model_review, scores, initial_scores, and format info
|
| 1026 |
+
"""
|
| 1027 |
+
combined = []
|
| 1028 |
+
missing_reviews = []
|
| 1029 |
+
|
| 1030 |
+
for paper_id, rubric_entry in rubrics_data.items():
|
| 1031 |
+
review_data = reviews_dict.get(paper_id)
|
| 1032 |
+
if not review_data or not review_data.get('review'):
|
| 1033 |
+
missing_reviews.append(paper_id)
|
| 1034 |
+
continue
|
| 1035 |
+
|
| 1036 |
+
entry = {
|
| 1037 |
+
'id': paper_id,
|
| 1038 |
+
'paper_context': rubric_entry.get('paper_context', ''),
|
| 1039 |
+
'decision': rubric_entry.get('decision', ''),
|
| 1040 |
+
'golden_review': rubric_entry.get('golden_review', ''),
|
| 1041 |
+
'rubrics': rubric_entry.get('rubrics', []),
|
| 1042 |
+
'model_review': review_data.get('review', ''),
|
| 1043 |
+
'scores': review_data.get('scores'), # Refined scores (if available)
|
| 1044 |
+
'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
|
| 1045 |
+
'format': review_data.get('format', 'unknown') # Format type
|
| 1046 |
+
}
|
| 1047 |
+
combined.append(entry)
|
| 1048 |
+
|
| 1049 |
+
if missing_reviews:
|
| 1050 |
+
print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
|
| 1051 |
+
|
| 1052 |
+
return combined
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
# ============================================================================
|
| 1056 |
+
# LLM Service Configuration
|
| 1057 |
+
# ============================================================================
|
| 1058 |
+
|
| 1059 |
+
def load_llm_config(config_path: str) -> Dict[str, Any]:
|
| 1060 |
+
"""Load LLM configuration from YAML file."""
|
| 1061 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 1062 |
+
config = yaml.safe_load(f)
|
| 1063 |
+
return config
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
|
| 1067 |
+
"""Create LLM service from configuration."""
|
| 1068 |
+
mode = config.get('mode', 'gpt').lower()
|
| 1069 |
+
|
| 1070 |
+
if mode == 'gpt':
|
| 1071 |
+
gpt_config = config.get('gpt', {})
|
| 1072 |
+
api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
|
| 1073 |
+
if not api_key:
|
| 1074 |
+
raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
|
| 1075 |
+
|
| 1076 |
+
service = GPTService(
|
| 1077 |
+
api_key=api_key,
|
| 1078 |
+
model_name=gpt_config.get('model_name', 'gpt-4o'),
|
| 1079 |
+
base_url=gpt_config.get('base_url'),
|
| 1080 |
+
timeout=gpt_config.get('timeout', 300)
|
| 1081 |
+
)
|
| 1082 |
+
return service
|
| 1083 |
+
|
| 1084 |
+
elif mode == 'vllm':
|
| 1085 |
+
vllm_config = config.get('vllm', {})
|
| 1086 |
+
service = VLLMService(
|
| 1087 |
+
base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
|
| 1088 |
+
api_key=vllm_config.get('api_key', 'dummy-key'),
|
| 1089 |
+
model_name=vllm_config.get('model_name'),
|
| 1090 |
+
timeout=vllm_config.get('timeout', 300),
|
| 1091 |
+
max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
|
| 1092 |
+
max_retries=vllm_config.get('max_retries', 3),
|
| 1093 |
+
retry_delay=vllm_config.get('retry_delay', 1.0),
|
| 1094 |
+
retry_backoff=vllm_config.get('retry_backoff', 2.0)
|
| 1095 |
+
)
|
| 1096 |
+
return service
|
| 1097 |
+
|
| 1098 |
+
else:
|
| 1099 |
+
raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
# ============================================================================
|
| 1103 |
+
# Main Evaluation Functions
|
| 1104 |
+
# ============================================================================
|
| 1105 |
+
|
| 1106 |
+
def run_semantic_evaluation(
|
| 1107 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1108 |
+
prompt_template: str,
|
| 1109 |
+
llm_service: LLMService,
|
| 1110 |
+
max_workers: int
|
| 1111 |
+
) -> tuple:
|
| 1112 |
+
"""Run semantic evaluation and return results and summary."""
|
| 1113 |
+
print(f"\n{'='*80}")
|
| 1114 |
+
print("RUNNING SEMANTIC EVALUATION")
|
| 1115 |
+
print(f"{'='*80}")
|
| 1116 |
+
print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
|
| 1117 |
+
|
| 1118 |
+
results = []
|
| 1119 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 1120 |
+
future_to_entry = {
|
| 1121 |
+
executor.submit(
|
| 1122 |
+
evaluate_review_semantic,
|
| 1123 |
+
entry,
|
| 1124 |
+
entry['paper_context'],
|
| 1125 |
+
prompt_template,
|
| 1126 |
+
llm_service
|
| 1127 |
+
): entry
|
| 1128 |
+
for entry in evaluation_data
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
|
| 1132 |
+
try:
|
| 1133 |
+
result = future.result()
|
| 1134 |
+
results.append(result)
|
| 1135 |
+
except Exception as e:
|
| 1136 |
+
entry = future_to_entry[future]
|
| 1137 |
+
print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
|
| 1138 |
+
results.append({
|
| 1139 |
+
'id': entry.get('id', 'unknown'),
|
| 1140 |
+
'raw_scores': {},
|
| 1141 |
+
'weighted_scores': {},
|
| 1142 |
+
'total_score': 0.0,
|
| 1143 |
+
'error': str(e),
|
| 1144 |
+
'raw_response': ''
|
| 1145 |
+
})
|
| 1146 |
+
|
| 1147 |
+
# Calculate statistics
|
| 1148 |
+
valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
|
| 1149 |
+
review_scores = [r.get('total_score', 0.0) for r in valid_results]
|
| 1150 |
+
|
| 1151 |
+
summary = {
|
| 1152 |
+
'total_entries': len(results),
|
| 1153 |
+
'valid_entries': len(valid_results),
|
| 1154 |
+
'failed_entries': len(results) - len(valid_results)
|
| 1155 |
+
}
|
| 1156 |
+
|
| 1157 |
+
if review_scores:
|
| 1158 |
+
summary['overall_score'] = {
|
| 1159 |
+
'mean': sum(review_scores) / len(review_scores),
|
| 1160 |
+
'min': min(review_scores),
|
| 1161 |
+
'max': max(review_scores)
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
# Calculate per-rubric statistics (extract rubric titles from first entry)
|
| 1165 |
+
if evaluation_data and evaluation_data[0].get('rubrics'):
|
| 1166 |
+
rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
|
| 1167 |
+
per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
|
| 1168 |
+
summary['per_rubric_statistics'] = per_rubric_stats
|
| 1169 |
+
|
| 1170 |
+
return results, summary
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
def run_auto_metric_evaluation(
|
| 1174 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1175 |
+
strict_mode: bool = False
|
| 1176 |
+
) -> tuple:
|
| 1177 |
+
"""
|
| 1178 |
+
Run auto-metric evaluation and return results and summary.
|
| 1179 |
+
|
| 1180 |
+
For refined format (has scores and initial_scores), evaluates both:
|
| 1181 |
+
- Refined scores evaluation
|
| 1182 |
+
- Initial scores evaluation
|
| 1183 |
+
|
| 1184 |
+
For original format (only initial_scores), evaluates:
|
| 1185 |
+
- Initial scores evaluation only
|
| 1186 |
+
|
| 1187 |
+
Returns:
|
| 1188 |
+
Tuple of (results_list, summary_dict)
|
| 1189 |
+
- results_list: List of evaluation results (may contain both refined and initial results for refined format)
|
| 1190 |
+
- summary_dict: Summary statistics
|
| 1191 |
+
"""
|
| 1192 |
+
print(f"\n{'='*80}")
|
| 1193 |
+
print("RUNNING AUTO-METRIC EVALUATION")
|
| 1194 |
+
print(f"{'='*80}")
|
| 1195 |
+
print(f"Evaluating {len(evaluation_data)} entries...")
|
| 1196 |
+
|
| 1197 |
+
# Detect format types
|
| 1198 |
+
refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
|
| 1199 |
+
original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
|
| 1200 |
+
|
| 1201 |
+
if refined_format_count > 0:
|
| 1202 |
+
print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
|
| 1203 |
+
if original_format_count > 0:
|
| 1204 |
+
print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
|
| 1205 |
+
|
| 1206 |
+
results = []
|
| 1207 |
+
for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
|
| 1208 |
+
format_type = entry.get('format', 'unknown')
|
| 1209 |
+
|
| 1210 |
+
if format_type == 'refined':
|
| 1211 |
+
# Evaluate both refined scores and initial scores
|
| 1212 |
+
try:
|
| 1213 |
+
entry_id = entry.get('id', 'unknown')
|
| 1214 |
+
|
| 1215 |
+
# Evaluate refined scores
|
| 1216 |
+
refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1217 |
+
refined_result['paper_id'] = entry_id # Keep original paper_id
|
| 1218 |
+
refined_result['id'] = f"{entry_id}_refined"
|
| 1219 |
+
results.append(refined_result)
|
| 1220 |
+
|
| 1221 |
+
# Evaluate initial scores
|
| 1222 |
+
initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
|
| 1223 |
+
initial_result['paper_id'] = entry_id # Keep original paper_id
|
| 1224 |
+
initial_result['id'] = f"{entry_id}_initial"
|
| 1225 |
+
results.append(initial_result)
|
| 1226 |
+
except Exception as e:
|
| 1227 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1228 |
+
results.append({
|
| 1229 |
+
'id': entry.get('id', 'unknown'),
|
| 1230 |
+
'error': str(e)
|
| 1231 |
+
})
|
| 1232 |
+
else:
|
| 1233 |
+
# Evaluate initial scores only (or extract from markdown)
|
| 1234 |
+
try:
|
| 1235 |
+
result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1236 |
+
results.append(result)
|
| 1237 |
+
except Exception as e:
|
| 1238 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1239 |
+
results.append({
|
| 1240 |
+
'id': entry.get('id', 'unknown'),
|
| 1241 |
+
'error': str(e)
|
| 1242 |
+
})
|
| 1243 |
+
|
| 1244 |
+
# Calculate statistics
|
| 1245 |
+
valid_results = [r for r in results if 'error' not in r]
|
| 1246 |
+
mse_results = [r for r in valid_results if r.get('overall_error') is not None]
|
| 1247 |
+
|
| 1248 |
+
# Separate refined and initial results for refined format
|
| 1249 |
+
refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
|
| 1250 |
+
initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
|
| 1251 |
+
auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
|
| 1252 |
+
|
| 1253 |
+
summary = {
|
| 1254 |
+
'total_entries': len(results),
|
| 1255 |
+
'valid_entries': len(valid_results),
|
| 1256 |
+
'mse_entries': len(mse_results),
|
| 1257 |
+
'refined_results_count': len(refined_results),
|
| 1258 |
+
'initial_results_count': len(initial_results),
|
| 1259 |
+
'auto_results_count': len(auto_results)
|
| 1260 |
+
}
|
| 1261 |
+
|
| 1262 |
+
# Calculate MSE/MAE statistics
|
| 1263 |
+
# For refined format, only use refined results for overall statistics (avoid double counting)
|
| 1264 |
+
# For other formats, use all results
|
| 1265 |
+
if refined_format_count > 0:
|
| 1266 |
+
# Refined format: use only refined results for overall statistics
|
| 1267 |
+
stats_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1268 |
+
else:
|
| 1269 |
+
# Original/other formats: use all results
|
| 1270 |
+
stats_results = mse_results
|
| 1271 |
+
|
| 1272 |
+
if stats_results:
|
| 1273 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1274 |
+
mse_stats = {}
|
| 1275 |
+
mae_stats = {}
|
| 1276 |
+
|
| 1277 |
+
for dim in dimensions:
|
| 1278 |
+
mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
|
| 1279 |
+
mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
|
| 1280 |
+
|
| 1281 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1282 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1283 |
+
|
| 1284 |
+
if mse_clean:
|
| 1285 |
+
mse_stats[dim] = {
|
| 1286 |
+
'mean': sum(mse_clean) / len(mse_clean),
|
| 1287 |
+
'count': len(mse_clean)
|
| 1288 |
+
}
|
| 1289 |
+
if mae_clean:
|
| 1290 |
+
mae_stats[dim] = {
|
| 1291 |
+
'mean': sum(mae_clean) / len(mae_clean),
|
| 1292 |
+
'count': len(mae_clean)
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
|
| 1296 |
+
overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1297 |
+
|
| 1298 |
+
if overall_clean:
|
| 1299 |
+
summary['overall_error'] = {
|
| 1300 |
+
'mean': sum(overall_clean) / len(overall_clean),
|
| 1301 |
+
'count': len(overall_clean)
|
| 1302 |
+
}
|
| 1303 |
+
|
| 1304 |
+
summary['mse_statistics'] = mse_stats
|
| 1305 |
+
summary['mae_statistics'] = mae_stats
|
| 1306 |
+
|
| 1307 |
+
# Calculate separate statistics for refined and initial results
|
| 1308 |
+
if refined_results:
|
| 1309 |
+
refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1310 |
+
if refined_mse_results:
|
| 1311 |
+
refined_mse_stats = {}
|
| 1312 |
+
refined_mae_stats = {}
|
| 1313 |
+
for dim in dimensions:
|
| 1314 |
+
mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1315 |
+
mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1316 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1317 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1318 |
+
if mse_clean:
|
| 1319 |
+
refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1320 |
+
if mae_clean:
|
| 1321 |
+
refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1322 |
+
summary['refined_mse_statistics'] = refined_mse_stats
|
| 1323 |
+
summary['refined_mae_statistics'] = refined_mae_stats
|
| 1324 |
+
|
| 1325 |
+
if initial_results:
|
| 1326 |
+
initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
|
| 1327 |
+
if initial_mse_results:
|
| 1328 |
+
initial_mse_stats = {}
|
| 1329 |
+
initial_mae_stats = {}
|
| 1330 |
+
for dim in dimensions:
|
| 1331 |
+
mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1332 |
+
mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1333 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1334 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1335 |
+
if mse_clean:
|
| 1336 |
+
initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1337 |
+
if mae_clean:
|
| 1338 |
+
initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1339 |
+
summary['initial_mse_statistics'] = initial_mse_stats
|
| 1340 |
+
summary['initial_mae_statistics'] = initial_mae_stats
|
| 1341 |
+
|
| 1342 |
+
# Calculate Spearman correlations
|
| 1343 |
+
def filter_valid_pairs(true_list, pred_list):
|
| 1344 |
+
filtered_true = []
|
| 1345 |
+
filtered_pred = []
|
| 1346 |
+
for t, p in zip(true_list, pred_list):
|
| 1347 |
+
if (t is not None and p is not None and
|
| 1348 |
+
not (isinstance(t, float) and math.isnan(t)) and
|
| 1349 |
+
not (isinstance(p, float) and math.isnan(p))):
|
| 1350 |
+
filtered_true.append(t)
|
| 1351 |
+
filtered_pred.append(p)
|
| 1352 |
+
return filtered_true, filtered_pred
|
| 1353 |
+
|
| 1354 |
+
# Calculate Spearman correlations
|
| 1355 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1356 |
+
# For other formats, use all results
|
| 1357 |
+
if refined_format_count > 0:
|
| 1358 |
+
# Calculate refined spearman correlations
|
| 1359 |
+
refined_spearman_stats = {}
|
| 1360 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1361 |
+
for dim in dimensions:
|
| 1362 |
+
true_values = [r.get(f'gt_{dim}') for r in refined_results]
|
| 1363 |
+
pred_values = [r.get(f'model_{dim}') for r in refined_results]
|
| 1364 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1365 |
+
|
| 1366 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1367 |
+
try:
|
| 1368 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1369 |
+
if not math.isnan(corr):
|
| 1370 |
+
refined_spearman_stats[dim] = {
|
| 1371 |
+
'correlation': corr,
|
| 1372 |
+
'count': len(true_clean)
|
| 1373 |
+
}
|
| 1374 |
+
except Exception:
|
| 1375 |
+
pass
|
| 1376 |
+
|
| 1377 |
+
# Calculate initial spearman correlations
|
| 1378 |
+
initial_spearman_stats = {}
|
| 1379 |
+
for dim in dimensions:
|
| 1380 |
+
true_values = [r.get(f'gt_{dim}') for r in initial_results]
|
| 1381 |
+
pred_values = [r.get(f'model_{dim}') for r in initial_results]
|
| 1382 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1383 |
+
|
| 1384 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1385 |
+
try:
|
| 1386 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1387 |
+
if not math.isnan(corr):
|
| 1388 |
+
initial_spearman_stats[dim] = {
|
| 1389 |
+
'correlation': corr,
|
| 1390 |
+
'count': len(true_clean)
|
| 1391 |
+
}
|
| 1392 |
+
except Exception:
|
| 1393 |
+
pass
|
| 1394 |
+
|
| 1395 |
+
# Use refined for overall statistics (avoid double counting)
|
| 1396 |
+
summary['spearman_correlations'] = refined_spearman_stats
|
| 1397 |
+
summary['refined_spearman_correlations'] = refined_spearman_stats
|
| 1398 |
+
summary['initial_spearman_correlations'] = initial_spearman_stats
|
| 1399 |
+
else:
|
| 1400 |
+
# Original/other formats: use all results
|
| 1401 |
+
correlation_results = valid_results
|
| 1402 |
+
spearman_stats = {}
|
| 1403 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1404 |
+
for dim in dimensions:
|
| 1405 |
+
true_values = [r.get(f'gt_{dim}') for r in correlation_results]
|
| 1406 |
+
pred_values = [r.get(f'model_{dim}') for r in correlation_results]
|
| 1407 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1408 |
+
|
| 1409 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1410 |
+
try:
|
| 1411 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1412 |
+
if not math.isnan(corr):
|
| 1413 |
+
spearman_stats[dim] = {
|
| 1414 |
+
'correlation': corr,
|
| 1415 |
+
'count': len(true_clean)
|
| 1416 |
+
}
|
| 1417 |
+
except Exception:
|
| 1418 |
+
pass
|
| 1419 |
+
|
| 1420 |
+
summary['spearman_correlations'] = spearman_stats
|
| 1421 |
+
|
| 1422 |
+
# Calculate Decision metrics
|
| 1423 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1424 |
+
# For other formats, use all results
|
| 1425 |
+
if refined_format_count > 0:
|
| 1426 |
+
# Calculate refined decision metrics
|
| 1427 |
+
refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1428 |
+
if refined_decision_results:
|
| 1429 |
+
true_decisions = []
|
| 1430 |
+
pred_decisions = []
|
| 1431 |
+
decision_acc = []
|
| 1432 |
+
|
| 1433 |
+
for r in refined_decision_results:
|
| 1434 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1435 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1436 |
+
|
| 1437 |
+
if 'accept' in pred_decision:
|
| 1438 |
+
pred_binary = 1
|
| 1439 |
+
else:
|
| 1440 |
+
pred_binary = 0
|
| 1441 |
+
|
| 1442 |
+
if 'accept' in gt_decision:
|
| 1443 |
+
gt_binary = 1
|
| 1444 |
+
else:
|
| 1445 |
+
gt_binary = 0
|
| 1446 |
+
|
| 1447 |
+
true_decisions.append(gt_binary)
|
| 1448 |
+
pred_decisions.append(pred_binary)
|
| 1449 |
+
|
| 1450 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1451 |
+
decision_acc.append(1.0)
|
| 1452 |
+
else:
|
| 1453 |
+
decision_acc.append(0.0)
|
| 1454 |
+
|
| 1455 |
+
if decision_acc:
|
| 1456 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1457 |
+
try:
|
| 1458 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1459 |
+
refined_decision_metrics = {
|
| 1460 |
+
'accuracy': decision_accuracy,
|
| 1461 |
+
'f1_macro': f1_score,
|
| 1462 |
+
'count': len(decision_acc)
|
| 1463 |
+
}
|
| 1464 |
+
except Exception:
|
| 1465 |
+
refined_decision_metrics = {
|
| 1466 |
+
'accuracy': decision_accuracy,
|
| 1467 |
+
'count': len(decision_acc)
|
| 1468 |
+
}
|
| 1469 |
+
summary['refined_decision_metrics'] = refined_decision_metrics
|
| 1470 |
+
summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
|
| 1471 |
+
|
| 1472 |
+
# Calculate initial decision metrics
|
| 1473 |
+
initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1474 |
+
if initial_decision_results:
|
| 1475 |
+
true_decisions = []
|
| 1476 |
+
pred_decisions = []
|
| 1477 |
+
decision_acc = []
|
| 1478 |
+
|
| 1479 |
+
for r in initial_decision_results:
|
| 1480 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1481 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1482 |
+
|
| 1483 |
+
if 'accept' in pred_decision:
|
| 1484 |
+
pred_binary = 1
|
| 1485 |
+
else:
|
| 1486 |
+
pred_binary = 0
|
| 1487 |
+
|
| 1488 |
+
if 'accept' in gt_decision:
|
| 1489 |
+
gt_binary = 1
|
| 1490 |
+
else:
|
| 1491 |
+
gt_binary = 0
|
| 1492 |
+
|
| 1493 |
+
true_decisions.append(gt_binary)
|
| 1494 |
+
pred_decisions.append(pred_binary)
|
| 1495 |
+
|
| 1496 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1497 |
+
decision_acc.append(1.0)
|
| 1498 |
+
else:
|
| 1499 |
+
decision_acc.append(0.0)
|
| 1500 |
+
|
| 1501 |
+
if decision_acc:
|
| 1502 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1503 |
+
try:
|
| 1504 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1505 |
+
initial_decision_metrics = {
|
| 1506 |
+
'accuracy': decision_accuracy,
|
| 1507 |
+
'f1_macro': f1_score,
|
| 1508 |
+
'count': len(decision_acc)
|
| 1509 |
+
}
|
| 1510 |
+
except Exception:
|
| 1511 |
+
initial_decision_metrics = {
|
| 1512 |
+
'accuracy': decision_accuracy,
|
| 1513 |
+
'count': len(decision_acc)
|
| 1514 |
+
}
|
| 1515 |
+
summary['initial_decision_metrics'] = initial_decision_metrics
|
| 1516 |
+
else:
|
| 1517 |
+
# Original/other formats: use all results
|
| 1518 |
+
decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1519 |
+
if decision_results:
|
| 1520 |
+
true_decisions = []
|
| 1521 |
+
pred_decisions = []
|
| 1522 |
+
decision_acc = []
|
| 1523 |
+
|
| 1524 |
+
for r in decision_results:
|
| 1525 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1526 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1527 |
+
|
| 1528 |
+
if 'accept' in pred_decision:
|
| 1529 |
+
pred_binary = 1
|
| 1530 |
+
else:
|
| 1531 |
+
pred_binary = 0
|
| 1532 |
+
|
| 1533 |
+
if 'accept' in gt_decision:
|
| 1534 |
+
gt_binary = 1
|
| 1535 |
+
else:
|
| 1536 |
+
gt_binary = 0
|
| 1537 |
+
|
| 1538 |
+
true_decisions.append(gt_binary)
|
| 1539 |
+
pred_decisions.append(pred_binary)
|
| 1540 |
+
|
| 1541 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1542 |
+
decision_acc.append(1.0)
|
| 1543 |
+
else:
|
| 1544 |
+
decision_acc.append(0.0)
|
| 1545 |
+
|
| 1546 |
+
if decision_acc:
|
| 1547 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1548 |
+
try:
|
| 1549 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1550 |
+
summary['decision_metrics'] = {
|
| 1551 |
+
'accuracy': decision_accuracy,
|
| 1552 |
+
'f1_macro': f1_score,
|
| 1553 |
+
'count': len(decision_acc)
|
| 1554 |
+
}
|
| 1555 |
+
except Exception:
|
| 1556 |
+
summary['decision_metrics'] = {
|
| 1557 |
+
'accuracy': decision_accuracy,
|
| 1558 |
+
'count': len(decision_acc)
|
| 1559 |
+
}
|
| 1560 |
+
|
| 1561 |
+
# Calculate Pairwise comparison
|
| 1562 |
+
# For refined format, only use refined results (avoid double counting)
|
| 1563 |
+
# For other formats, use all results
|
| 1564 |
+
if refined_format_count > 0:
|
| 1565 |
+
pairwise_results = refined_results
|
| 1566 |
+
else:
|
| 1567 |
+
pairwise_results = valid_results
|
| 1568 |
+
|
| 1569 |
+
paper_scores = []
|
| 1570 |
+
for r in pairwise_results:
|
| 1571 |
+
if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
|
| 1572 |
+
(r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
|
| 1573 |
+
paper_scores.append({
|
| 1574 |
+
'true_rating': r.get('gt_rating'),
|
| 1575 |
+
'pred_rating': r.get('model_rating'),
|
| 1576 |
+
'true_soundness': r.get('gt_soundness'),
|
| 1577 |
+
'pred_soundness': r.get('model_soundness'),
|
| 1578 |
+
'true_presentation': r.get('gt_presentation'),
|
| 1579 |
+
'pred_presentation': r.get('model_presentation'),
|
| 1580 |
+
'true_confidence': r.get('gt_confidence'),
|
| 1581 |
+
'pred_confidence': r.get('model_confidence')
|
| 1582 |
+
})
|
| 1583 |
+
|
| 1584 |
+
if len(paper_scores) >= 2:
|
| 1585 |
+
pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
|
| 1586 |
+
summary['pairwise_accuracies'] = pairwise_accuracies
|
| 1587 |
+
|
| 1588 |
+
return results, summary
|
| 1589 |
+
|
| 1590 |
+
|
| 1591 |
+
# ============================================================================
|
| 1592 |
+
# Main Function
|
| 1593 |
+
# ============================================================================
|
| 1594 |
+
|
| 1595 |
+
def parse_args():
|
| 1596 |
+
"""Parse command line arguments."""
|
| 1597 |
+
parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
|
| 1598 |
+
|
| 1599 |
+
# Input paths
|
| 1600 |
+
parser.add_argument("--rubrics_path", type=str, required=True,
|
| 1601 |
+
help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
|
| 1602 |
+
parser.add_argument("--reviews_path", type=str, required=True,
|
| 1603 |
+
help="Path to JSON file with model reviews (contains pred_fast_mode)")
|
| 1604 |
+
|
| 1605 |
+
# Evaluation mode
|
| 1606 |
+
parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
|
| 1607 |
+
help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
|
| 1608 |
+
|
| 1609 |
+
# Output paths
|
| 1610 |
+
parser.add_argument("--semantic_output", type=str, default=None,
|
| 1611 |
+
help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
|
| 1612 |
+
parser.add_argument("--auto_metric_output", type=str, default=None,
|
| 1613 |
+
help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
|
| 1614 |
+
|
| 1615 |
+
# Semantic evaluation settings
|
| 1616 |
+
parser.add_argument("--yaml_path", type=str, default=None,
|
| 1617 |
+
help="Path to prompts.yaml file (required for semantic evaluation)")
|
| 1618 |
+
parser.add_argument("--config_path", type=str, default=None,
|
| 1619 |
+
help="Path to configs.yaml file (required for semantic evaluation)")
|
| 1620 |
+
|
| 1621 |
+
# Multi-threading
|
| 1622 |
+
parser.add_argument("--max_workers", type=int, default=None,
|
| 1623 |
+
help="Maximum number of worker threads for semantic evaluation (default: 5)")
|
| 1624 |
+
|
| 1625 |
+
# Strict mode (normalize scores to discrete scales)
|
| 1626 |
+
parser.add_argument("--strict_mode", action="store_true", default=False,
|
| 1627 |
+
help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
|
| 1628 |
+
|
| 1629 |
+
# Input format override
|
| 1630 |
+
parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
|
| 1631 |
+
help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
|
| 1632 |
+
|
| 1633 |
+
return parser.parse_args()
|
| 1634 |
+
|
| 1635 |
+
|
| 1636 |
+
def main():
|
| 1637 |
+
"""Main execution function."""
|
| 1638 |
+
args = parse_args()
|
| 1639 |
+
|
| 1640 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 1641 |
+
|
| 1642 |
+
# Resolve paths
|
| 1643 |
+
rubrics_path = args.rubrics_path
|
| 1644 |
+
if not os.path.isabs(rubrics_path):
|
| 1645 |
+
rubrics_path = os.path.join(script_dir, rubrics_path)
|
| 1646 |
+
|
| 1647 |
+
reviews_path = args.reviews_path
|
| 1648 |
+
if not os.path.isabs(reviews_path):
|
| 1649 |
+
reviews_path = os.path.join(script_dir, reviews_path)
|
| 1650 |
+
|
| 1651 |
+
max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
|
| 1652 |
+
|
| 1653 |
+
# Validate mode and output paths
|
| 1654 |
+
if args.mode in ["semantic", "both"]:
|
| 1655 |
+
if not args.semantic_output:
|
| 1656 |
+
raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
|
| 1657 |
+
if not args.yaml_path:
|
| 1658 |
+
raise ValueError("--yaml_path is required for semantic evaluation")
|
| 1659 |
+
if not args.config_path:
|
| 1660 |
+
raise ValueError("--config_path is required for semantic evaluation")
|
| 1661 |
+
|
| 1662 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1663 |
+
if not args.auto_metric_output:
|
| 1664 |
+
raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
|
| 1665 |
+
|
| 1666 |
+
# Check if files exist
|
| 1667 |
+
if not os.path.exists(rubrics_path):
|
| 1668 |
+
raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
|
| 1669 |
+
if not os.path.exists(reviews_path):
|
| 1670 |
+
raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
|
| 1671 |
+
|
| 1672 |
+
# Load data
|
| 1673 |
+
print(f"Loading rubrics from {rubrics_path}...")
|
| 1674 |
+
rubrics_data = load_rubrics_json(rubrics_path)
|
| 1675 |
+
print(f"Loaded {len(rubrics_data)} rubrics entries")
|
| 1676 |
+
|
| 1677 |
+
print(f"Loading model reviews from {reviews_path}...")
|
| 1678 |
+
if args.input_format != 'auto':
|
| 1679 |
+
print(f"Using manually specified format: {args.input_format}")
|
| 1680 |
+
else:
|
| 1681 |
+
print("Auto-detecting input format...")
|
| 1682 |
+
reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
|
| 1683 |
+
print(f"Loaded {len(reviews_dict)} model reviews")
|
| 1684 |
+
|
| 1685 |
+
# Combine rubrics and reviews
|
| 1686 |
+
print("Combining rubrics and reviews...")
|
| 1687 |
+
evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
|
| 1688 |
+
print(f"Prepared {len(evaluation_data)} entries for evaluation")
|
| 1689 |
+
|
| 1690 |
+
# Run evaluations based on mode
|
| 1691 |
+
if args.mode in ["semantic", "both"]:
|
| 1692 |
+
# Resolve semantic evaluation paths
|
| 1693 |
+
yaml_path = args.yaml_path
|
| 1694 |
+
if not os.path.isabs(yaml_path):
|
| 1695 |
+
yaml_path = os.path.join(script_dir, yaml_path)
|
| 1696 |
+
|
| 1697 |
+
config_path = args.config_path
|
| 1698 |
+
if not os.path.isabs(config_path):
|
| 1699 |
+
config_path = os.path.join(script_dir, config_path)
|
| 1700 |
+
|
| 1701 |
+
if not os.path.exists(yaml_path):
|
| 1702 |
+
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
|
| 1703 |
+
if not os.path.exists(config_path):
|
| 1704 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 1705 |
+
|
| 1706 |
+
# Load prompt template
|
| 1707 |
+
print(f"Loading prompt template from {yaml_path}...")
|
| 1708 |
+
prompt_template = load_prompt_template(yaml_path)
|
| 1709 |
+
if not prompt_template:
|
| 1710 |
+
raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
|
| 1711 |
+
|
| 1712 |
+
# Initialize LLM service
|
| 1713 |
+
print(f"Loading LLM configuration from {config_path}...")
|
| 1714 |
+
llm_config = load_llm_config(config_path)
|
| 1715 |
+
llm_service = create_llm_service_from_config(llm_config)
|
| 1716 |
+
mode = llm_config.get('mode', 'gpt')
|
| 1717 |
+
print(f"LLM service initialized (mode: {mode})")
|
| 1718 |
+
if hasattr(llm_service, 'model_name'):
|
| 1719 |
+
print(f"Using model: {llm_service.model_name}")
|
| 1720 |
+
|
| 1721 |
+
# Run semantic evaluation
|
| 1722 |
+
semantic_results, semantic_summary = run_semantic_evaluation(
|
| 1723 |
+
evaluation_data, prompt_template, llm_service, max_workers
|
| 1724 |
+
)
|
| 1725 |
+
|
| 1726 |
+
# Save semantic results
|
| 1727 |
+
semantic_output = args.semantic_output
|
| 1728 |
+
if not os.path.isabs(semantic_output):
|
| 1729 |
+
semantic_output = os.path.join(script_dir, semantic_output)
|
| 1730 |
+
|
| 1731 |
+
output_dir = os.path.dirname(semantic_output)
|
| 1732 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1733 |
+
|
| 1734 |
+
with open(semantic_output, 'w', encoding='utf-8') as f:
|
| 1735 |
+
json.dump(semantic_results, f, ensure_ascii=False, indent=2)
|
| 1736 |
+
print(f"\nSemantic evaluation results saved to {semantic_output}")
|
| 1737 |
+
|
| 1738 |
+
# Save semantic summary
|
| 1739 |
+
semantic_summary_path = semantic_output.replace('.json', '_summary.json')
|
| 1740 |
+
with open(semantic_summary_path, 'w', encoding='utf-8') as f:
|
| 1741 |
+
json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
|
| 1742 |
+
print(f"Semantic evaluation summary saved to {semantic_summary_path}")
|
| 1743 |
+
|
| 1744 |
+
# Print semantic summary
|
| 1745 |
+
print("\n" + "="*80)
|
| 1746 |
+
print("SEMANTIC EVALUATION SUMMARY")
|
| 1747 |
+
print("="*80)
|
| 1748 |
+
print(f"Total entries: {semantic_summary['total_entries']}")
|
| 1749 |
+
print(f"Valid entries: {semantic_summary['valid_entries']}")
|
| 1750 |
+
print(f"Failed entries: {semantic_summary['failed_entries']}")
|
| 1751 |
+
if 'overall_score' in semantic_summary:
|
| 1752 |
+
score = semantic_summary['overall_score']
|
| 1753 |
+
print(f"\nOverall Score:")
|
| 1754 |
+
print(f" Mean: {score['mean']:.2f}")
|
| 1755 |
+
print(f" Min: {score['min']:.2f}")
|
| 1756 |
+
print(f" Max: {score['max']:.2f}")
|
| 1757 |
+
|
| 1758 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1759 |
+
# Run auto-metric evaluation
|
| 1760 |
+
auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
|
| 1761 |
+
evaluation_data,
|
| 1762 |
+
strict_mode=args.strict_mode
|
| 1763 |
+
)
|
| 1764 |
+
|
| 1765 |
+
# Save auto-metric results
|
| 1766 |
+
auto_metric_output = args.auto_metric_output
|
| 1767 |
+
if not os.path.isabs(auto_metric_output):
|
| 1768 |
+
auto_metric_output = os.path.join(script_dir, auto_metric_output)
|
| 1769 |
+
|
| 1770 |
+
output_dir = os.path.dirname(auto_metric_output)
|
| 1771 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1772 |
+
|
| 1773 |
+
with open(auto_metric_output, 'w', encoding='utf-8') as f:
|
| 1774 |
+
json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
|
| 1775 |
+
print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
|
| 1776 |
+
|
| 1777 |
+
# Save auto-metric summary
|
| 1778 |
+
auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
|
| 1779 |
+
with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
|
| 1780 |
+
json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
|
| 1781 |
+
print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
|
| 1782 |
+
|
| 1783 |
+
# Print auto-metric summary
|
| 1784 |
+
print("\n" + "="*80)
|
| 1785 |
+
print("AUTO-METRIC EVALUATION SUMMARY")
|
| 1786 |
+
print("="*80)
|
| 1787 |
+
print(f"Total entries: {auto_metric_summary['total_entries']}")
|
| 1788 |
+
print(f"Valid entries: {auto_metric_summary['valid_entries']}")
|
| 1789 |
+
print(f"MSE entries: {auto_metric_summary['mse_entries']}")
|
| 1790 |
+
|
| 1791 |
+
if 'mse_statistics' in auto_metric_summary:
|
| 1792 |
+
print("\nMSE Statistics:")
|
| 1793 |
+
for dim, stats in auto_metric_summary['mse_statistics'].items():
|
| 1794 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1795 |
+
|
| 1796 |
+
if 'mae_statistics' in auto_metric_summary:
|
| 1797 |
+
print("\nMAE Statistics:")
|
| 1798 |
+
for dim, stats in auto_metric_summary['mae_statistics'].items():
|
| 1799 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1800 |
+
|
| 1801 |
+
# Print refined and initial statistics if available
|
| 1802 |
+
if 'refined_mse_statistics' in auto_metric_summary:
|
| 1803 |
+
print("\nRefined Scores - MSE Statistics:")
|
| 1804 |
+
for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
|
| 1805 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1806 |
+
|
| 1807 |
+
if 'refined_mae_statistics' in auto_metric_summary:
|
| 1808 |
+
print("\nRefined Scores - MAE Statistics:")
|
| 1809 |
+
for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
|
| 1810 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1811 |
+
|
| 1812 |
+
if 'initial_mse_statistics' in auto_metric_summary:
|
| 1813 |
+
print("\nInitial Scores - MSE Statistics:")
|
| 1814 |
+
for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
|
| 1815 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1816 |
+
|
| 1817 |
+
if 'initial_mae_statistics' in auto_metric_summary:
|
| 1818 |
+
print("\nInitial Scores - MAE Statistics:")
|
| 1819 |
+
for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
|
| 1820 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1821 |
+
|
| 1822 |
+
if 'spearman_correlations' in auto_metric_summary:
|
| 1823 |
+
print("\nSpearman Correlations:")
|
| 1824 |
+
for dim, stats in auto_metric_summary['spearman_correlations'].items():
|
| 1825 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1826 |
+
|
| 1827 |
+
# Print refined and initial spearman correlations if available
|
| 1828 |
+
if 'refined_spearman_correlations' in auto_metric_summary:
|
| 1829 |
+
print("\nRefined Scores - Spearman Correlations:")
|
| 1830 |
+
for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
|
| 1831 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1832 |
+
|
| 1833 |
+
if 'initial_spearman_correlations' in auto_metric_summary:
|
| 1834 |
+
print("\nInitial Scores - Spearman Correlations:")
|
| 1835 |
+
for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
|
| 1836 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1837 |
+
|
| 1838 |
+
if 'decision_metrics' in auto_metric_summary:
|
| 1839 |
+
dm = auto_metric_summary['decision_metrics']
|
| 1840 |
+
print(f"\nDecision Metrics:")
|
| 1841 |
+
print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
|
| 1842 |
+
if 'f1_macro' in dm:
|
| 1843 |
+
print(f" F1 (macro): {dm['f1_macro']:.4f}")
|
| 1844 |
+
|
| 1845 |
+
# Print refined and initial decision metrics if available
|
| 1846 |
+
if 'refined_decision_metrics' in auto_metric_summary:
|
| 1847 |
+
print("\nRefined Scores - Decision Metrics:")
|
| 1848 |
+
rdm = auto_metric_summary['refined_decision_metrics']
|
| 1849 |
+
print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
|
| 1850 |
+
if 'f1_macro' in rdm:
|
| 1851 |
+
print(f" F1 (macro): {rdm['f1_macro']:.4f}")
|
| 1852 |
+
|
| 1853 |
+
if 'initial_decision_metrics' in auto_metric_summary:
|
| 1854 |
+
print("\nInitial Scores - Decision Metrics:")
|
| 1855 |
+
idm = auto_metric_summary['initial_decision_metrics']
|
| 1856 |
+
print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
|
| 1857 |
+
if 'f1_macro' in idm:
|
| 1858 |
+
print(f" F1 (macro): {idm['f1_macro']:.4f}")
|
| 1859 |
+
|
| 1860 |
+
print("\n" + "="*80)
|
| 1861 |
+
print("EVALUATION COMPLETE")
|
| 1862 |
+
print("="*80)
|
| 1863 |
+
|
| 1864 |
+
|
| 1865 |
+
if __name__ == "__main__":
|
| 1866 |
+
main()
|
src/evaluator/2_evaluate_aiscientist.py
ADDED
|
@@ -0,0 +1,1866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
|
| 6 |
+
2. Reads input JSON file containing model reviews (supports multiple formats)
|
| 7 |
+
3. Supports three evaluation modes:
|
| 8 |
+
- semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
|
| 9 |
+
- auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
|
| 10 |
+
- both: Run both evaluations separately
|
| 11 |
+
4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
|
| 12 |
+
5. Outputs separate JSON files for results and summaries
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Semantic evaluation only
|
| 16 |
+
python 2_evaluate.py \
|
| 17 |
+
--rubrics_path eval_rubrics.json \
|
| 18 |
+
--reviews_path model_reviews.json \
|
| 19 |
+
--mode semantic \
|
| 20 |
+
--yaml_path prompts.yaml \
|
| 21 |
+
--config_path configs.yaml \
|
| 22 |
+
--semantic_output semantic_results.json \
|
| 23 |
+
--max_workers 5
|
| 24 |
+
|
| 25 |
+
# Auto-metric evaluation only
|
| 26 |
+
python 2_evaluate.py \
|
| 27 |
+
--rubrics_path eval_rubrics.json \
|
| 28 |
+
--reviews_path model_reviews.json \
|
| 29 |
+
--mode auto_metric \
|
| 30 |
+
--auto_metric_output auto_metric_results.json
|
| 31 |
+
|
| 32 |
+
# Auto-metric evaluation with strict mode (normalize scores to discrete scales)
|
| 33 |
+
python 2_evaluate.py \
|
| 34 |
+
--rubrics_path eval_rubrics.json \
|
| 35 |
+
--reviews_path model_reviews.json \
|
| 36 |
+
--mode auto_metric \
|
| 37 |
+
--auto_metric_output auto_metric_results.json \
|
| 38 |
+
--strict_mode
|
| 39 |
+
|
| 40 |
+
# Auto-metric evaluation with manually specified input format (refined)
|
| 41 |
+
python 2_evaluate.py \
|
| 42 |
+
--rubrics_path eval_rubrics.json \
|
| 43 |
+
--reviews_path model_reviews.json \
|
| 44 |
+
--mode auto_metric \
|
| 45 |
+
--auto_metric_output auto_metric_results.json \
|
| 46 |
+
--input_format refined
|
| 47 |
+
|
| 48 |
+
# Auto-metric evaluation with manually specified input format (original)
|
| 49 |
+
python 2_evaluate.py \
|
| 50 |
+
--rubrics_path eval_rubrics.json \
|
| 51 |
+
--reviews_path ours.json \
|
| 52 |
+
--mode auto_metric \
|
| 53 |
+
--auto_metric_output auto_metric_results.json \
|
| 54 |
+
--input_format original
|
| 55 |
+
|
| 56 |
+
# Both evaluations
|
| 57 |
+
python 2_evaluate.py \
|
| 58 |
+
--rubrics_path eval_rubrics.json \
|
| 59 |
+
--reviews_path model_reviews.json \
|
| 60 |
+
--mode both \
|
| 61 |
+
--yaml_path prompts.yaml \
|
| 62 |
+
--config_path configs.yaml \
|
| 63 |
+
--semantic_output semantic_results.json \
|
| 64 |
+
--auto_metric_output auto_metric_results.json \
|
| 65 |
+
--max_workers 32
|
| 66 |
+
"""
|
| 67 |
+
from __future__ import annotations
|
| 68 |
+
|
| 69 |
+
import json
|
| 70 |
+
import os
|
| 71 |
+
import sys
|
| 72 |
+
import argparse
|
| 73 |
+
import yaml
|
| 74 |
+
import math
|
| 75 |
+
import re
|
| 76 |
+
from typing import Dict, List, Any, Optional
|
| 77 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 78 |
+
from tqdm import tqdm
|
| 79 |
+
from itertools import combinations
|
| 80 |
+
from scipy.stats import spearmanr
|
| 81 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 82 |
+
|
| 83 |
+
# Add parent directory to path
|
| 84 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 85 |
+
# Import parse_llm_response from local llm_service module
|
| 86 |
+
import llm_service as local_llm_service
|
| 87 |
+
parse_llm_response = local_llm_service.parse_llm_response
|
| 88 |
+
|
| 89 |
+
# Import from shared/utils for gpt/vllm support
|
| 90 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 91 |
+
if project_root not in sys.path:
|
| 92 |
+
sys.path.insert(0, project_root)
|
| 93 |
+
|
| 94 |
+
from shared.utils.llm_service import LLMService
|
| 95 |
+
from shared.utils.vllm_service import VLLMService
|
| 96 |
+
from shared.utils.gpt_service import GPTService
|
| 97 |
+
sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
|
| 98 |
+
from json_parser import parse_review_markdown
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def convert_ai_researcher(review: dict) -> str:
|
| 102 |
+
"""
|
| 103 |
+
Convert the review text from ai-researcher format to unified review system format.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
summary = review["Summary"]
|
| 107 |
+
strengths = "\n".join(f"- {s}" for s in review["Strengths"])
|
| 108 |
+
weaknesses = "\n".join(f"- {w}" for w in review["Weaknesses"])
|
| 109 |
+
|
| 110 |
+
# scores
|
| 111 |
+
originality = review["Originality"]
|
| 112 |
+
quality = review["Quality"]
|
| 113 |
+
clarity = review["Clarity"]
|
| 114 |
+
significance = review["Significance"]
|
| 115 |
+
|
| 116 |
+
questions = "\n".join(f"- {q}" for q in review["Questions"])
|
| 117 |
+
limitations = "\n".join(f"- {l}" for l in review["Limitations"])
|
| 118 |
+
ethical_concerns = review["Ethical Concerns"]
|
| 119 |
+
|
| 120 |
+
# scores again
|
| 121 |
+
soundness = review["Soundness"]
|
| 122 |
+
presentation = review["Presentation"]
|
| 123 |
+
contribution = review["Contribution"]
|
| 124 |
+
overall = review["Overall"]
|
| 125 |
+
confidence = review["Confidence"]
|
| 126 |
+
|
| 127 |
+
# final decision
|
| 128 |
+
decision = review["Decision"]
|
| 129 |
+
|
| 130 |
+
meta_review = {
|
| 131 |
+
"rating": overall,
|
| 132 |
+
"soundness": soundness,
|
| 133 |
+
"presentation": presentation,
|
| 134 |
+
"contribution": contribution,
|
| 135 |
+
"confidence": confidence,
|
| 136 |
+
"decision": decision.lower().strip(),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return f"Summary: {summary}\nStrengths: {strengths}\nWeaknesses: {weaknesses}\nOriginality: {originality}\nQuality: {quality}\nClarity: {clarity}\nSignificance: {significance}\nQuestions: {questions}\nLimitations: {limitations}\nEthical Concerns: {ethical_concerns}\nSoundness: {soundness}\nPresentation: {presentation}\nContribution: {contribution}\nOverall: {overall}\nConfidence: {confidence}\nDecision: {decision}", meta_review
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def convert_agenticreview(review_text: str) -> tuple:
|
| 143 |
+
"""
|
| 144 |
+
Convert the review text from agenticreview format to unified review system format.
|
| 145 |
+
|
| 146 |
+
The agenticreview format has text like:
|
| 147 |
+
"Overall rating: 5\n\nSignificance and novelty: ..."
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
review_text: Raw review text string
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Tuple of (formatted_review_text, meta_review_dict)
|
| 154 |
+
"""
|
| 155 |
+
# Extract rating from "Overall rating: x" format
|
| 156 |
+
rating = None
|
| 157 |
+
rating_match = re.search(r'Overall\s+rating\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
|
| 158 |
+
if rating_match:
|
| 159 |
+
try:
|
| 160 |
+
rating = float(rating_match.group(1))
|
| 161 |
+
except (ValueError, IndexError):
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
# If not found, try alternative patterns
|
| 165 |
+
if rating is None:
|
| 166 |
+
rating_match = re.search(r'(?:rating|score)\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
|
| 167 |
+
if rating_match:
|
| 168 |
+
try:
|
| 169 |
+
rating = float(rating_match.group(1))
|
| 170 |
+
except (ValueError, IndexError):
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
# Try to extract from parse_review_markdown as fallback
|
| 174 |
+
if rating is None:
|
| 175 |
+
try:
|
| 176 |
+
parsed = parse_review_markdown(review_text)
|
| 177 |
+
rating = parsed.get('rating')
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
# Create meta_review dict - agenticreview only has rating, no other scores
|
| 182 |
+
meta_review = {
|
| 183 |
+
"rating": rating,
|
| 184 |
+
"soundness": None,
|
| 185 |
+
"presentation": None,
|
| 186 |
+
"contribution": None,
|
| 187 |
+
"confidence": None,
|
| 188 |
+
"decision": None,
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Return the review text as-is (it's already in a readable format)
|
| 192 |
+
return review_text, meta_review
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ReviewProcessor:
|
| 196 |
+
"""Handles the extraction and processing of reviews from different sources."""
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def extract_review_content(pred_context):
|
| 200 |
+
"""
|
| 201 |
+
Extract the review content from the prediction context.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
pred_context: Raw prediction data that contains the review
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
str: Extracted review content
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# First attempt to extract from boxed format
|
| 211 |
+
return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 212 |
+
except Exception:
|
| 213 |
+
# Alternative extraction if the first method fails
|
| 214 |
+
if isinstance(pred_context, dict) and 'output' in pred_context:
|
| 215 |
+
return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 216 |
+
else:
|
| 217 |
+
# Return as is if extraction fails
|
| 218 |
+
return pred_context
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ============================================================================
|
| 222 |
+
# Semantic Evaluation Functions (from 2_evaluate_direct.py)
|
| 223 |
+
# ============================================================================
|
| 224 |
+
|
| 225 |
+
def load_prompt_template(yaml_path: str) -> str:
|
| 226 |
+
"""Load the evaluator prompt from YAML file."""
|
| 227 |
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
| 228 |
+
prompts = yaml.safe_load(f)
|
| 229 |
+
return prompts.get('v1_evaluator_prompt', '')
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def build_evaluation_prompt(
|
| 233 |
+
rubrics: List[Dict[str, Any]],
|
| 234 |
+
paper_content: str,
|
| 235 |
+
review: str,
|
| 236 |
+
prompt_template: str
|
| 237 |
+
) -> str:
|
| 238 |
+
"""Build the evaluation prompt by replacing placeholders."""
|
| 239 |
+
rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
|
| 240 |
+
prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
|
| 241 |
+
prompt = prompt.replace('<<paper_content>>', paper_content)
|
| 242 |
+
prompt = prompt.replace('<<review>>', review)
|
| 243 |
+
return prompt
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def calculate_weighted_scores(
|
| 247 |
+
raw_scores: Dict[str, Dict[str, Any]],
|
| 248 |
+
rubrics: List[Dict[str, Any]]
|
| 249 |
+
) -> Dict[str, float]:
|
| 250 |
+
"""Calculate weighted scores for each rubric."""
|
| 251 |
+
rubric_weights = {r['title']: r['weight'] for r in rubrics}
|
| 252 |
+
weighted_scores = {}
|
| 253 |
+
|
| 254 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 255 |
+
if rubric_title not in rubric_weights:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
rubric_score = rubric_data.get('score', 0)
|
| 259 |
+
if isinstance(rubric_score, str):
|
| 260 |
+
try:
|
| 261 |
+
rubric_score = int(rubric_score)
|
| 262 |
+
except ValueError:
|
| 263 |
+
rubric_score = 0
|
| 264 |
+
|
| 265 |
+
if rubric_score not in [0, 1]:
|
| 266 |
+
rubric_score = 1 if rubric_score > 0 else 0
|
| 267 |
+
|
| 268 |
+
weight = rubric_weights[rubric_title]
|
| 269 |
+
weighted_scores[rubric_title] = rubric_score * weight
|
| 270 |
+
|
| 271 |
+
return weighted_scores
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
|
| 275 |
+
"""Calculate scores for each rubric."""
|
| 276 |
+
scores = {}
|
| 277 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 278 |
+
scores[rubric_title] = rubric_data.get('score', 0)
|
| 279 |
+
return scores
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def evaluate_review_semantic(
|
| 283 |
+
entry: Dict[str, Any],
|
| 284 |
+
paper_content: str,
|
| 285 |
+
prompt_template: str,
|
| 286 |
+
llm_service: LLMService
|
| 287 |
+
) -> Dict[str, Any]:
|
| 288 |
+
"""Evaluate a single review using article-specific rubrics."""
|
| 289 |
+
entry_id = entry.get('id', 'unknown')
|
| 290 |
+
rubrics = entry.get('rubrics', [])
|
| 291 |
+
model_review = entry.get('model_review', '')
|
| 292 |
+
|
| 293 |
+
if not rubrics:
|
| 294 |
+
return {
|
| 295 |
+
'id': entry_id,
|
| 296 |
+
'raw_scores': {},
|
| 297 |
+
'weighted_scores': {},
|
| 298 |
+
'total_score': 0.0,
|
| 299 |
+
'error': 'No valid rubrics found',
|
| 300 |
+
'raw_response': ''
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
# Build prompt
|
| 304 |
+
prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
|
| 305 |
+
|
| 306 |
+
# Call LLM
|
| 307 |
+
try:
|
| 308 |
+
messages = [{"role": "user", "content": prompt}]
|
| 309 |
+
response = llm_service.generate(messages=messages)
|
| 310 |
+
|
| 311 |
+
# Parse response
|
| 312 |
+
raw_scores = parse_llm_response(response)
|
| 313 |
+
weighted_scores = calculate_scores(raw_scores)
|
| 314 |
+
total_score = sum(weighted_scores.values())
|
| 315 |
+
|
| 316 |
+
return {
|
| 317 |
+
'id': entry_id,
|
| 318 |
+
'raw_scores': raw_scores,
|
| 319 |
+
'weighted_scores': weighted_scores,
|
| 320 |
+
'total_score': total_score,
|
| 321 |
+
'raw_response': response
|
| 322 |
+
}
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f"[ERROR] Error evaluating review {entry_id}: {e}")
|
| 325 |
+
return {
|
| 326 |
+
'id': entry_id,
|
| 327 |
+
'raw_scores': {},
|
| 328 |
+
'weighted_scores': {},
|
| 329 |
+
'total_score': 0.0,
|
| 330 |
+
'error': str(e),
|
| 331 |
+
'raw_response': ''
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def calculate_per_rubric_statistics(
|
| 336 |
+
valid_results: List[Dict[str, Any]],
|
| 337 |
+
rubric_titles: List[str]
|
| 338 |
+
) -> Dict[str, Dict[str, float]]:
|
| 339 |
+
"""Calculate per-rubric statistics from evaluation results."""
|
| 340 |
+
rubric_scores = {title: [] for title in rubric_titles}
|
| 341 |
+
|
| 342 |
+
for result in valid_results:
|
| 343 |
+
weighted_scores = result.get('weighted_scores', {})
|
| 344 |
+
if not isinstance(weighted_scores, dict):
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
for rubric_title in rubric_titles:
|
| 348 |
+
if rubric_title in weighted_scores:
|
| 349 |
+
score = weighted_scores[rubric_title]
|
| 350 |
+
if isinstance(score, str):
|
| 351 |
+
try:
|
| 352 |
+
score = float(score)
|
| 353 |
+
except ValueError:
|
| 354 |
+
continue
|
| 355 |
+
elif isinstance(score, (int, float)):
|
| 356 |
+
score = float(score)
|
| 357 |
+
else:
|
| 358 |
+
continue
|
| 359 |
+
rubric_scores[rubric_title].append(score)
|
| 360 |
+
|
| 361 |
+
per_rubric_stats = {}
|
| 362 |
+
for rubric_title in rubric_titles:
|
| 363 |
+
scores = rubric_scores[rubric_title]
|
| 364 |
+
if not scores:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
mean_score = sum(scores) / len(scores)
|
| 368 |
+
min_score = min(scores)
|
| 369 |
+
max_score = max(scores)
|
| 370 |
+
count = len(scores)
|
| 371 |
+
|
| 372 |
+
if rubric_title == "False or Contradictory Claims":
|
| 373 |
+
pass_count = sum(1 for s in scores if s >= 0)
|
| 374 |
+
else:
|
| 375 |
+
pass_count = sum(1 for s in scores if s >= 1)
|
| 376 |
+
pass_rate = pass_count / count if count > 0 else 0.0
|
| 377 |
+
|
| 378 |
+
per_rubric_stats[rubric_title] = {
|
| 379 |
+
'mean': mean_score,
|
| 380 |
+
'min': min_score,
|
| 381 |
+
'max': max_score,
|
| 382 |
+
'count': count,
|
| 383 |
+
'pass_rate': pass_rate
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
return per_rubric_stats
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# ============================================================================
|
| 390 |
+
# Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
|
| 391 |
+
# ============================================================================
|
| 392 |
+
|
| 393 |
+
def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
|
| 394 |
+
"""Extract numeric scores and decision from a review markdown text."""
|
| 395 |
+
if not review_text:
|
| 396 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
parsed = parse_review_markdown(review_text)
|
| 400 |
+
decision = parsed.get('decision', '')
|
| 401 |
+
if decision:
|
| 402 |
+
decision_lower = decision.lower().strip()
|
| 403 |
+
if 'accept' in decision_lower:
|
| 404 |
+
decision = 'accept'
|
| 405 |
+
elif 'reject' in decision_lower:
|
| 406 |
+
decision = 'reject'
|
| 407 |
+
elif 'undecided' in decision_lower:
|
| 408 |
+
decision = 'undecided'
|
| 409 |
+
else:
|
| 410 |
+
decision = decision_lower
|
| 411 |
+
else:
|
| 412 |
+
decision = None
|
| 413 |
+
|
| 414 |
+
return {
|
| 415 |
+
'soundness': parsed.get('soundness'),
|
| 416 |
+
'presentation': parsed.get('presentation'),
|
| 417 |
+
'rating': parsed.get('rating'),
|
| 418 |
+
'confidence': parsed.get('confidence'),
|
| 419 |
+
'decision': decision
|
| 420 |
+
}
|
| 421 |
+
except Exception as e:
|
| 422 |
+
print(f"Warning: Failed to parse review text: {e}")
|
| 423 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
|
| 427 |
+
"""Calculate Mean Squared Error for a single value."""
|
| 428 |
+
if predicted is None or ground_truth is None:
|
| 429 |
+
return None
|
| 430 |
+
return (predicted - ground_truth) ** 2
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
|
| 434 |
+
"""Calculate Mean Absolute Error for a single value."""
|
| 435 |
+
if predicted is None or ground_truth is None:
|
| 436 |
+
return None
|
| 437 |
+
return abs(predicted - ground_truth)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
|
| 441 |
+
"""
|
| 442 |
+
Normalize a float score to the nearest discrete value based on scale type.
|
| 443 |
+
Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
score: The float score to normalize (can be None)
|
| 447 |
+
scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
|
| 448 |
+
or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
Normalized discrete score, or None if input is None
|
| 452 |
+
"""
|
| 453 |
+
if score is None:
|
| 454 |
+
return None
|
| 455 |
+
|
| 456 |
+
try:
|
| 457 |
+
score = float(score)
|
| 458 |
+
except (ValueError, TypeError):
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
if scale_type == '0-5':
|
| 462 |
+
# Discrete values: 0, 1, 2, 3, 4, 5
|
| 463 |
+
discrete_values = [0, 1, 2, 3, 4, 5]
|
| 464 |
+
# Clamp to valid range
|
| 465 |
+
score = max(0, min(5, score))
|
| 466 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 467 |
+
# For ties, prefer the higher value
|
| 468 |
+
best_value = None
|
| 469 |
+
best_distance = float('inf')
|
| 470 |
+
for val in discrete_values:
|
| 471 |
+
distance = abs(val - score)
|
| 472 |
+
if distance < best_distance:
|
| 473 |
+
best_distance = distance
|
| 474 |
+
best_value = val
|
| 475 |
+
elif distance == best_distance and val > best_value:
|
| 476 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 477 |
+
best_value = val
|
| 478 |
+
return best_value
|
| 479 |
+
elif scale_type == '0-10':
|
| 480 |
+
# Discrete values: 0, 2, 4, 6, 8, 10
|
| 481 |
+
discrete_values = [0, 2, 4, 6, 8, 10]
|
| 482 |
+
# Clamp to valid range
|
| 483 |
+
score = max(0, min(10, score))
|
| 484 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 485 |
+
best_value = None
|
| 486 |
+
best_distance = float('inf')
|
| 487 |
+
for val in discrete_values:
|
| 488 |
+
distance = abs(val - score)
|
| 489 |
+
if distance < best_distance:
|
| 490 |
+
best_distance = distance
|
| 491 |
+
best_value = val
|
| 492 |
+
elif distance == best_distance and val > best_value:
|
| 493 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 494 |
+
best_value = val
|
| 495 |
+
return best_value
|
| 496 |
+
else:
|
| 497 |
+
raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
|
| 501 |
+
"""
|
| 502 |
+
Normalize all scores in a dictionary to their appropriate discrete scales.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Dictionary with normalized scores
|
| 509 |
+
"""
|
| 510 |
+
normalized = {}
|
| 511 |
+
|
| 512 |
+
# soundness, presentation, confidence use 0-5 scale
|
| 513 |
+
for key in ['soundness', 'presentation', 'confidence']:
|
| 514 |
+
normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
|
| 515 |
+
|
| 516 |
+
# rating uses 0-10 scale
|
| 517 |
+
normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
|
| 518 |
+
|
| 519 |
+
return normalized
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def calculate_score_metrics(
|
| 523 |
+
model_scores: Dict[str, float],
|
| 524 |
+
ground_truth_scores: Dict[str, float],
|
| 525 |
+
normalize: bool = False
|
| 526 |
+
) -> Dict[str, Any]:
|
| 527 |
+
"""
|
| 528 |
+
Calculate MSE and MAE metrics for each scoring dimension.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
model_scores: Dictionary with model scores
|
| 532 |
+
ground_truth_scores: Dictionary with ground truth scores
|
| 533 |
+
normalize: If True, normalize scores to discrete scales before computing metrics
|
| 534 |
+
|
| 535 |
+
Returns:
|
| 536 |
+
Dictionary with MSE, MAE metrics and optionally normalized scores
|
| 537 |
+
"""
|
| 538 |
+
dimensions = ['soundness', 'presentation', 'rating', 'confidence']
|
| 539 |
+
|
| 540 |
+
# Normalize scores to discrete scales if requested
|
| 541 |
+
if normalize:
|
| 542 |
+
model_scores_normalized = normalize_scores_dict(model_scores)
|
| 543 |
+
gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
|
| 544 |
+
else:
|
| 545 |
+
model_scores_normalized = model_scores
|
| 546 |
+
gt_scores_normalized = ground_truth_scores
|
| 547 |
+
|
| 548 |
+
mse_values = {}
|
| 549 |
+
mae_values = {}
|
| 550 |
+
valid_count = 0
|
| 551 |
+
|
| 552 |
+
for dim in dimensions:
|
| 553 |
+
# Use normalized scores for metric calculation
|
| 554 |
+
mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 555 |
+
mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 556 |
+
mse_values[f'{dim}_mse'] = mse
|
| 557 |
+
mae_values[f'{dim}_mae'] = mae
|
| 558 |
+
if mse is not None:
|
| 559 |
+
valid_count += 1
|
| 560 |
+
|
| 561 |
+
overall_error = sum([v for v in mse_values.values() if v is not None])
|
| 562 |
+
|
| 563 |
+
result = {
|
| 564 |
+
**mse_values,
|
| 565 |
+
**mae_values,
|
| 566 |
+
'overall_error': overall_error if valid_count > 0 else None,
|
| 567 |
+
'valid_dimensions': valid_count
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
# Include normalized scores in result for transparency (only if normalize=True)
|
| 571 |
+
if normalize:
|
| 572 |
+
result['model_scores_normalized'] = model_scores_normalized
|
| 573 |
+
result['gt_scores_normalized'] = gt_scores_normalized
|
| 574 |
+
|
| 575 |
+
return result
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def normalize_score_value(value):
|
| 579 |
+
"""Normalize score value to float, handling string representations."""
|
| 580 |
+
if value is None:
|
| 581 |
+
return None
|
| 582 |
+
if isinstance(value, (int, float)):
|
| 583 |
+
return float(value)
|
| 584 |
+
if isinstance(value, str):
|
| 585 |
+
# Try to extract numeric value from string (e.g., "2.75" -> 2.75)
|
| 586 |
+
try:
|
| 587 |
+
import re
|
| 588 |
+
match = re.search(r'(\d+\.?\d*)', value)
|
| 589 |
+
if match:
|
| 590 |
+
return float(match.group(1))
|
| 591 |
+
except:
|
| 592 |
+
pass
|
| 593 |
+
return None
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def normalize_decision(decision):
|
| 597 |
+
"""Normalize decision string to standard format."""
|
| 598 |
+
if decision is None:
|
| 599 |
+
return None
|
| 600 |
+
decision_lower = str(decision).lower().strip()
|
| 601 |
+
if 'accept' in decision_lower:
|
| 602 |
+
return 'accept'
|
| 603 |
+
elif 'reject' in decision_lower:
|
| 604 |
+
return 'reject'
|
| 605 |
+
elif 'undecided' in decision_lower:
|
| 606 |
+
return 'undecided'
|
| 607 |
+
else:
|
| 608 |
+
return decision_lower
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 612 |
+
"""
|
| 613 |
+
Extract scores from a structured dictionary (scores or initial_scores format).
|
| 614 |
+
|
| 615 |
+
Args:
|
| 616 |
+
scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
|
| 620 |
+
"""
|
| 621 |
+
if not scores_dict:
|
| 622 |
+
return {
|
| 623 |
+
'soundness': None,
|
| 624 |
+
'presentation': None,
|
| 625 |
+
'rating': None,
|
| 626 |
+
'confidence': None,
|
| 627 |
+
'decision': None
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
return {
|
| 631 |
+
'soundness': normalize_score_value(scores_dict.get('soundness')),
|
| 632 |
+
'presentation': normalize_score_value(scores_dict.get('presentation')),
|
| 633 |
+
'rating': normalize_score_value(scores_dict.get('rating')),
|
| 634 |
+
'confidence': normalize_score_value(scores_dict.get('confidence')),
|
| 635 |
+
'decision': normalize_decision(scores_dict.get('decision'))
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
|
| 640 |
+
"""
|
| 641 |
+
Evaluate a single entry by extracting scores and calculating metrics.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
entry: Evaluation entry containing model_review, scores, initial_scores, etc.
|
| 645 |
+
use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
|
| 646 |
+
|
| 647 |
+
Returns:
|
| 648 |
+
Dict containing evaluation metrics
|
| 649 |
+
"""
|
| 650 |
+
entry_id = entry.get('id', 'unknown')
|
| 651 |
+
model_review = entry.get('model_review', '')
|
| 652 |
+
format_type = entry.get('format', 'unknown')
|
| 653 |
+
|
| 654 |
+
# Extract scores based on format
|
| 655 |
+
model_scores = {}
|
| 656 |
+
model_decision = None
|
| 657 |
+
|
| 658 |
+
if format_type == 'refined' and not use_initial_scores:
|
| 659 |
+
# Use refined scores from structured data
|
| 660 |
+
scores_dict = entry.get('scores', {})
|
| 661 |
+
model_data = extract_scores_from_dict(scores_dict)
|
| 662 |
+
model_scores = {
|
| 663 |
+
'soundness': model_data.get('soundness'),
|
| 664 |
+
'presentation': model_data.get('presentation'),
|
| 665 |
+
'rating': model_data.get('rating'),
|
| 666 |
+
'confidence': model_data.get('confidence')
|
| 667 |
+
}
|
| 668 |
+
model_decision = model_data.get('decision')
|
| 669 |
+
elif format_type == 'refined' and use_initial_scores:
|
| 670 |
+
# Use initial scores from structured data
|
| 671 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 672 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 673 |
+
model_scores = {
|
| 674 |
+
'soundness': model_data.get('soundness'),
|
| 675 |
+
'presentation': model_data.get('presentation'),
|
| 676 |
+
'rating': model_data.get('rating'),
|
| 677 |
+
'confidence': model_data.get('confidence')
|
| 678 |
+
}
|
| 679 |
+
model_decision = model_data.get('decision')
|
| 680 |
+
elif format_type == 'original':
|
| 681 |
+
# Use initial scores from structured data
|
| 682 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 683 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 684 |
+
model_scores = {
|
| 685 |
+
'soundness': model_data.get('soundness'),
|
| 686 |
+
'presentation': model_data.get('presentation'),
|
| 687 |
+
'rating': model_data.get('rating'),
|
| 688 |
+
'confidence': model_data.get('confidence')
|
| 689 |
+
}
|
| 690 |
+
model_decision = model_data.get('decision')
|
| 691 |
+
|
| 692 |
+
# Fallback: If confidence is missing from structured data, try to extract from review text
|
| 693 |
+
# (meta_review may not have confidence field, but review text might)
|
| 694 |
+
if model_scores.get('confidence') is None and model_review:
|
| 695 |
+
try:
|
| 696 |
+
review_data = extract_scores_from_review(model_review)
|
| 697 |
+
if review_data.get('confidence') is not None:
|
| 698 |
+
model_scores['confidence'] = review_data.get('confidence')
|
| 699 |
+
except Exception:
|
| 700 |
+
pass # Keep confidence as None if extraction fails
|
| 701 |
+
else:
|
| 702 |
+
# Fallback: extract from markdown review text
|
| 703 |
+
model_data = extract_scores_from_review(model_review)
|
| 704 |
+
model_scores = {
|
| 705 |
+
'soundness': model_data.get('soundness'),
|
| 706 |
+
'presentation': model_data.get('presentation'),
|
| 707 |
+
'rating': model_data.get('rating'),
|
| 708 |
+
'confidence': model_data.get('confidence')
|
| 709 |
+
}
|
| 710 |
+
model_decision = model_data.get('decision')
|
| 711 |
+
|
| 712 |
+
# Get ground truth scores from golden_review ONLY
|
| 713 |
+
# Ground truth must ONLY come from golden_review, never from model output
|
| 714 |
+
# If extraction fails, leave fields as None (do not use model_review as fallback)
|
| 715 |
+
ground_truth_review = entry.get('golden_review', '')
|
| 716 |
+
ground_truth_scores = {}
|
| 717 |
+
gt_decision = None
|
| 718 |
+
|
| 719 |
+
if not ground_truth_review:
|
| 720 |
+
print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
|
| 721 |
+
else:
|
| 722 |
+
try:
|
| 723 |
+
# Extract scores from golden_review markdown text
|
| 724 |
+
gt_data = extract_scores_from_review(ground_truth_review)
|
| 725 |
+
if not gt_data:
|
| 726 |
+
print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
|
| 727 |
+
else:
|
| 728 |
+
ground_truth_scores = {
|
| 729 |
+
'soundness': gt_data.get('soundness'),
|
| 730 |
+
'presentation': gt_data.get('presentation'),
|
| 731 |
+
'rating': gt_data.get('rating'),
|
| 732 |
+
'confidence': gt_data.get('confidence')
|
| 733 |
+
}
|
| 734 |
+
gt_decision = normalize_decision(gt_data.get('decision'))
|
| 735 |
+
# Note: If any field is None, it stays None - we do NOT use model_review as fallback
|
| 736 |
+
# Using model output as ground truth would inflate evaluation scores
|
| 737 |
+
except Exception as e:
|
| 738 |
+
print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
|
| 739 |
+
print(f" Ground truth scores will be empty. Error: {str(e)}")
|
| 740 |
+
|
| 741 |
+
# Calculate MSE and MAE metrics (with optional normalization in strict mode)
|
| 742 |
+
score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
|
| 743 |
+
|
| 744 |
+
# Calculate decision accuracy
|
| 745 |
+
decision_match = False
|
| 746 |
+
decision_accuracy = None
|
| 747 |
+
if model_decision is not None and gt_decision is not None:
|
| 748 |
+
model_decision_normalized = normalize_decision(model_decision)
|
| 749 |
+
decision_match = (model_decision_normalized == gt_decision)
|
| 750 |
+
decision_accuracy = 1.0 if decision_match else 0.0
|
| 751 |
+
|
| 752 |
+
result = {
|
| 753 |
+
'id': entry_id,
|
| 754 |
+
'format': format_type,
|
| 755 |
+
'model_soundness': model_scores.get('soundness'),
|
| 756 |
+
'model_presentation': model_scores.get('presentation'),
|
| 757 |
+
'model_rating': model_scores.get('rating'),
|
| 758 |
+
'model_confidence': model_scores.get('confidence'),
|
| 759 |
+
'model_decision': model_decision,
|
| 760 |
+
'gt_soundness': ground_truth_scores.get('soundness'),
|
| 761 |
+
'gt_presentation': ground_truth_scores.get('presentation'),
|
| 762 |
+
'gt_rating': ground_truth_scores.get('rating'),
|
| 763 |
+
'gt_confidence': ground_truth_scores.get('confidence'),
|
| 764 |
+
'gt_decision': gt_decision,
|
| 765 |
+
'decision_match': decision_match,
|
| 766 |
+
'decision_accuracy': decision_accuracy,
|
| 767 |
+
**score_metrics
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
# Add prefix to indicate which scores were used
|
| 771 |
+
if format_type == 'refined':
|
| 772 |
+
if use_initial_scores:
|
| 773 |
+
result['score_type'] = 'initial'
|
| 774 |
+
else:
|
| 775 |
+
result['score_type'] = 'refined'
|
| 776 |
+
else:
|
| 777 |
+
result['score_type'] = 'auto'
|
| 778 |
+
|
| 779 |
+
return result
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
|
| 783 |
+
"""Calculate pairwise accuracy for each metric by comparing rankings."""
|
| 784 |
+
if len(paper_scores) < 2:
|
| 785 |
+
return {}
|
| 786 |
+
|
| 787 |
+
total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 788 |
+
correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 789 |
+
|
| 790 |
+
for paper1, paper2 in combinations(paper_scores, 2):
|
| 791 |
+
# Check rating ranking
|
| 792 |
+
if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
|
| 793 |
+
paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
|
| 794 |
+
total_valid_pairs['rating'] += 1
|
| 795 |
+
true_order = paper1['true_rating'] > paper2['true_rating']
|
| 796 |
+
pred_order = paper1['pred_rating'] > paper2['pred_rating']
|
| 797 |
+
if true_order == pred_order:
|
| 798 |
+
correct_pairs['rating'] += 1
|
| 799 |
+
|
| 800 |
+
# Similar for other dimensions...
|
| 801 |
+
# (abbreviated for space, similar logic for soundness, presentation, confidence)
|
| 802 |
+
for metric in ['soundness', 'presentation', 'confidence']:
|
| 803 |
+
true_key = f'true_{metric}'
|
| 804 |
+
pred_key = f'pred_{metric}'
|
| 805 |
+
if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
|
| 806 |
+
paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
|
| 807 |
+
total_valid_pairs[metric] += 1
|
| 808 |
+
true_order = paper1[true_key] > paper2[true_key]
|
| 809 |
+
pred_order = paper1[pred_key] > paper2[pred_key]
|
| 810 |
+
if true_order == pred_order:
|
| 811 |
+
correct_pairs[metric] += 1
|
| 812 |
+
|
| 813 |
+
pairwise_accuracies = {
|
| 814 |
+
metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
|
| 815 |
+
for metric in ['rating', 'soundness', 'presentation', 'confidence']
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
return pairwise_accuracies
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
# ============================================================================
|
| 822 |
+
# Data Loading Functions
|
| 823 |
+
# ============================================================================
|
| 824 |
+
|
| 825 |
+
def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
|
| 826 |
+
"""Load rubrics JSON and create lookup by id."""
|
| 827 |
+
with open(rubrics_path, 'r', encoding='utf-8') as f:
|
| 828 |
+
data = json.load(f)
|
| 829 |
+
|
| 830 |
+
if isinstance(data, list):
|
| 831 |
+
return {item['id']: item for item in data}
|
| 832 |
+
elif isinstance(data, dict):
|
| 833 |
+
return data
|
| 834 |
+
else:
|
| 835 |
+
raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
| 839 |
+
"""
|
| 840 |
+
Load model reviews JSON and extract reviews by id.
|
| 841 |
+
|
| 842 |
+
Supports two input formats:
|
| 843 |
+
1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
|
| 844 |
+
2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
|
| 845 |
+
|
| 846 |
+
Args:
|
| 847 |
+
reviews_path: Path to JSON file containing model reviews
|
| 848 |
+
format_override: Optional format override ('refined', 'original', or None for auto-detect)
|
| 849 |
+
|
| 850 |
+
Returns:
|
| 851 |
+
Dict mapping paper_id to dict containing:
|
| 852 |
+
- 'review': review text (markdown)
|
| 853 |
+
- 'scores': refined scores dict (if available)
|
| 854 |
+
- 'initial_scores': initial scores dict (if available)
|
| 855 |
+
- 'format': 'refined' or 'original'
|
| 856 |
+
"""
|
| 857 |
+
with open(reviews_path, 'r', encoding='utf-8') as f:
|
| 858 |
+
data = json.load(f)
|
| 859 |
+
|
| 860 |
+
if isinstance(data, dict):
|
| 861 |
+
data = list(data.values())
|
| 862 |
+
|
| 863 |
+
reviews_dict = {}
|
| 864 |
+
for item in data:
|
| 865 |
+
item_id = None
|
| 866 |
+
review_text = ''
|
| 867 |
+
scores = None
|
| 868 |
+
initial_scores = None
|
| 869 |
+
format_type = None
|
| 870 |
+
|
| 871 |
+
# Use format override if provided, otherwise auto-detect
|
| 872 |
+
if format_override and format_override != 'auto':
|
| 873 |
+
# Force use specified format
|
| 874 |
+
if format_override == 'refined':
|
| 875 |
+
item_id = item.get('paper_id') or item.get('id')
|
| 876 |
+
if not item_id:
|
| 877 |
+
continue
|
| 878 |
+
format_type = 'refined'
|
| 879 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 880 |
+
scores = item.get('scores', {})
|
| 881 |
+
initial_scores = item.get('initial_scores', {})
|
| 882 |
+
elif format_override == 'original':
|
| 883 |
+
item_id = item.get('id')
|
| 884 |
+
if not item_id:
|
| 885 |
+
continue
|
| 886 |
+
format_type = 'original'
|
| 887 |
+
model_prediction = item.get('model_prediction', {})
|
| 888 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 889 |
+
review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 890 |
+
initial_scores = {
|
| 891 |
+
'rating': meta_review.get('rating'),
|
| 892 |
+
'soundness': meta_review.get('soundness'),
|
| 893 |
+
'presentation': meta_review.get('presentation'),
|
| 894 |
+
'contribution': meta_review.get('contribution'),
|
| 895 |
+
'decision': model_prediction.get('decision'),
|
| 896 |
+
}
|
| 897 |
+
else:
|
| 898 |
+
raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
|
| 899 |
+
else:
|
| 900 |
+
# Auto-detect format
|
| 901 |
+
if "paper_id" in item:
|
| 902 |
+
# Refined format (from refinement pipeline)
|
| 903 |
+
item_id = item.get('paper_id')
|
| 904 |
+
if not item_id:
|
| 905 |
+
continue
|
| 906 |
+
|
| 907 |
+
# Check if this is refined format (has scores and initial_scores)
|
| 908 |
+
if 'scores' in item and 'initial_scores' in item:
|
| 909 |
+
format_type = 'refined'
|
| 910 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 911 |
+
scores = item.get('scores', {})
|
| 912 |
+
initial_scores = item.get('initial_scores', {})
|
| 913 |
+
else:
|
| 914 |
+
# Standard format with paper_id
|
| 915 |
+
format_type = 'standard'
|
| 916 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 917 |
+
elif "model_prediction" in item:
|
| 918 |
+
# Original format (like ours.json)
|
| 919 |
+
item_id = item.get('id')
|
| 920 |
+
if not item_id:
|
| 921 |
+
continue
|
| 922 |
+
|
| 923 |
+
format_type = 'original'
|
| 924 |
+
model_prediction = item.get('model_prediction', {})
|
| 925 |
+
|
| 926 |
+
review_text = model_prediction.get('raw_text', '')
|
| 927 |
+
|
| 928 |
+
if review_text is None:
|
| 929 |
+
continue
|
| 930 |
+
|
| 931 |
+
# Detect format: agenticreview has raw_text as string with "Overall rating: x"
|
| 932 |
+
# ai_researcher format has raw_text as dict or JSON string with structured fields
|
| 933 |
+
is_agenticreview = False
|
| 934 |
+
if isinstance(review_text, str):
|
| 935 |
+
# Check if it's a JSON string (ai_researcher format)
|
| 936 |
+
try:
|
| 937 |
+
parsed_json = json.loads(review_text)
|
| 938 |
+
if isinstance(parsed_json, dict) and any(key in parsed_json for key in ["Summary", "Strengths", "Overall", "Decision"]):
|
| 939 |
+
# It's ai_researcher format
|
| 940 |
+
review_text = parsed_json
|
| 941 |
+
review_text, meta_review = convert_ai_researcher(review_text)
|
| 942 |
+
else:
|
| 943 |
+
# It's agenticreview format (plain text with "Overall rating: x")
|
| 944 |
+
is_agenticreview = True
|
| 945 |
+
except (json.JSONDecodeError, TypeError):
|
| 946 |
+
# Not JSON, check if it contains "Overall rating:" pattern
|
| 947 |
+
if re.search(r'Overall\s+rating\s*[:=]', review_text, re.IGNORECASE):
|
| 948 |
+
is_agenticreview = True
|
| 949 |
+
else:
|
| 950 |
+
# Try to parse as ai_researcher anyway
|
| 951 |
+
try:
|
| 952 |
+
review_text = json.loads(review_text)
|
| 953 |
+
review_text, meta_review = convert_ai_researcher(review_text)
|
| 954 |
+
except:
|
| 955 |
+
review_text = 'Empty Review'
|
| 956 |
+
meta_review = {}
|
| 957 |
+
elif isinstance(review_text, dict):
|
| 958 |
+
# It's ai_researcher format (dict)
|
| 959 |
+
review_text, meta_review = convert_ai_researcher(review_text)
|
| 960 |
+
else:
|
| 961 |
+
review_text = 'Empty Review'
|
| 962 |
+
meta_review = {}
|
| 963 |
+
|
| 964 |
+
# Handle agenticreview format
|
| 965 |
+
if is_agenticreview:
|
| 966 |
+
review_text, meta_review = convert_agenticreview(review_text)
|
| 967 |
+
|
| 968 |
+
# Extract initial scores
|
| 969 |
+
# Use meta_review as primary source (from convert_ai_researcher or convert_agenticreview)
|
| 970 |
+
# Fallback to model_prediction.get('decision') if not in meta_review
|
| 971 |
+
initial_scores = {
|
| 972 |
+
'rating': meta_review.get('rating'),
|
| 973 |
+
'soundness': meta_review.get('soundness'),
|
| 974 |
+
'presentation': meta_review.get('presentation'),
|
| 975 |
+
'contribution': meta_review.get('contribution'),
|
| 976 |
+
'confidence': meta_review.get('confidence'),
|
| 977 |
+
'decision': meta_review.get('decision') or model_prediction.get('decision'),
|
| 978 |
+
}
|
| 979 |
+
else:
|
| 980 |
+
# Legacy format (pred_fast_mode)
|
| 981 |
+
item_id = item.get('id')
|
| 982 |
+
if not item_id:
|
| 983 |
+
continue
|
| 984 |
+
|
| 985 |
+
format_type = 'legacy'
|
| 986 |
+
review_dict = item.get('pred_fast_mode', {})
|
| 987 |
+
if isinstance(review_dict, dict):
|
| 988 |
+
review_text = review_dict.get('raw_text', '')
|
| 989 |
+
else:
|
| 990 |
+
review_text = str(review_dict)
|
| 991 |
+
|
| 992 |
+
# Extract review content from the review text field
|
| 993 |
+
try:
|
| 994 |
+
if review_text:
|
| 995 |
+
# extracted_review = ReviewProcessor.extract_review_content(review_text)
|
| 996 |
+
extracted_review = review_text
|
| 997 |
+
else:
|
| 998 |
+
extracted_review = ''
|
| 999 |
+
|
| 1000 |
+
reviews_dict[item_id] = {
|
| 1001 |
+
'review': extracted_review,
|
| 1002 |
+
'scores': scores,
|
| 1003 |
+
'initial_scores': initial_scores,
|
| 1004 |
+
'format': format_type
|
| 1005 |
+
}
|
| 1006 |
+
except Exception as e:
|
| 1007 |
+
print(f"[WARN] Failed to extract review for {item_id}: {e}")
|
| 1008 |
+
continue
|
| 1009 |
+
|
| 1010 |
+
return reviews_dict
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def combine_rubrics_and_reviews(
|
| 1014 |
+
rubrics_data: Dict[str, Dict[str, Any]],
|
| 1015 |
+
reviews_dict: Dict[str, Dict[str, Any]]
|
| 1016 |
+
) -> List[Dict[str, Any]]:
|
| 1017 |
+
"""
|
| 1018 |
+
Combine rubrics and reviews into evaluation entries.
|
| 1019 |
+
|
| 1020 |
+
Args:
|
| 1021 |
+
rubrics_data: Dict mapping paper_id to rubric entry
|
| 1022 |
+
reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
|
| 1023 |
+
|
| 1024 |
+
Returns:
|
| 1025 |
+
List of evaluation entries with model_review, scores, initial_scores, and format info
|
| 1026 |
+
"""
|
| 1027 |
+
combined = []
|
| 1028 |
+
missing_reviews = []
|
| 1029 |
+
|
| 1030 |
+
for paper_id, rubric_entry in rubrics_data.items():
|
| 1031 |
+
review_data = reviews_dict.get(paper_id)
|
| 1032 |
+
if not review_data or not review_data.get('review'):
|
| 1033 |
+
missing_reviews.append(paper_id)
|
| 1034 |
+
continue
|
| 1035 |
+
|
| 1036 |
+
entry = {
|
| 1037 |
+
'id': paper_id,
|
| 1038 |
+
'paper_context': rubric_entry.get('paper_context', ''),
|
| 1039 |
+
'decision': rubric_entry.get('decision', ''),
|
| 1040 |
+
'golden_review': rubric_entry.get('golden_review', ''),
|
| 1041 |
+
'rubrics': rubric_entry.get('rubrics', []),
|
| 1042 |
+
'model_review': review_data.get('review', ''),
|
| 1043 |
+
'scores': review_data.get('scores'), # Refined scores (if available)
|
| 1044 |
+
'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
|
| 1045 |
+
'format': review_data.get('format', 'unknown') # Format type
|
| 1046 |
+
}
|
| 1047 |
+
combined.append(entry)
|
| 1048 |
+
|
| 1049 |
+
if missing_reviews:
|
| 1050 |
+
print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
|
| 1051 |
+
|
| 1052 |
+
return combined
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
# ============================================================================
|
| 1056 |
+
# LLM Service Configuration
|
| 1057 |
+
# ============================================================================
|
| 1058 |
+
|
| 1059 |
+
def load_llm_config(config_path: str) -> Dict[str, Any]:
|
| 1060 |
+
"""Load LLM configuration from YAML file."""
|
| 1061 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 1062 |
+
config = yaml.safe_load(f)
|
| 1063 |
+
return config
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
|
| 1067 |
+
"""Create LLM service from configuration."""
|
| 1068 |
+
mode = config.get('mode', 'gpt').lower()
|
| 1069 |
+
|
| 1070 |
+
if mode == 'gpt':
|
| 1071 |
+
gpt_config = config.get('gpt', {})
|
| 1072 |
+
api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
|
| 1073 |
+
if not api_key:
|
| 1074 |
+
raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
|
| 1075 |
+
|
| 1076 |
+
service = GPTService(
|
| 1077 |
+
api_key=api_key,
|
| 1078 |
+
model_name=gpt_config.get('model_name', 'gpt-4o'),
|
| 1079 |
+
base_url=gpt_config.get('base_url'),
|
| 1080 |
+
timeout=gpt_config.get('timeout', 300)
|
| 1081 |
+
)
|
| 1082 |
+
return service
|
| 1083 |
+
|
| 1084 |
+
elif mode == 'vllm':
|
| 1085 |
+
vllm_config = config.get('vllm', {})
|
| 1086 |
+
service = VLLMService(
|
| 1087 |
+
base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
|
| 1088 |
+
api_key=vllm_config.get('api_key', 'dummy-key'),
|
| 1089 |
+
model_name=vllm_config.get('model_name'),
|
| 1090 |
+
timeout=vllm_config.get('timeout', 300),
|
| 1091 |
+
max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
|
| 1092 |
+
max_retries=vllm_config.get('max_retries', 3),
|
| 1093 |
+
retry_delay=vllm_config.get('retry_delay', 1.0),
|
| 1094 |
+
retry_backoff=vllm_config.get('retry_backoff', 2.0)
|
| 1095 |
+
)
|
| 1096 |
+
return service
|
| 1097 |
+
|
| 1098 |
+
else:
|
| 1099 |
+
raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
# ============================================================================
|
| 1103 |
+
# Main Evaluation Functions
|
| 1104 |
+
# ============================================================================
|
| 1105 |
+
|
| 1106 |
+
def run_semantic_evaluation(
|
| 1107 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1108 |
+
prompt_template: str,
|
| 1109 |
+
llm_service: LLMService,
|
| 1110 |
+
max_workers: int
|
| 1111 |
+
) -> tuple:
|
| 1112 |
+
"""Run semantic evaluation and return results and summary."""
|
| 1113 |
+
print(f"\n{'='*80}")
|
| 1114 |
+
print("RUNNING SEMANTIC EVALUATION")
|
| 1115 |
+
print(f"{'='*80}")
|
| 1116 |
+
print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
|
| 1117 |
+
|
| 1118 |
+
results = []
|
| 1119 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 1120 |
+
future_to_entry = {
|
| 1121 |
+
executor.submit(
|
| 1122 |
+
evaluate_review_semantic,
|
| 1123 |
+
entry,
|
| 1124 |
+
entry['paper_context'],
|
| 1125 |
+
prompt_template,
|
| 1126 |
+
llm_service
|
| 1127 |
+
): entry
|
| 1128 |
+
for entry in evaluation_data
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
|
| 1132 |
+
try:
|
| 1133 |
+
result = future.result()
|
| 1134 |
+
results.append(result)
|
| 1135 |
+
except Exception as e:
|
| 1136 |
+
entry = future_to_entry[future]
|
| 1137 |
+
print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
|
| 1138 |
+
results.append({
|
| 1139 |
+
'id': entry.get('id', 'unknown'),
|
| 1140 |
+
'raw_scores': {},
|
| 1141 |
+
'weighted_scores': {},
|
| 1142 |
+
'total_score': 0.0,
|
| 1143 |
+
'error': str(e),
|
| 1144 |
+
'raw_response': ''
|
| 1145 |
+
})
|
| 1146 |
+
|
| 1147 |
+
# Calculate statistics
|
| 1148 |
+
valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
|
| 1149 |
+
review_scores = [r.get('total_score', 0.0) for r in valid_results]
|
| 1150 |
+
|
| 1151 |
+
summary = {
|
| 1152 |
+
'total_entries': len(results),
|
| 1153 |
+
'valid_entries': len(valid_results),
|
| 1154 |
+
'failed_entries': len(results) - len(valid_results)
|
| 1155 |
+
}
|
| 1156 |
+
|
| 1157 |
+
if review_scores:
|
| 1158 |
+
summary['overall_score'] = {
|
| 1159 |
+
'mean': sum(review_scores) / len(review_scores),
|
| 1160 |
+
'min': min(review_scores),
|
| 1161 |
+
'max': max(review_scores)
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
# Calculate per-rubric statistics (extract rubric titles from first entry)
|
| 1165 |
+
if evaluation_data and evaluation_data[0].get('rubrics'):
|
| 1166 |
+
rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
|
| 1167 |
+
per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
|
| 1168 |
+
summary['per_rubric_statistics'] = per_rubric_stats
|
| 1169 |
+
|
| 1170 |
+
return results, summary
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
def run_auto_metric_evaluation(
|
| 1174 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1175 |
+
strict_mode: bool = False
|
| 1176 |
+
) -> tuple:
|
| 1177 |
+
"""
|
| 1178 |
+
Run auto-metric evaluation and return results and summary.
|
| 1179 |
+
|
| 1180 |
+
For refined format (has scores and initial_scores), evaluates both:
|
| 1181 |
+
- Refined scores evaluation
|
| 1182 |
+
- Initial scores evaluation
|
| 1183 |
+
|
| 1184 |
+
For original format (only initial_scores), evaluates:
|
| 1185 |
+
- Initial scores evaluation only
|
| 1186 |
+
|
| 1187 |
+
Returns:
|
| 1188 |
+
Tuple of (results_list, summary_dict)
|
| 1189 |
+
- results_list: List of evaluation results (may contain both refined and initial results for refined format)
|
| 1190 |
+
- summary_dict: Summary statistics
|
| 1191 |
+
"""
|
| 1192 |
+
print(f"\n{'='*80}")
|
| 1193 |
+
print("RUNNING AUTO-METRIC EVALUATION")
|
| 1194 |
+
print(f"{'='*80}")
|
| 1195 |
+
print(f"Evaluating {len(evaluation_data)} entries...")
|
| 1196 |
+
|
| 1197 |
+
# Detect format types
|
| 1198 |
+
refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
|
| 1199 |
+
original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
|
| 1200 |
+
|
| 1201 |
+
if refined_format_count > 0:
|
| 1202 |
+
print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
|
| 1203 |
+
if original_format_count > 0:
|
| 1204 |
+
print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
|
| 1205 |
+
|
| 1206 |
+
results = []
|
| 1207 |
+
for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
|
| 1208 |
+
format_type = entry.get('format', 'unknown')
|
| 1209 |
+
|
| 1210 |
+
if format_type == 'refined':
|
| 1211 |
+
# Evaluate both refined scores and initial scores
|
| 1212 |
+
try:
|
| 1213 |
+
entry_id = entry.get('id', 'unknown')
|
| 1214 |
+
|
| 1215 |
+
# Evaluate refined scores
|
| 1216 |
+
refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1217 |
+
refined_result['paper_id'] = entry_id # Keep original paper_id
|
| 1218 |
+
refined_result['id'] = f"{entry_id}_refined"
|
| 1219 |
+
results.append(refined_result)
|
| 1220 |
+
|
| 1221 |
+
# Evaluate initial scores
|
| 1222 |
+
initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
|
| 1223 |
+
initial_result['paper_id'] = entry_id # Keep original paper_id
|
| 1224 |
+
initial_result['id'] = f"{entry_id}_initial"
|
| 1225 |
+
results.append(initial_result)
|
| 1226 |
+
except Exception as e:
|
| 1227 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1228 |
+
results.append({
|
| 1229 |
+
'id': entry.get('id', 'unknown'),
|
| 1230 |
+
'error': str(e)
|
| 1231 |
+
})
|
| 1232 |
+
else:
|
| 1233 |
+
# Evaluate initial scores only (or extract from markdown)
|
| 1234 |
+
try:
|
| 1235 |
+
result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1236 |
+
results.append(result)
|
| 1237 |
+
except Exception as e:
|
| 1238 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1239 |
+
results.append({
|
| 1240 |
+
'id': entry.get('id', 'unknown'),
|
| 1241 |
+
'error': str(e)
|
| 1242 |
+
})
|
| 1243 |
+
|
| 1244 |
+
# Calculate statistics
|
| 1245 |
+
valid_results = [r for r in results if 'error' not in r]
|
| 1246 |
+
mse_results = [r for r in valid_results if r.get('overall_error') is not None]
|
| 1247 |
+
|
| 1248 |
+
# Separate refined and initial results for refined format
|
| 1249 |
+
refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
|
| 1250 |
+
initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
|
| 1251 |
+
auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
|
| 1252 |
+
|
| 1253 |
+
summary = {
|
| 1254 |
+
'total_entries': len(results),
|
| 1255 |
+
'valid_entries': len(valid_results),
|
| 1256 |
+
'mse_entries': len(mse_results),
|
| 1257 |
+
'refined_results_count': len(refined_results),
|
| 1258 |
+
'initial_results_count': len(initial_results),
|
| 1259 |
+
'auto_results_count': len(auto_results)
|
| 1260 |
+
}
|
| 1261 |
+
|
| 1262 |
+
# Calculate MSE/MAE statistics
|
| 1263 |
+
# For refined format, only use refined results for overall statistics (avoid double counting)
|
| 1264 |
+
# For other formats, use all results
|
| 1265 |
+
if refined_format_count > 0:
|
| 1266 |
+
# Refined format: use only refined results for overall statistics
|
| 1267 |
+
stats_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1268 |
+
else:
|
| 1269 |
+
# Original/other formats: use all results
|
| 1270 |
+
stats_results = mse_results
|
| 1271 |
+
|
| 1272 |
+
if stats_results:
|
| 1273 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1274 |
+
mse_stats = {}
|
| 1275 |
+
mae_stats = {}
|
| 1276 |
+
|
| 1277 |
+
for dim in dimensions:
|
| 1278 |
+
mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
|
| 1279 |
+
mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
|
| 1280 |
+
|
| 1281 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1282 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1283 |
+
|
| 1284 |
+
if mse_clean:
|
| 1285 |
+
mse_stats[dim] = {
|
| 1286 |
+
'mean': sum(mse_clean) / len(mse_clean),
|
| 1287 |
+
'count': len(mse_clean)
|
| 1288 |
+
}
|
| 1289 |
+
if mae_clean:
|
| 1290 |
+
mae_stats[dim] = {
|
| 1291 |
+
'mean': sum(mae_clean) / len(mae_clean),
|
| 1292 |
+
'count': len(mae_clean)
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
|
| 1296 |
+
overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1297 |
+
|
| 1298 |
+
if overall_clean:
|
| 1299 |
+
summary['overall_error'] = {
|
| 1300 |
+
'mean': sum(overall_clean) / len(overall_clean),
|
| 1301 |
+
'count': len(overall_clean)
|
| 1302 |
+
}
|
| 1303 |
+
|
| 1304 |
+
summary['mse_statistics'] = mse_stats
|
| 1305 |
+
summary['mae_statistics'] = mae_stats
|
| 1306 |
+
|
| 1307 |
+
# Calculate separate statistics for refined and initial results
|
| 1308 |
+
if refined_results:
|
| 1309 |
+
refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1310 |
+
if refined_mse_results:
|
| 1311 |
+
refined_mse_stats = {}
|
| 1312 |
+
refined_mae_stats = {}
|
| 1313 |
+
for dim in dimensions:
|
| 1314 |
+
mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1315 |
+
mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1316 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1317 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1318 |
+
if mse_clean:
|
| 1319 |
+
refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1320 |
+
if mae_clean:
|
| 1321 |
+
refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1322 |
+
summary['refined_mse_statistics'] = refined_mse_stats
|
| 1323 |
+
summary['refined_mae_statistics'] = refined_mae_stats
|
| 1324 |
+
|
| 1325 |
+
if initial_results:
|
| 1326 |
+
initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
|
| 1327 |
+
if initial_mse_results:
|
| 1328 |
+
initial_mse_stats = {}
|
| 1329 |
+
initial_mae_stats = {}
|
| 1330 |
+
for dim in dimensions:
|
| 1331 |
+
mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1332 |
+
mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1333 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1334 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1335 |
+
if mse_clean:
|
| 1336 |
+
initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1337 |
+
if mae_clean:
|
| 1338 |
+
initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1339 |
+
summary['initial_mse_statistics'] = initial_mse_stats
|
| 1340 |
+
summary['initial_mae_statistics'] = initial_mae_stats
|
| 1341 |
+
|
| 1342 |
+
# Calculate Spearman correlations
|
| 1343 |
+
def filter_valid_pairs(true_list, pred_list):
|
| 1344 |
+
filtered_true = []
|
| 1345 |
+
filtered_pred = []
|
| 1346 |
+
for t, p in zip(true_list, pred_list):
|
| 1347 |
+
if (t is not None and p is not None and
|
| 1348 |
+
not (isinstance(t, float) and math.isnan(t)) and
|
| 1349 |
+
not (isinstance(p, float) and math.isnan(p))):
|
| 1350 |
+
filtered_true.append(t)
|
| 1351 |
+
filtered_pred.append(p)
|
| 1352 |
+
return filtered_true, filtered_pred
|
| 1353 |
+
|
| 1354 |
+
# Calculate Spearman correlations
|
| 1355 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1356 |
+
# For other formats, use all results
|
| 1357 |
+
if refined_format_count > 0:
|
| 1358 |
+
# Calculate refined spearman correlations
|
| 1359 |
+
refined_spearman_stats = {}
|
| 1360 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1361 |
+
for dim in dimensions:
|
| 1362 |
+
true_values = [r.get(f'gt_{dim}') for r in refined_results]
|
| 1363 |
+
pred_values = [r.get(f'model_{dim}') for r in refined_results]
|
| 1364 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1365 |
+
|
| 1366 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1367 |
+
try:
|
| 1368 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1369 |
+
if not math.isnan(corr):
|
| 1370 |
+
refined_spearman_stats[dim] = {
|
| 1371 |
+
'correlation': corr,
|
| 1372 |
+
'count': len(true_clean)
|
| 1373 |
+
}
|
| 1374 |
+
except Exception:
|
| 1375 |
+
pass
|
| 1376 |
+
|
| 1377 |
+
# Calculate initial spearman correlations
|
| 1378 |
+
initial_spearman_stats = {}
|
| 1379 |
+
for dim in dimensions:
|
| 1380 |
+
true_values = [r.get(f'gt_{dim}') for r in initial_results]
|
| 1381 |
+
pred_values = [r.get(f'model_{dim}') for r in initial_results]
|
| 1382 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1383 |
+
|
| 1384 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1385 |
+
try:
|
| 1386 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1387 |
+
if not math.isnan(corr):
|
| 1388 |
+
initial_spearman_stats[dim] = {
|
| 1389 |
+
'correlation': corr,
|
| 1390 |
+
'count': len(true_clean)
|
| 1391 |
+
}
|
| 1392 |
+
except Exception:
|
| 1393 |
+
pass
|
| 1394 |
+
|
| 1395 |
+
# Use refined for overall statistics (avoid double counting)
|
| 1396 |
+
summary['spearman_correlations'] = refined_spearman_stats
|
| 1397 |
+
summary['refined_spearman_correlations'] = refined_spearman_stats
|
| 1398 |
+
summary['initial_spearman_correlations'] = initial_spearman_stats
|
| 1399 |
+
else:
|
| 1400 |
+
# Original/other formats: use all results
|
| 1401 |
+
correlation_results = valid_results
|
| 1402 |
+
spearman_stats = {}
|
| 1403 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1404 |
+
for dim in dimensions:
|
| 1405 |
+
true_values = [r.get(f'gt_{dim}') for r in correlation_results]
|
| 1406 |
+
pred_values = [r.get(f'model_{dim}') for r in correlation_results]
|
| 1407 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1408 |
+
|
| 1409 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1410 |
+
try:
|
| 1411 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1412 |
+
if not math.isnan(corr):
|
| 1413 |
+
spearman_stats[dim] = {
|
| 1414 |
+
'correlation': corr,
|
| 1415 |
+
'count': len(true_clean)
|
| 1416 |
+
}
|
| 1417 |
+
except Exception:
|
| 1418 |
+
pass
|
| 1419 |
+
|
| 1420 |
+
summary['spearman_correlations'] = spearman_stats
|
| 1421 |
+
|
| 1422 |
+
# Calculate Decision metrics
|
| 1423 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1424 |
+
# For other formats, use all results
|
| 1425 |
+
if refined_format_count > 0:
|
| 1426 |
+
# Calculate refined decision metrics
|
| 1427 |
+
refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1428 |
+
if refined_decision_results:
|
| 1429 |
+
true_decisions = []
|
| 1430 |
+
pred_decisions = []
|
| 1431 |
+
decision_acc = []
|
| 1432 |
+
|
| 1433 |
+
for r in refined_decision_results:
|
| 1434 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1435 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1436 |
+
|
| 1437 |
+
if 'accept' in pred_decision:
|
| 1438 |
+
pred_binary = 1
|
| 1439 |
+
else:
|
| 1440 |
+
pred_binary = 0
|
| 1441 |
+
|
| 1442 |
+
if 'accept' in gt_decision:
|
| 1443 |
+
gt_binary = 1
|
| 1444 |
+
else:
|
| 1445 |
+
gt_binary = 0
|
| 1446 |
+
|
| 1447 |
+
true_decisions.append(gt_binary)
|
| 1448 |
+
pred_decisions.append(pred_binary)
|
| 1449 |
+
|
| 1450 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1451 |
+
decision_acc.append(1.0)
|
| 1452 |
+
else:
|
| 1453 |
+
decision_acc.append(0.0)
|
| 1454 |
+
|
| 1455 |
+
if decision_acc:
|
| 1456 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1457 |
+
try:
|
| 1458 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1459 |
+
refined_decision_metrics = {
|
| 1460 |
+
'accuracy': decision_accuracy,
|
| 1461 |
+
'f1_macro': f1_score,
|
| 1462 |
+
'count': len(decision_acc)
|
| 1463 |
+
}
|
| 1464 |
+
except Exception:
|
| 1465 |
+
refined_decision_metrics = {
|
| 1466 |
+
'accuracy': decision_accuracy,
|
| 1467 |
+
'count': len(decision_acc)
|
| 1468 |
+
}
|
| 1469 |
+
summary['refined_decision_metrics'] = refined_decision_metrics
|
| 1470 |
+
summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
|
| 1471 |
+
|
| 1472 |
+
# Calculate initial decision metrics
|
| 1473 |
+
initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1474 |
+
if initial_decision_results:
|
| 1475 |
+
true_decisions = []
|
| 1476 |
+
pred_decisions = []
|
| 1477 |
+
decision_acc = []
|
| 1478 |
+
|
| 1479 |
+
for r in initial_decision_results:
|
| 1480 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1481 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1482 |
+
|
| 1483 |
+
if 'accept' in pred_decision:
|
| 1484 |
+
pred_binary = 1
|
| 1485 |
+
else:
|
| 1486 |
+
pred_binary = 0
|
| 1487 |
+
|
| 1488 |
+
if 'accept' in gt_decision:
|
| 1489 |
+
gt_binary = 1
|
| 1490 |
+
else:
|
| 1491 |
+
gt_binary = 0
|
| 1492 |
+
|
| 1493 |
+
true_decisions.append(gt_binary)
|
| 1494 |
+
pred_decisions.append(pred_binary)
|
| 1495 |
+
|
| 1496 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1497 |
+
decision_acc.append(1.0)
|
| 1498 |
+
else:
|
| 1499 |
+
decision_acc.append(0.0)
|
| 1500 |
+
|
| 1501 |
+
if decision_acc:
|
| 1502 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1503 |
+
try:
|
| 1504 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1505 |
+
initial_decision_metrics = {
|
| 1506 |
+
'accuracy': decision_accuracy,
|
| 1507 |
+
'f1_macro': f1_score,
|
| 1508 |
+
'count': len(decision_acc)
|
| 1509 |
+
}
|
| 1510 |
+
except Exception:
|
| 1511 |
+
initial_decision_metrics = {
|
| 1512 |
+
'accuracy': decision_accuracy,
|
| 1513 |
+
'count': len(decision_acc)
|
| 1514 |
+
}
|
| 1515 |
+
summary['initial_decision_metrics'] = initial_decision_metrics
|
| 1516 |
+
else:
|
| 1517 |
+
# Original/other formats: use all results
|
| 1518 |
+
decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1519 |
+
if decision_results:
|
| 1520 |
+
true_decisions = []
|
| 1521 |
+
pred_decisions = []
|
| 1522 |
+
decision_acc = []
|
| 1523 |
+
|
| 1524 |
+
for r in decision_results:
|
| 1525 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1526 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1527 |
+
|
| 1528 |
+
if 'accept' in pred_decision:
|
| 1529 |
+
pred_binary = 1
|
| 1530 |
+
else:
|
| 1531 |
+
pred_binary = 0
|
| 1532 |
+
|
| 1533 |
+
if 'accept' in gt_decision:
|
| 1534 |
+
gt_binary = 1
|
| 1535 |
+
else:
|
| 1536 |
+
gt_binary = 0
|
| 1537 |
+
|
| 1538 |
+
true_decisions.append(gt_binary)
|
| 1539 |
+
pred_decisions.append(pred_binary)
|
| 1540 |
+
|
| 1541 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1542 |
+
decision_acc.append(1.0)
|
| 1543 |
+
else:
|
| 1544 |
+
decision_acc.append(0.0)
|
| 1545 |
+
|
| 1546 |
+
if decision_acc:
|
| 1547 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1548 |
+
try:
|
| 1549 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1550 |
+
summary['decision_metrics'] = {
|
| 1551 |
+
'accuracy': decision_accuracy,
|
| 1552 |
+
'f1_macro': f1_score,
|
| 1553 |
+
'count': len(decision_acc)
|
| 1554 |
+
}
|
| 1555 |
+
except Exception:
|
| 1556 |
+
summary['decision_metrics'] = {
|
| 1557 |
+
'accuracy': decision_accuracy,
|
| 1558 |
+
'count': len(decision_acc)
|
| 1559 |
+
}
|
| 1560 |
+
|
| 1561 |
+
# Calculate Pairwise comparison
|
| 1562 |
+
# For refined format, only use refined results (avoid double counting)
|
| 1563 |
+
# For other formats, use all results
|
| 1564 |
+
if refined_format_count > 0:
|
| 1565 |
+
pairwise_results = refined_results
|
| 1566 |
+
else:
|
| 1567 |
+
pairwise_results = valid_results
|
| 1568 |
+
|
| 1569 |
+
paper_scores = []
|
| 1570 |
+
for r in pairwise_results:
|
| 1571 |
+
if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
|
| 1572 |
+
(r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
|
| 1573 |
+
paper_scores.append({
|
| 1574 |
+
'true_rating': r.get('gt_rating'),
|
| 1575 |
+
'pred_rating': r.get('model_rating'),
|
| 1576 |
+
'true_soundness': r.get('gt_soundness'),
|
| 1577 |
+
'pred_soundness': r.get('model_soundness'),
|
| 1578 |
+
'true_presentation': r.get('gt_presentation'),
|
| 1579 |
+
'pred_presentation': r.get('model_presentation'),
|
| 1580 |
+
'true_confidence': r.get('gt_confidence'),
|
| 1581 |
+
'pred_confidence': r.get('model_confidence')
|
| 1582 |
+
})
|
| 1583 |
+
|
| 1584 |
+
if len(paper_scores) >= 2:
|
| 1585 |
+
pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
|
| 1586 |
+
summary['pairwise_accuracies'] = pairwise_accuracies
|
| 1587 |
+
|
| 1588 |
+
return results, summary
|
| 1589 |
+
|
| 1590 |
+
|
| 1591 |
+
# ============================================================================
|
| 1592 |
+
# Main Function
|
| 1593 |
+
# ============================================================================
|
| 1594 |
+
|
| 1595 |
+
def parse_args():
|
| 1596 |
+
"""Parse command line arguments."""
|
| 1597 |
+
parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
|
| 1598 |
+
|
| 1599 |
+
# Input paths
|
| 1600 |
+
parser.add_argument("--rubrics_path", type=str, required=True,
|
| 1601 |
+
help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
|
| 1602 |
+
parser.add_argument("--reviews_path", type=str, required=True,
|
| 1603 |
+
help="Path to JSON file with model reviews (contains pred_fast_mode)")
|
| 1604 |
+
|
| 1605 |
+
# Evaluation mode
|
| 1606 |
+
parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
|
| 1607 |
+
help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
|
| 1608 |
+
|
| 1609 |
+
# Output paths
|
| 1610 |
+
parser.add_argument("--semantic_output", type=str, default=None,
|
| 1611 |
+
help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
|
| 1612 |
+
parser.add_argument("--auto_metric_output", type=str, default=None,
|
| 1613 |
+
help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
|
| 1614 |
+
|
| 1615 |
+
# Semantic evaluation settings
|
| 1616 |
+
parser.add_argument("--yaml_path", type=str, default=None,
|
| 1617 |
+
help="Path to prompts.yaml file (required for semantic evaluation)")
|
| 1618 |
+
parser.add_argument("--config_path", type=str, default=None,
|
| 1619 |
+
help="Path to configs.yaml file (required for semantic evaluation)")
|
| 1620 |
+
|
| 1621 |
+
# Multi-threading
|
| 1622 |
+
parser.add_argument("--max_workers", type=int, default=None,
|
| 1623 |
+
help="Maximum number of worker threads for semantic evaluation (default: 5)")
|
| 1624 |
+
|
| 1625 |
+
# Strict mode (normalize scores to discrete scales)
|
| 1626 |
+
parser.add_argument("--strict_mode", action="store_true", default=False,
|
| 1627 |
+
help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
|
| 1628 |
+
|
| 1629 |
+
# Input format override
|
| 1630 |
+
parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
|
| 1631 |
+
help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
|
| 1632 |
+
|
| 1633 |
+
return parser.parse_args()
|
| 1634 |
+
|
| 1635 |
+
|
| 1636 |
+
def main():
|
| 1637 |
+
"""Main execution function."""
|
| 1638 |
+
args = parse_args()
|
| 1639 |
+
|
| 1640 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 1641 |
+
|
| 1642 |
+
# Resolve paths
|
| 1643 |
+
rubrics_path = args.rubrics_path
|
| 1644 |
+
if not os.path.isabs(rubrics_path):
|
| 1645 |
+
rubrics_path = os.path.join(script_dir, rubrics_path)
|
| 1646 |
+
|
| 1647 |
+
reviews_path = args.reviews_path
|
| 1648 |
+
if not os.path.isabs(reviews_path):
|
| 1649 |
+
reviews_path = os.path.join(script_dir, reviews_path)
|
| 1650 |
+
|
| 1651 |
+
max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
|
| 1652 |
+
|
| 1653 |
+
# Validate mode and output paths
|
| 1654 |
+
if args.mode in ["semantic", "both"]:
|
| 1655 |
+
if not args.semantic_output:
|
| 1656 |
+
raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
|
| 1657 |
+
if not args.yaml_path:
|
| 1658 |
+
raise ValueError("--yaml_path is required for semantic evaluation")
|
| 1659 |
+
if not args.config_path:
|
| 1660 |
+
raise ValueError("--config_path is required for semantic evaluation")
|
| 1661 |
+
|
| 1662 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1663 |
+
if not args.auto_metric_output:
|
| 1664 |
+
raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
|
| 1665 |
+
|
| 1666 |
+
# Check if files exist
|
| 1667 |
+
if not os.path.exists(rubrics_path):
|
| 1668 |
+
raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
|
| 1669 |
+
if not os.path.exists(reviews_path):
|
| 1670 |
+
raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
|
| 1671 |
+
|
| 1672 |
+
# Load data
|
| 1673 |
+
print(f"Loading rubrics from {rubrics_path}...")
|
| 1674 |
+
rubrics_data = load_rubrics_json(rubrics_path)
|
| 1675 |
+
print(f"Loaded {len(rubrics_data)} rubrics entries")
|
| 1676 |
+
|
| 1677 |
+
print(f"Loading model reviews from {reviews_path}...")
|
| 1678 |
+
if args.input_format != 'auto':
|
| 1679 |
+
print(f"Using manually specified format: {args.input_format}")
|
| 1680 |
+
else:
|
| 1681 |
+
print("Auto-detecting input format...")
|
| 1682 |
+
reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
|
| 1683 |
+
print(f"Loaded {len(reviews_dict)} model reviews")
|
| 1684 |
+
|
| 1685 |
+
# Combine rubrics and reviews
|
| 1686 |
+
print("Combining rubrics and reviews...")
|
| 1687 |
+
evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
|
| 1688 |
+
print(f"Prepared {len(evaluation_data)} entries for evaluation")
|
| 1689 |
+
|
| 1690 |
+
# Run evaluations based on mode
|
| 1691 |
+
if args.mode in ["semantic", "both"]:
|
| 1692 |
+
# Resolve semantic evaluation paths
|
| 1693 |
+
yaml_path = args.yaml_path
|
| 1694 |
+
if not os.path.isabs(yaml_path):
|
| 1695 |
+
yaml_path = os.path.join(script_dir, yaml_path)
|
| 1696 |
+
|
| 1697 |
+
config_path = args.config_path
|
| 1698 |
+
if not os.path.isabs(config_path):
|
| 1699 |
+
config_path = os.path.join(script_dir, config_path)
|
| 1700 |
+
|
| 1701 |
+
if not os.path.exists(yaml_path):
|
| 1702 |
+
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
|
| 1703 |
+
if not os.path.exists(config_path):
|
| 1704 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 1705 |
+
|
| 1706 |
+
# Load prompt template
|
| 1707 |
+
print(f"Loading prompt template from {yaml_path}...")
|
| 1708 |
+
prompt_template = load_prompt_template(yaml_path)
|
| 1709 |
+
if not prompt_template:
|
| 1710 |
+
raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
|
| 1711 |
+
|
| 1712 |
+
# Initialize LLM service
|
| 1713 |
+
print(f"Loading LLM configuration from {config_path}...")
|
| 1714 |
+
llm_config = load_llm_config(config_path)
|
| 1715 |
+
llm_service = create_llm_service_from_config(llm_config)
|
| 1716 |
+
mode = llm_config.get('mode', 'gpt')
|
| 1717 |
+
print(f"LLM service initialized (mode: {mode})")
|
| 1718 |
+
if hasattr(llm_service, 'model_name'):
|
| 1719 |
+
print(f"Using model: {llm_service.model_name}")
|
| 1720 |
+
|
| 1721 |
+
# Run semantic evaluation
|
| 1722 |
+
semantic_results, semantic_summary = run_semantic_evaluation(
|
| 1723 |
+
evaluation_data, prompt_template, llm_service, max_workers
|
| 1724 |
+
)
|
| 1725 |
+
|
| 1726 |
+
# Save semantic results
|
| 1727 |
+
semantic_output = args.semantic_output
|
| 1728 |
+
if not os.path.isabs(semantic_output):
|
| 1729 |
+
semantic_output = os.path.join(script_dir, semantic_output)
|
| 1730 |
+
|
| 1731 |
+
output_dir = os.path.dirname(semantic_output)
|
| 1732 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1733 |
+
|
| 1734 |
+
with open(semantic_output, 'w', encoding='utf-8') as f:
|
| 1735 |
+
json.dump(semantic_results, f, ensure_ascii=False, indent=2)
|
| 1736 |
+
print(f"\nSemantic evaluation results saved to {semantic_output}")
|
| 1737 |
+
|
| 1738 |
+
# Save semantic summary
|
| 1739 |
+
semantic_summary_path = semantic_output.replace('.json', '_summary.json')
|
| 1740 |
+
with open(semantic_summary_path, 'w', encoding='utf-8') as f:
|
| 1741 |
+
json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
|
| 1742 |
+
print(f"Semantic evaluation summary saved to {semantic_summary_path}")
|
| 1743 |
+
|
| 1744 |
+
# Print semantic summary
|
| 1745 |
+
print("\n" + "="*80)
|
| 1746 |
+
print("SEMANTIC EVALUATION SUMMARY")
|
| 1747 |
+
print("="*80)
|
| 1748 |
+
print(f"Total entries: {semantic_summary['total_entries']}")
|
| 1749 |
+
print(f"Valid entries: {semantic_summary['valid_entries']}")
|
| 1750 |
+
print(f"Failed entries: {semantic_summary['failed_entries']}")
|
| 1751 |
+
if 'overall_score' in semantic_summary:
|
| 1752 |
+
score = semantic_summary['overall_score']
|
| 1753 |
+
print(f"\nOverall Score:")
|
| 1754 |
+
print(f" Mean: {score['mean']:.2f}")
|
| 1755 |
+
print(f" Min: {score['min']:.2f}")
|
| 1756 |
+
print(f" Max: {score['max']:.2f}")
|
| 1757 |
+
|
| 1758 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1759 |
+
# Run auto-metric evaluation
|
| 1760 |
+
auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
|
| 1761 |
+
evaluation_data,
|
| 1762 |
+
strict_mode=args.strict_mode
|
| 1763 |
+
)
|
| 1764 |
+
|
| 1765 |
+
# Save auto-metric results
|
| 1766 |
+
auto_metric_output = args.auto_metric_output
|
| 1767 |
+
if not os.path.isabs(auto_metric_output):
|
| 1768 |
+
auto_metric_output = os.path.join(script_dir, auto_metric_output)
|
| 1769 |
+
|
| 1770 |
+
output_dir = os.path.dirname(auto_metric_output)
|
| 1771 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1772 |
+
|
| 1773 |
+
with open(auto_metric_output, 'w', encoding='utf-8') as f:
|
| 1774 |
+
json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
|
| 1775 |
+
print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
|
| 1776 |
+
|
| 1777 |
+
# Save auto-metric summary
|
| 1778 |
+
auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
|
| 1779 |
+
with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
|
| 1780 |
+
json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
|
| 1781 |
+
print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
|
| 1782 |
+
|
| 1783 |
+
# Print auto-metric summary
|
| 1784 |
+
print("\n" + "="*80)
|
| 1785 |
+
print("AUTO-METRIC EVALUATION SUMMARY")
|
| 1786 |
+
print("="*80)
|
| 1787 |
+
print(f"Total entries: {auto_metric_summary['total_entries']}")
|
| 1788 |
+
print(f"Valid entries: {auto_metric_summary['valid_entries']}")
|
| 1789 |
+
print(f"MSE entries: {auto_metric_summary['mse_entries']}")
|
| 1790 |
+
|
| 1791 |
+
if 'mse_statistics' in auto_metric_summary:
|
| 1792 |
+
print("\nMSE Statistics:")
|
| 1793 |
+
for dim, stats in auto_metric_summary['mse_statistics'].items():
|
| 1794 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1795 |
+
|
| 1796 |
+
if 'mae_statistics' in auto_metric_summary:
|
| 1797 |
+
print("\nMAE Statistics:")
|
| 1798 |
+
for dim, stats in auto_metric_summary['mae_statistics'].items():
|
| 1799 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1800 |
+
|
| 1801 |
+
# Print refined and initial statistics if available
|
| 1802 |
+
if 'refined_mse_statistics' in auto_metric_summary:
|
| 1803 |
+
print("\nRefined Scores - MSE Statistics:")
|
| 1804 |
+
for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
|
| 1805 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1806 |
+
|
| 1807 |
+
if 'refined_mae_statistics' in auto_metric_summary:
|
| 1808 |
+
print("\nRefined Scores - MAE Statistics:")
|
| 1809 |
+
for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
|
| 1810 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1811 |
+
|
| 1812 |
+
if 'initial_mse_statistics' in auto_metric_summary:
|
| 1813 |
+
print("\nInitial Scores - MSE Statistics:")
|
| 1814 |
+
for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
|
| 1815 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1816 |
+
|
| 1817 |
+
if 'initial_mae_statistics' in auto_metric_summary:
|
| 1818 |
+
print("\nInitial Scores - MAE Statistics:")
|
| 1819 |
+
for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
|
| 1820 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1821 |
+
|
| 1822 |
+
if 'spearman_correlations' in auto_metric_summary:
|
| 1823 |
+
print("\nSpearman Correlations:")
|
| 1824 |
+
for dim, stats in auto_metric_summary['spearman_correlations'].items():
|
| 1825 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1826 |
+
|
| 1827 |
+
# Print refined and initial spearman correlations if available
|
| 1828 |
+
if 'refined_spearman_correlations' in auto_metric_summary:
|
| 1829 |
+
print("\nRefined Scores - Spearman Correlations:")
|
| 1830 |
+
for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
|
| 1831 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1832 |
+
|
| 1833 |
+
if 'initial_spearman_correlations' in auto_metric_summary:
|
| 1834 |
+
print("\nInitial Scores - Spearman Correlations:")
|
| 1835 |
+
for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
|
| 1836 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1837 |
+
|
| 1838 |
+
if 'decision_metrics' in auto_metric_summary:
|
| 1839 |
+
dm = auto_metric_summary['decision_metrics']
|
| 1840 |
+
print(f"\nDecision Metrics:")
|
| 1841 |
+
print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
|
| 1842 |
+
if 'f1_macro' in dm:
|
| 1843 |
+
print(f" F1 (macro): {dm['f1_macro']:.4f}")
|
| 1844 |
+
|
| 1845 |
+
# Print refined and initial decision metrics if available
|
| 1846 |
+
if 'refined_decision_metrics' in auto_metric_summary:
|
| 1847 |
+
print("\nRefined Scores - Decision Metrics:")
|
| 1848 |
+
rdm = auto_metric_summary['refined_decision_metrics']
|
| 1849 |
+
print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
|
| 1850 |
+
if 'f1_macro' in rdm:
|
| 1851 |
+
print(f" F1 (macro): {rdm['f1_macro']:.4f}")
|
| 1852 |
+
|
| 1853 |
+
if 'initial_decision_metrics' in auto_metric_summary:
|
| 1854 |
+
print("\nInitial Scores - Decision Metrics:")
|
| 1855 |
+
idm = auto_metric_summary['initial_decision_metrics']
|
| 1856 |
+
print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
|
| 1857 |
+
if 'f1_macro' in idm:
|
| 1858 |
+
print(f" F1 (macro): {idm['f1_macro']:.4f}")
|
| 1859 |
+
|
| 1860 |
+
print("\n" + "="*80)
|
| 1861 |
+
print("EVALUATION COMPLETE")
|
| 1862 |
+
print("="*80)
|
| 1863 |
+
|
| 1864 |
+
|
| 1865 |
+
if __name__ == "__main__":
|
| 1866 |
+
main()
|
src/evaluator/2_evaluate_cyclereviewer.py
ADDED
|
@@ -0,0 +1,1837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
|
| 3 |
+
|
| 4 |
+
This script:
|
| 5 |
+
1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
|
| 6 |
+
2. Reads input JSON file containing model reviews (supports multiple formats)
|
| 7 |
+
3. Supports three evaluation modes:
|
| 8 |
+
- semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
|
| 9 |
+
- auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
|
| 10 |
+
- both: Run both evaluations separately
|
| 11 |
+
4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
|
| 12 |
+
5. Outputs separate JSON files for results and summaries
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Semantic evaluation only
|
| 16 |
+
python 2_evaluate.py \
|
| 17 |
+
--rubrics_path eval_rubrics.json \
|
| 18 |
+
--reviews_path model_reviews.json \
|
| 19 |
+
--mode semantic \
|
| 20 |
+
--yaml_path prompts.yaml \
|
| 21 |
+
--config_path configs.yaml \
|
| 22 |
+
--semantic_output semantic_results.json \
|
| 23 |
+
--max_workers 5
|
| 24 |
+
|
| 25 |
+
# Auto-metric evaluation only
|
| 26 |
+
python 2_evaluate.py \
|
| 27 |
+
--rubrics_path eval_rubrics.json \
|
| 28 |
+
--reviews_path model_reviews.json \
|
| 29 |
+
--mode auto_metric \
|
| 30 |
+
--auto_metric_output auto_metric_results.json
|
| 31 |
+
|
| 32 |
+
# Auto-metric evaluation with strict mode (normalize scores to discrete scales)
|
| 33 |
+
python 2_evaluate.py \
|
| 34 |
+
--rubrics_path eval_rubrics.json \
|
| 35 |
+
--reviews_path model_reviews.json \
|
| 36 |
+
--mode auto_metric \
|
| 37 |
+
--auto_metric_output auto_metric_results.json \
|
| 38 |
+
--strict_mode
|
| 39 |
+
|
| 40 |
+
# Auto-metric evaluation with manually specified input format (refined)
|
| 41 |
+
python 2_evaluate.py \
|
| 42 |
+
--rubrics_path eval_rubrics.json \
|
| 43 |
+
--reviews_path model_reviews.json \
|
| 44 |
+
--mode auto_metric \
|
| 45 |
+
--auto_metric_output auto_metric_results.json \
|
| 46 |
+
--input_format refined
|
| 47 |
+
|
| 48 |
+
# Auto-metric evaluation with manually specified input format (original)
|
| 49 |
+
python 2_evaluate.py \
|
| 50 |
+
--rubrics_path eval_rubrics.json \
|
| 51 |
+
--reviews_path ours.json \
|
| 52 |
+
--mode auto_metric \
|
| 53 |
+
--auto_metric_output auto_metric_results.json \
|
| 54 |
+
--input_format original
|
| 55 |
+
|
| 56 |
+
# Both evaluations
|
| 57 |
+
python 2_evaluate.py \
|
| 58 |
+
--rubrics_path eval_rubrics.json \
|
| 59 |
+
--reviews_path model_reviews.json \
|
| 60 |
+
--mode both \
|
| 61 |
+
--yaml_path prompts.yaml \
|
| 62 |
+
--config_path configs.yaml \
|
| 63 |
+
--semantic_output semantic_results.json \
|
| 64 |
+
--auto_metric_output auto_metric_results.json \
|
| 65 |
+
--max_workers 32
|
| 66 |
+
"""
|
| 67 |
+
from __future__ import annotations
|
| 68 |
+
|
| 69 |
+
import json
|
| 70 |
+
import os
|
| 71 |
+
import sys
|
| 72 |
+
import argparse
|
| 73 |
+
import yaml
|
| 74 |
+
import math
|
| 75 |
+
import re
|
| 76 |
+
from typing import Dict, List, Any, Optional
|
| 77 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 78 |
+
from tqdm import tqdm
|
| 79 |
+
from itertools import combinations
|
| 80 |
+
from scipy.stats import spearmanr
|
| 81 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 82 |
+
|
| 83 |
+
# Add parent directory to path
|
| 84 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 85 |
+
# Import parse_llm_response from local llm_service module
|
| 86 |
+
import llm_service as local_llm_service
|
| 87 |
+
parse_llm_response = local_llm_service.parse_llm_response
|
| 88 |
+
|
| 89 |
+
# Import from shared/utils for gpt/vllm support
|
| 90 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 91 |
+
if project_root not in sys.path:
|
| 92 |
+
sys.path.insert(0, project_root)
|
| 93 |
+
|
| 94 |
+
from shared.utils.llm_service import LLMService
|
| 95 |
+
from shared.utils.vllm_service import VLLMService
|
| 96 |
+
from shared.utils.gpt_service import GPTService
|
| 97 |
+
sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
|
| 98 |
+
from json_parser import parse_review_markdown
|
| 99 |
+
|
| 100 |
+
def convert_cyclereviewer(review_text: str) -> tuple:
|
| 101 |
+
"""
|
| 102 |
+
Convert the review text from cyclereviewer format to unified review system format.
|
| 103 |
+
|
| 104 |
+
The cyclereviewer format has markdown sections like:
|
| 105 |
+
"## Rating\n\n3: reject, not good enough\n\n## Confidence\n\n4: You are confident...\n\n"
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
review_text: Raw review text string (markdown format)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Tuple of (formatted_review_text, meta_review_dict)
|
| 112 |
+
"""
|
| 113 |
+
# Use parse_review_markdown to extract scores from markdown sections
|
| 114 |
+
parsed = {}
|
| 115 |
+
try:
|
| 116 |
+
parsed = parse_review_markdown(review_text)
|
| 117 |
+
except Exception:
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
# Extract rating - can be from "## Rating\n\n3: reject..." or "## Score: 6: ..."
|
| 121 |
+
rating = parsed.get('rating')
|
| 122 |
+
if rating is None:
|
| 123 |
+
# Try to extract from "## Rating\n\n3: reject..." format
|
| 124 |
+
rating_match = re.search(r'##\s*Rating\s*\n\n\s*(\d+\.?\d*)\s*:', review_text, re.IGNORECASE | re.MULTILINE)
|
| 125 |
+
if rating_match:
|
| 126 |
+
try:
|
| 127 |
+
rating = float(rating_match.group(1))
|
| 128 |
+
except (ValueError, IndexError):
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
# Try "## Score: 6: ..." format
|
| 132 |
+
if rating is None:
|
| 133 |
+
score_match = re.search(r'##\s*Score\s*:\s*(\d+\.?\d*)\s*:', review_text, re.IGNORECASE | re.MULTILINE)
|
| 134 |
+
if score_match:
|
| 135 |
+
try:
|
| 136 |
+
rating = float(score_match.group(1))
|
| 137 |
+
except (ValueError, IndexError):
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
# Extract confidence
|
| 141 |
+
confidence = parsed.get('confidence')
|
| 142 |
+
if confidence is None:
|
| 143 |
+
# Try to extract from "## Confidence\n\n4: You are confident..." format
|
| 144 |
+
confidence_match = re.search(r'##\s*Confidence\s*\n\n\s*(\d+\.?\d*)\s*:', review_text, re.IGNORECASE | re.MULTILINE)
|
| 145 |
+
if confidence_match:
|
| 146 |
+
try:
|
| 147 |
+
confidence = float(confidence_match.group(1))
|
| 148 |
+
except (ValueError, IndexError):
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
# Extract decision from rating text (e.g., "3: reject, not good enough")
|
| 152 |
+
decision = None
|
| 153 |
+
if rating is not None:
|
| 154 |
+
# Look for decision in rating section
|
| 155 |
+
rating_section_match = re.search(r'##\s*Rating\s*\n\n(.*?)(?=\n##|$)', review_text, re.IGNORECASE | re.DOTALL)
|
| 156 |
+
if rating_section_match:
|
| 157 |
+
rating_content = rating_section_match.group(1)
|
| 158 |
+
# Extract decision from text like "3: reject, not good enough"
|
| 159 |
+
decision_match = re.search(r':\s*(accept|reject|undecided)', rating_content, re.IGNORECASE)
|
| 160 |
+
if decision_match:
|
| 161 |
+
decision = decision_match.group(1).lower()
|
| 162 |
+
|
| 163 |
+
# Also try Score section
|
| 164 |
+
if decision is None:
|
| 165 |
+
score_section_match = re.search(r'##\s*Score\s*:\s*\d+\s*:\s*(.*?)(?=\n##|$)', review_text, re.IGNORECASE | re.DOTALL)
|
| 166 |
+
if score_section_match:
|
| 167 |
+
score_content = score_section_match.group(1)
|
| 168 |
+
decision_match = re.search(r'(accept|reject|undecided)', score_content, re.IGNORECASE)
|
| 169 |
+
if decision_match:
|
| 170 |
+
decision = decision_match.group(1).lower()
|
| 171 |
+
|
| 172 |
+
# Extract soundness, presentation from parsed data or markdown
|
| 173 |
+
soundness = parsed.get('soundness')
|
| 174 |
+
presentation = parsed.get('presentation')
|
| 175 |
+
contribution = parsed.get('contribution')
|
| 176 |
+
|
| 177 |
+
# Create meta_review dict
|
| 178 |
+
meta_review = {
|
| 179 |
+
"rating": rating,
|
| 180 |
+
"soundness": soundness,
|
| 181 |
+
"presentation": presentation,
|
| 182 |
+
"contribution": contribution,
|
| 183 |
+
"confidence": confidence,
|
| 184 |
+
"decision": decision,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# Return the review text as-is (it's already in markdown format)
|
| 188 |
+
return review_text, meta_review
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class ReviewProcessor:
|
| 192 |
+
"""Handles the extraction and processing of reviews from different sources."""
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def extract_review_content(pred_context):
|
| 196 |
+
"""
|
| 197 |
+
Extract the review content from the prediction context.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
pred_context: Raw prediction data that contains the review
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
str: Extracted review content
|
| 204 |
+
"""
|
| 205 |
+
try:
|
| 206 |
+
# First attempt to extract from boxed format
|
| 207 |
+
return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 208 |
+
except Exception:
|
| 209 |
+
# Alternative extraction if the first method fails
|
| 210 |
+
if isinstance(pred_context, dict) and 'output' in pred_context:
|
| 211 |
+
return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
|
| 212 |
+
else:
|
| 213 |
+
# Return as is if extraction fails
|
| 214 |
+
return pred_context
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ============================================================================
|
| 218 |
+
# Semantic Evaluation Functions (from 2_evaluate_direct.py)
|
| 219 |
+
# ============================================================================
|
| 220 |
+
|
| 221 |
+
def load_prompt_template(yaml_path: str) -> str:
|
| 222 |
+
"""Load the evaluator prompt from YAML file."""
|
| 223 |
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
| 224 |
+
prompts = yaml.safe_load(f)
|
| 225 |
+
return prompts.get('v1_evaluator_prompt', '')
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def build_evaluation_prompt(
|
| 229 |
+
rubrics: List[Dict[str, Any]],
|
| 230 |
+
paper_content: str,
|
| 231 |
+
review: str,
|
| 232 |
+
prompt_template: str
|
| 233 |
+
) -> str:
|
| 234 |
+
"""Build the evaluation prompt by replacing placeholders."""
|
| 235 |
+
rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
|
| 236 |
+
prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
|
| 237 |
+
prompt = prompt.replace('<<paper_content>>', paper_content)
|
| 238 |
+
prompt = prompt.replace('<<review>>', review)
|
| 239 |
+
return prompt
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def calculate_weighted_scores(
|
| 243 |
+
raw_scores: Dict[str, Dict[str, Any]],
|
| 244 |
+
rubrics: List[Dict[str, Any]]
|
| 245 |
+
) -> Dict[str, float]:
|
| 246 |
+
"""Calculate weighted scores for each rubric."""
|
| 247 |
+
rubric_weights = {r['title']: r['weight'] for r in rubrics}
|
| 248 |
+
weighted_scores = {}
|
| 249 |
+
|
| 250 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 251 |
+
if rubric_title not in rubric_weights:
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
rubric_score = rubric_data.get('score', 0)
|
| 255 |
+
if isinstance(rubric_score, str):
|
| 256 |
+
try:
|
| 257 |
+
rubric_score = int(rubric_score)
|
| 258 |
+
except ValueError:
|
| 259 |
+
rubric_score = 0
|
| 260 |
+
|
| 261 |
+
if rubric_score not in [0, 1]:
|
| 262 |
+
rubric_score = 1 if rubric_score > 0 else 0
|
| 263 |
+
|
| 264 |
+
weight = rubric_weights[rubric_title]
|
| 265 |
+
weighted_scores[rubric_title] = rubric_score * weight
|
| 266 |
+
|
| 267 |
+
return weighted_scores
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
|
| 271 |
+
"""Calculate scores for each rubric."""
|
| 272 |
+
scores = {}
|
| 273 |
+
for rubric_title, rubric_data in raw_scores.items():
|
| 274 |
+
scores[rubric_title] = rubric_data.get('score', 0)
|
| 275 |
+
return scores
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def evaluate_review_semantic(
|
| 279 |
+
entry: Dict[str, Any],
|
| 280 |
+
paper_content: str,
|
| 281 |
+
prompt_template: str,
|
| 282 |
+
llm_service: LLMService
|
| 283 |
+
) -> Dict[str, Any]:
|
| 284 |
+
"""Evaluate a single review using article-specific rubrics."""
|
| 285 |
+
entry_id = entry.get('id', 'unknown')
|
| 286 |
+
rubrics = entry.get('rubrics', [])
|
| 287 |
+
model_review = entry.get('model_review', '')
|
| 288 |
+
|
| 289 |
+
if not rubrics:
|
| 290 |
+
return {
|
| 291 |
+
'id': entry_id,
|
| 292 |
+
'raw_scores': {},
|
| 293 |
+
'weighted_scores': {},
|
| 294 |
+
'total_score': 0.0,
|
| 295 |
+
'error': 'No valid rubrics found',
|
| 296 |
+
'raw_response': ''
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
# Build prompt
|
| 300 |
+
prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
|
| 301 |
+
|
| 302 |
+
# Call LLM
|
| 303 |
+
try:
|
| 304 |
+
messages = [{"role": "user", "content": prompt}]
|
| 305 |
+
response = llm_service.generate(messages=messages)
|
| 306 |
+
|
| 307 |
+
# Parse response
|
| 308 |
+
raw_scores = parse_llm_response(response)
|
| 309 |
+
weighted_scores = calculate_scores(raw_scores)
|
| 310 |
+
total_score = sum(weighted_scores.values())
|
| 311 |
+
|
| 312 |
+
return {
|
| 313 |
+
'id': entry_id,
|
| 314 |
+
'raw_scores': raw_scores,
|
| 315 |
+
'weighted_scores': weighted_scores,
|
| 316 |
+
'total_score': total_score,
|
| 317 |
+
'raw_response': response
|
| 318 |
+
}
|
| 319 |
+
except Exception as e:
|
| 320 |
+
print(f"[ERROR] Error evaluating review {entry_id}: {e}")
|
| 321 |
+
return {
|
| 322 |
+
'id': entry_id,
|
| 323 |
+
'raw_scores': {},
|
| 324 |
+
'weighted_scores': {},
|
| 325 |
+
'total_score': 0.0,
|
| 326 |
+
'error': str(e),
|
| 327 |
+
'raw_response': ''
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def calculate_per_rubric_statistics(
|
| 332 |
+
valid_results: List[Dict[str, Any]],
|
| 333 |
+
rubric_titles: List[str]
|
| 334 |
+
) -> Dict[str, Dict[str, float]]:
|
| 335 |
+
"""Calculate per-rubric statistics from evaluation results."""
|
| 336 |
+
rubric_scores = {title: [] for title in rubric_titles}
|
| 337 |
+
|
| 338 |
+
for result in valid_results:
|
| 339 |
+
weighted_scores = result.get('weighted_scores', {})
|
| 340 |
+
if not isinstance(weighted_scores, dict):
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
for rubric_title in rubric_titles:
|
| 344 |
+
if rubric_title in weighted_scores:
|
| 345 |
+
score = weighted_scores[rubric_title]
|
| 346 |
+
if isinstance(score, str):
|
| 347 |
+
try:
|
| 348 |
+
score = float(score)
|
| 349 |
+
except ValueError:
|
| 350 |
+
continue
|
| 351 |
+
elif isinstance(score, (int, float)):
|
| 352 |
+
score = float(score)
|
| 353 |
+
else:
|
| 354 |
+
continue
|
| 355 |
+
rubric_scores[rubric_title].append(score)
|
| 356 |
+
|
| 357 |
+
per_rubric_stats = {}
|
| 358 |
+
for rubric_title in rubric_titles:
|
| 359 |
+
scores = rubric_scores[rubric_title]
|
| 360 |
+
if not scores:
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
mean_score = sum(scores) / len(scores)
|
| 364 |
+
min_score = min(scores)
|
| 365 |
+
max_score = max(scores)
|
| 366 |
+
count = len(scores)
|
| 367 |
+
|
| 368 |
+
if rubric_title == "False or Contradictory Claims":
|
| 369 |
+
pass_count = sum(1 for s in scores if s >= 0)
|
| 370 |
+
else:
|
| 371 |
+
pass_count = sum(1 for s in scores if s >= 1)
|
| 372 |
+
pass_rate = pass_count / count if count > 0 else 0.0
|
| 373 |
+
|
| 374 |
+
per_rubric_stats[rubric_title] = {
|
| 375 |
+
'mean': mean_score,
|
| 376 |
+
'min': min_score,
|
| 377 |
+
'max': max_score,
|
| 378 |
+
'count': count,
|
| 379 |
+
'pass_rate': pass_rate
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
return per_rubric_stats
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# ============================================================================
|
| 386 |
+
# Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
|
| 387 |
+
# ============================================================================
|
| 388 |
+
|
| 389 |
+
def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
|
| 390 |
+
"""Extract numeric scores and decision from a review markdown text."""
|
| 391 |
+
if not review_text:
|
| 392 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 393 |
+
|
| 394 |
+
try:
|
| 395 |
+
parsed = parse_review_markdown(review_text)
|
| 396 |
+
decision = parsed.get('decision', '')
|
| 397 |
+
if decision:
|
| 398 |
+
decision_lower = decision.lower().strip()
|
| 399 |
+
if 'accept' in decision_lower:
|
| 400 |
+
decision = 'accept'
|
| 401 |
+
elif 'reject' in decision_lower:
|
| 402 |
+
decision = 'reject'
|
| 403 |
+
elif 'undecided' in decision_lower:
|
| 404 |
+
decision = 'undecided'
|
| 405 |
+
else:
|
| 406 |
+
decision = decision_lower
|
| 407 |
+
else:
|
| 408 |
+
decision = None
|
| 409 |
+
|
| 410 |
+
return {
|
| 411 |
+
'soundness': parsed.get('soundness'),
|
| 412 |
+
'presentation': parsed.get('presentation'),
|
| 413 |
+
'rating': parsed.get('rating'),
|
| 414 |
+
'confidence': parsed.get('confidence'),
|
| 415 |
+
'decision': decision
|
| 416 |
+
}
|
| 417 |
+
except Exception as e:
|
| 418 |
+
print(f"Warning: Failed to parse review text: {e}")
|
| 419 |
+
return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
|
| 423 |
+
"""Calculate Mean Squared Error for a single value."""
|
| 424 |
+
if predicted is None or ground_truth is None:
|
| 425 |
+
return None
|
| 426 |
+
return (predicted - ground_truth) ** 2
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
|
| 430 |
+
"""Calculate Mean Absolute Error for a single value."""
|
| 431 |
+
if predicted is None or ground_truth is None:
|
| 432 |
+
return None
|
| 433 |
+
return abs(predicted - ground_truth)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
|
| 437 |
+
"""
|
| 438 |
+
Normalize a float score to the nearest discrete value based on scale type.
|
| 439 |
+
Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
score: The float score to normalize (can be None)
|
| 443 |
+
scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
|
| 444 |
+
or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
Normalized discrete score, or None if input is None
|
| 448 |
+
"""
|
| 449 |
+
if score is None:
|
| 450 |
+
return None
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
score = float(score)
|
| 454 |
+
except (ValueError, TypeError):
|
| 455 |
+
return None
|
| 456 |
+
|
| 457 |
+
if scale_type == '0-5':
|
| 458 |
+
# Discrete values: 0, 1, 2, 3, 4, 5
|
| 459 |
+
discrete_values = [0, 1, 2, 3, 4, 5]
|
| 460 |
+
# Clamp to valid range
|
| 461 |
+
score = max(0, min(5, score))
|
| 462 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 463 |
+
# For ties, prefer the higher value
|
| 464 |
+
best_value = None
|
| 465 |
+
best_distance = float('inf')
|
| 466 |
+
for val in discrete_values:
|
| 467 |
+
distance = abs(val - score)
|
| 468 |
+
if distance < best_distance:
|
| 469 |
+
best_distance = distance
|
| 470 |
+
best_value = val
|
| 471 |
+
elif distance == best_distance and val > best_value:
|
| 472 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 473 |
+
best_value = val
|
| 474 |
+
return best_value
|
| 475 |
+
elif scale_type == '0-10':
|
| 476 |
+
# Discrete values: 0, 2, 4, 6, 8, 10
|
| 477 |
+
discrete_values = [0, 2, 4, 6, 8, 10]
|
| 478 |
+
# Clamp to valid range
|
| 479 |
+
score = max(0, min(10, score))
|
| 480 |
+
# Find nearest discrete value, with round-half-up tie-breaking
|
| 481 |
+
best_value = None
|
| 482 |
+
best_distance = float('inf')
|
| 483 |
+
for val in discrete_values:
|
| 484 |
+
distance = abs(val - score)
|
| 485 |
+
if distance < best_distance:
|
| 486 |
+
best_distance = distance
|
| 487 |
+
best_value = val
|
| 488 |
+
elif distance == best_distance and val > best_value:
|
| 489 |
+
# Tie-breaking: prefer higher value (round-half-up)
|
| 490 |
+
best_value = val
|
| 491 |
+
return best_value
|
| 492 |
+
else:
|
| 493 |
+
raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
|
| 497 |
+
"""
|
| 498 |
+
Normalize all scores in a dictionary to their appropriate discrete scales.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
Dictionary with normalized scores
|
| 505 |
+
"""
|
| 506 |
+
normalized = {}
|
| 507 |
+
|
| 508 |
+
# soundness, presentation, confidence use 0-5 scale
|
| 509 |
+
for key in ['soundness', 'presentation', 'confidence']:
|
| 510 |
+
normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
|
| 511 |
+
|
| 512 |
+
# rating uses 0-10 scale
|
| 513 |
+
normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
|
| 514 |
+
|
| 515 |
+
return normalized
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def calculate_score_metrics(
|
| 519 |
+
model_scores: Dict[str, float],
|
| 520 |
+
ground_truth_scores: Dict[str, float],
|
| 521 |
+
normalize: bool = False
|
| 522 |
+
) -> Dict[str, Any]:
|
| 523 |
+
"""
|
| 524 |
+
Calculate MSE and MAE metrics for each scoring dimension.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
model_scores: Dictionary with model scores
|
| 528 |
+
ground_truth_scores: Dictionary with ground truth scores
|
| 529 |
+
normalize: If True, normalize scores to discrete scales before computing metrics
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
Dictionary with MSE, MAE metrics and optionally normalized scores
|
| 533 |
+
"""
|
| 534 |
+
dimensions = ['soundness', 'presentation', 'rating', 'confidence']
|
| 535 |
+
|
| 536 |
+
# Normalize scores to discrete scales if requested
|
| 537 |
+
if normalize:
|
| 538 |
+
model_scores_normalized = normalize_scores_dict(model_scores)
|
| 539 |
+
gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
|
| 540 |
+
else:
|
| 541 |
+
model_scores_normalized = model_scores
|
| 542 |
+
gt_scores_normalized = ground_truth_scores
|
| 543 |
+
|
| 544 |
+
mse_values = {}
|
| 545 |
+
mae_values = {}
|
| 546 |
+
valid_count = 0
|
| 547 |
+
|
| 548 |
+
for dim in dimensions:
|
| 549 |
+
# Use normalized scores for metric calculation
|
| 550 |
+
mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 551 |
+
mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
|
| 552 |
+
mse_values[f'{dim}_mse'] = mse
|
| 553 |
+
mae_values[f'{dim}_mae'] = mae
|
| 554 |
+
if mse is not None:
|
| 555 |
+
valid_count += 1
|
| 556 |
+
|
| 557 |
+
overall_error = sum([v for v in mse_values.values() if v is not None])
|
| 558 |
+
|
| 559 |
+
result = {
|
| 560 |
+
**mse_values,
|
| 561 |
+
**mae_values,
|
| 562 |
+
'overall_error': overall_error if valid_count > 0 else None,
|
| 563 |
+
'valid_dimensions': valid_count
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
# Include normalized scores in result for transparency (only if normalize=True)
|
| 567 |
+
if normalize:
|
| 568 |
+
result['model_scores_normalized'] = model_scores_normalized
|
| 569 |
+
result['gt_scores_normalized'] = gt_scores_normalized
|
| 570 |
+
|
| 571 |
+
return result
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def normalize_score_value(value):
|
| 575 |
+
"""Normalize score value to float, handling string representations."""
|
| 576 |
+
if value is None:
|
| 577 |
+
return None
|
| 578 |
+
if isinstance(value, (int, float)):
|
| 579 |
+
return float(value)
|
| 580 |
+
if isinstance(value, str):
|
| 581 |
+
# Try to extract numeric value from string (e.g., "2.75" -> 2.75)
|
| 582 |
+
try:
|
| 583 |
+
import re
|
| 584 |
+
match = re.search(r'(\d+\.?\d*)', value)
|
| 585 |
+
if match:
|
| 586 |
+
return float(match.group(1))
|
| 587 |
+
except:
|
| 588 |
+
pass
|
| 589 |
+
return None
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def normalize_decision(decision):
|
| 593 |
+
"""Normalize decision string to standard format."""
|
| 594 |
+
if decision is None:
|
| 595 |
+
return None
|
| 596 |
+
decision_lower = str(decision).lower().strip()
|
| 597 |
+
if 'accept' in decision_lower:
|
| 598 |
+
return 'accept'
|
| 599 |
+
elif 'reject' in decision_lower:
|
| 600 |
+
return 'reject'
|
| 601 |
+
elif 'undecided' in decision_lower:
|
| 602 |
+
return 'undecided'
|
| 603 |
+
else:
|
| 604 |
+
return decision_lower
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 608 |
+
"""
|
| 609 |
+
Extract scores from a structured dictionary (scores or initial_scores format).
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
|
| 613 |
+
|
| 614 |
+
Returns:
|
| 615 |
+
Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
|
| 616 |
+
"""
|
| 617 |
+
if not scores_dict:
|
| 618 |
+
return {
|
| 619 |
+
'soundness': None,
|
| 620 |
+
'presentation': None,
|
| 621 |
+
'rating': None,
|
| 622 |
+
'confidence': None,
|
| 623 |
+
'decision': None
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
return {
|
| 627 |
+
'soundness': normalize_score_value(scores_dict.get('soundness')),
|
| 628 |
+
'presentation': normalize_score_value(scores_dict.get('presentation')),
|
| 629 |
+
'rating': normalize_score_value(scores_dict.get('rating')),
|
| 630 |
+
'confidence': normalize_score_value(scores_dict.get('confidence')),
|
| 631 |
+
'decision': normalize_decision(scores_dict.get('decision'))
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
|
| 636 |
+
"""
|
| 637 |
+
Evaluate a single entry by extracting scores and calculating metrics.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
entry: Evaluation entry containing model_review, scores, initial_scores, etc.
|
| 641 |
+
use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
Dict containing evaluation metrics
|
| 645 |
+
"""
|
| 646 |
+
entry_id = entry.get('id', 'unknown')
|
| 647 |
+
model_review = entry.get('model_review', '')
|
| 648 |
+
format_type = entry.get('format', 'unknown')
|
| 649 |
+
|
| 650 |
+
# Extract scores based on format
|
| 651 |
+
model_scores = {}
|
| 652 |
+
model_decision = None
|
| 653 |
+
|
| 654 |
+
if format_type == 'refined' and not use_initial_scores:
|
| 655 |
+
# Use refined scores from structured data
|
| 656 |
+
scores_dict = entry.get('scores', {})
|
| 657 |
+
model_data = extract_scores_from_dict(scores_dict)
|
| 658 |
+
model_scores = {
|
| 659 |
+
'soundness': model_data.get('soundness'),
|
| 660 |
+
'presentation': model_data.get('presentation'),
|
| 661 |
+
'rating': model_data.get('rating'),
|
| 662 |
+
'confidence': model_data.get('confidence')
|
| 663 |
+
}
|
| 664 |
+
model_decision = model_data.get('decision')
|
| 665 |
+
elif format_type == 'refined' and use_initial_scores:
|
| 666 |
+
# Use initial scores from structured data
|
| 667 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 668 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 669 |
+
model_scores = {
|
| 670 |
+
'soundness': model_data.get('soundness'),
|
| 671 |
+
'presentation': model_data.get('presentation'),
|
| 672 |
+
'rating': model_data.get('rating'),
|
| 673 |
+
'confidence': model_data.get('confidence')
|
| 674 |
+
}
|
| 675 |
+
model_decision = model_data.get('decision')
|
| 676 |
+
elif format_type == 'original':
|
| 677 |
+
# Use initial scores from structured data
|
| 678 |
+
initial_scores_dict = entry.get('initial_scores', {})
|
| 679 |
+
model_data = extract_scores_from_dict(initial_scores_dict)
|
| 680 |
+
model_scores = {
|
| 681 |
+
'soundness': model_data.get('soundness'),
|
| 682 |
+
'presentation': model_data.get('presentation'),
|
| 683 |
+
'rating': model_data.get('rating'),
|
| 684 |
+
'confidence': model_data.get('confidence')
|
| 685 |
+
}
|
| 686 |
+
model_decision = model_data.get('decision')
|
| 687 |
+
|
| 688 |
+
# Fallback: If confidence is missing from structured data, try to extract from review text
|
| 689 |
+
# (meta_review may not have confidence field, but review text might)
|
| 690 |
+
if model_scores.get('confidence') is None and model_review:
|
| 691 |
+
try:
|
| 692 |
+
review_data = extract_scores_from_review(model_review)
|
| 693 |
+
if review_data.get('confidence') is not None:
|
| 694 |
+
model_scores['confidence'] = review_data.get('confidence')
|
| 695 |
+
except Exception:
|
| 696 |
+
pass # Keep confidence as None if extraction fails
|
| 697 |
+
else:
|
| 698 |
+
# Fallback: extract from markdown review text
|
| 699 |
+
model_data = extract_scores_from_review(model_review)
|
| 700 |
+
model_scores = {
|
| 701 |
+
'soundness': model_data.get('soundness'),
|
| 702 |
+
'presentation': model_data.get('presentation'),
|
| 703 |
+
'rating': model_data.get('rating'),
|
| 704 |
+
'confidence': model_data.get('confidence')
|
| 705 |
+
}
|
| 706 |
+
model_decision = model_data.get('decision')
|
| 707 |
+
|
| 708 |
+
# Get ground truth scores from golden_review ONLY
|
| 709 |
+
# Ground truth must ONLY come from golden_review, never from model output
|
| 710 |
+
# If extraction fails, leave fields as None (do not use model_review as fallback)
|
| 711 |
+
ground_truth_review = entry.get('golden_review', '')
|
| 712 |
+
ground_truth_scores = {}
|
| 713 |
+
gt_decision = None
|
| 714 |
+
|
| 715 |
+
if not ground_truth_review:
|
| 716 |
+
print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
|
| 717 |
+
else:
|
| 718 |
+
try:
|
| 719 |
+
# Extract scores from golden_review markdown text
|
| 720 |
+
gt_data = extract_scores_from_review(ground_truth_review)
|
| 721 |
+
if not gt_data:
|
| 722 |
+
print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
|
| 723 |
+
else:
|
| 724 |
+
ground_truth_scores = {
|
| 725 |
+
'soundness': gt_data.get('soundness'),
|
| 726 |
+
'presentation': gt_data.get('presentation'),
|
| 727 |
+
'rating': gt_data.get('rating'),
|
| 728 |
+
'confidence': gt_data.get('confidence')
|
| 729 |
+
}
|
| 730 |
+
gt_decision = normalize_decision(gt_data.get('decision'))
|
| 731 |
+
# Note: If any field is None, it stays None - we do NOT use model_review as fallback
|
| 732 |
+
# Using model output as ground truth would inflate evaluation scores
|
| 733 |
+
except Exception as e:
|
| 734 |
+
print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
|
| 735 |
+
print(f" Ground truth scores will be empty. Error: {str(e)}")
|
| 736 |
+
|
| 737 |
+
# Calculate MSE and MAE metrics (with optional normalization in strict mode)
|
| 738 |
+
score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
|
| 739 |
+
|
| 740 |
+
# Calculate decision accuracy
|
| 741 |
+
decision_match = False
|
| 742 |
+
decision_accuracy = None
|
| 743 |
+
if model_decision is not None and gt_decision is not None:
|
| 744 |
+
model_decision_normalized = normalize_decision(model_decision)
|
| 745 |
+
decision_match = (model_decision_normalized == gt_decision)
|
| 746 |
+
decision_accuracy = 1.0 if decision_match else 0.0
|
| 747 |
+
|
| 748 |
+
result = {
|
| 749 |
+
'id': entry_id,
|
| 750 |
+
'format': format_type,
|
| 751 |
+
'model_soundness': model_scores.get('soundness'),
|
| 752 |
+
'model_presentation': model_scores.get('presentation'),
|
| 753 |
+
'model_rating': model_scores.get('rating'),
|
| 754 |
+
'model_confidence': model_scores.get('confidence'),
|
| 755 |
+
'model_decision': model_decision,
|
| 756 |
+
'gt_soundness': ground_truth_scores.get('soundness'),
|
| 757 |
+
'gt_presentation': ground_truth_scores.get('presentation'),
|
| 758 |
+
'gt_rating': ground_truth_scores.get('rating'),
|
| 759 |
+
'gt_confidence': ground_truth_scores.get('confidence'),
|
| 760 |
+
'gt_decision': gt_decision,
|
| 761 |
+
'decision_match': decision_match,
|
| 762 |
+
'decision_accuracy': decision_accuracy,
|
| 763 |
+
**score_metrics
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
# Add prefix to indicate which scores were used
|
| 767 |
+
if format_type == 'refined':
|
| 768 |
+
if use_initial_scores:
|
| 769 |
+
result['score_type'] = 'initial'
|
| 770 |
+
else:
|
| 771 |
+
result['score_type'] = 'refined'
|
| 772 |
+
else:
|
| 773 |
+
result['score_type'] = 'auto'
|
| 774 |
+
|
| 775 |
+
return result
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
|
| 779 |
+
"""Calculate pairwise accuracy for each metric by comparing rankings."""
|
| 780 |
+
if len(paper_scores) < 2:
|
| 781 |
+
return {}
|
| 782 |
+
|
| 783 |
+
total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 784 |
+
correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
|
| 785 |
+
|
| 786 |
+
for paper1, paper2 in combinations(paper_scores, 2):
|
| 787 |
+
# Check rating ranking
|
| 788 |
+
if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
|
| 789 |
+
paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
|
| 790 |
+
total_valid_pairs['rating'] += 1
|
| 791 |
+
true_order = paper1['true_rating'] > paper2['true_rating']
|
| 792 |
+
pred_order = paper1['pred_rating'] > paper2['pred_rating']
|
| 793 |
+
if true_order == pred_order:
|
| 794 |
+
correct_pairs['rating'] += 1
|
| 795 |
+
|
| 796 |
+
# Similar for other dimensions...
|
| 797 |
+
# (abbreviated for space, similar logic for soundness, presentation, confidence)
|
| 798 |
+
for metric in ['soundness', 'presentation', 'confidence']:
|
| 799 |
+
true_key = f'true_{metric}'
|
| 800 |
+
pred_key = f'pred_{metric}'
|
| 801 |
+
if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
|
| 802 |
+
paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
|
| 803 |
+
total_valid_pairs[metric] += 1
|
| 804 |
+
true_order = paper1[true_key] > paper2[true_key]
|
| 805 |
+
pred_order = paper1[pred_key] > paper2[pred_key]
|
| 806 |
+
if true_order == pred_order:
|
| 807 |
+
correct_pairs[metric] += 1
|
| 808 |
+
|
| 809 |
+
pairwise_accuracies = {
|
| 810 |
+
metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
|
| 811 |
+
for metric in ['rating', 'soundness', 'presentation', 'confidence']
|
| 812 |
+
}
|
| 813 |
+
|
| 814 |
+
return pairwise_accuracies
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
# ============================================================================
|
| 818 |
+
# Data Loading Functions
|
| 819 |
+
# ============================================================================
|
| 820 |
+
|
| 821 |
+
def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
|
| 822 |
+
"""Load rubrics JSON and create lookup by id."""
|
| 823 |
+
with open(rubrics_path, 'r', encoding='utf-8') as f:
|
| 824 |
+
data = json.load(f)
|
| 825 |
+
|
| 826 |
+
if isinstance(data, list):
|
| 827 |
+
return {item['id']: item for item in data}
|
| 828 |
+
elif isinstance(data, dict):
|
| 829 |
+
return data
|
| 830 |
+
else:
|
| 831 |
+
raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
| 835 |
+
"""
|
| 836 |
+
Load model reviews JSON and extract reviews by id.
|
| 837 |
+
|
| 838 |
+
Supports two input formats:
|
| 839 |
+
1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
|
| 840 |
+
2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
reviews_path: Path to JSON file containing model reviews
|
| 844 |
+
format_override: Optional format override ('refined', 'original', or None for auto-detect)
|
| 845 |
+
|
| 846 |
+
Returns:
|
| 847 |
+
Dict mapping paper_id to dict containing:
|
| 848 |
+
- 'review': review text (markdown)
|
| 849 |
+
- 'scores': refined scores dict (if available)
|
| 850 |
+
- 'initial_scores': initial scores dict (if available)
|
| 851 |
+
- 'format': 'refined' or 'original'
|
| 852 |
+
"""
|
| 853 |
+
with open(reviews_path, 'r', encoding='utf-8') as f:
|
| 854 |
+
data = json.load(f)
|
| 855 |
+
|
| 856 |
+
if isinstance(data, dict):
|
| 857 |
+
data = list(data.values())
|
| 858 |
+
|
| 859 |
+
reviews_dict = {}
|
| 860 |
+
for item in data:
|
| 861 |
+
item_id = None
|
| 862 |
+
review_text = ''
|
| 863 |
+
scores = None
|
| 864 |
+
initial_scores = None
|
| 865 |
+
format_type = None
|
| 866 |
+
|
| 867 |
+
# Use format override if provided, otherwise auto-detect
|
| 868 |
+
if format_override and format_override != 'auto':
|
| 869 |
+
# Force use specified format
|
| 870 |
+
if format_override == 'refined':
|
| 871 |
+
item_id = item.get('paper_id') or item.get('id')
|
| 872 |
+
if not item_id:
|
| 873 |
+
continue
|
| 874 |
+
format_type = 'refined'
|
| 875 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 876 |
+
scores = item.get('scores', {})
|
| 877 |
+
initial_scores = item.get('initial_scores', {})
|
| 878 |
+
elif format_override == 'original':
|
| 879 |
+
item_id = item.get('id')
|
| 880 |
+
if not item_id:
|
| 881 |
+
continue
|
| 882 |
+
format_type = 'original'
|
| 883 |
+
model_prediction = item.get('model_prediction', {})
|
| 884 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 885 |
+
review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 886 |
+
initial_scores = {
|
| 887 |
+
'rating': meta_review.get('rating'),
|
| 888 |
+
'soundness': meta_review.get('soundness'),
|
| 889 |
+
'presentation': meta_review.get('presentation'),
|
| 890 |
+
'contribution': meta_review.get('contribution'),
|
| 891 |
+
'decision': model_prediction.get('decision'),
|
| 892 |
+
}
|
| 893 |
+
else:
|
| 894 |
+
raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
|
| 895 |
+
else:
|
| 896 |
+
# Auto-detect format
|
| 897 |
+
if "paper_id" in item:
|
| 898 |
+
# Refined format (from refinement pipeline)
|
| 899 |
+
item_id = item.get('paper_id')
|
| 900 |
+
if not item_id:
|
| 901 |
+
continue
|
| 902 |
+
|
| 903 |
+
# Check if this is refined format (has scores and initial_scores)
|
| 904 |
+
if 'scores' in item and 'initial_scores' in item:
|
| 905 |
+
format_type = 'refined'
|
| 906 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 907 |
+
scores = item.get('scores', {})
|
| 908 |
+
initial_scores = item.get('initial_scores', {})
|
| 909 |
+
else:
|
| 910 |
+
# Standard format with paper_id
|
| 911 |
+
format_type = 'standard'
|
| 912 |
+
review_text = item.get('review_markdown', '') or item.get('review', '')
|
| 913 |
+
elif "model_prediction" in item:
|
| 914 |
+
# Original format (like ours.json) or cyclereviewer format
|
| 915 |
+
item_id = item.get('id')
|
| 916 |
+
if not item_id:
|
| 917 |
+
continue
|
| 918 |
+
|
| 919 |
+
format_type = 'original'
|
| 920 |
+
model_prediction = item.get('model_prediction', {})
|
| 921 |
+
meta_review = model_prediction.get('meta_review', {})
|
| 922 |
+
|
| 923 |
+
# Extract review content (prefer meta_review.content, fallback to raw_text)
|
| 924 |
+
review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
|
| 925 |
+
|
| 926 |
+
# Detect cyclereviewer format: has raw_text as markdown string with "## Rating" or "## Score:" patterns
|
| 927 |
+
is_cyclereviewer = False
|
| 928 |
+
if isinstance(review_text, str) and review_text:
|
| 929 |
+
# Check if it contains cyclereviewer patterns
|
| 930 |
+
if (re.search(r'##\s*(Rating|Score)\s*:', review_text, re.IGNORECASE) or
|
| 931 |
+
re.search(r'##\s*Rating\s*\n\n\s*\d+\s*:', review_text, re.IGNORECASE | re.MULTILINE)):
|
| 932 |
+
is_cyclereviewer = True
|
| 933 |
+
|
| 934 |
+
# Handle cyclereviewer format
|
| 935 |
+
if is_cyclereviewer:
|
| 936 |
+
review_text, meta_review = convert_cyclereviewer(review_text)
|
| 937 |
+
|
| 938 |
+
# Extract initial scores
|
| 939 |
+
# Use meta_review as primary source (from convert_cyclereviewer or original meta_review)
|
| 940 |
+
# Fallback to model_prediction.get('decision') if not in meta_review
|
| 941 |
+
initial_scores = {
|
| 942 |
+
'rating': meta_review.get('rating'),
|
| 943 |
+
'soundness': meta_review.get('soundness'),
|
| 944 |
+
'presentation': meta_review.get('presentation'),
|
| 945 |
+
'contribution': meta_review.get('contribution'),
|
| 946 |
+
'confidence': meta_review.get('confidence'),
|
| 947 |
+
'decision': meta_review.get('decision') or model_prediction.get('decision'),
|
| 948 |
+
}
|
| 949 |
+
else:
|
| 950 |
+
# Legacy format (pred_fast_mode)
|
| 951 |
+
item_id = item.get('id')
|
| 952 |
+
if not item_id:
|
| 953 |
+
continue
|
| 954 |
+
|
| 955 |
+
format_type = 'legacy'
|
| 956 |
+
review_dict = item.get('pred_fast_mode', {})
|
| 957 |
+
if isinstance(review_dict, dict):
|
| 958 |
+
# review_text = review_dict.get('raw_text', '')
|
| 959 |
+
review_text = review_dict
|
| 960 |
+
else:
|
| 961 |
+
review_text = str(review_dict)
|
| 962 |
+
|
| 963 |
+
# Extract review content from the review text field
|
| 964 |
+
try:
|
| 965 |
+
if review_text:
|
| 966 |
+
extracted_review = ReviewProcessor.extract_review_content(review_text)
|
| 967 |
+
else:
|
| 968 |
+
extracted_review = ''
|
| 969 |
+
|
| 970 |
+
reviews_dict[item_id] = {
|
| 971 |
+
'review': extracted_review,
|
| 972 |
+
'scores': scores,
|
| 973 |
+
'initial_scores': initial_scores,
|
| 974 |
+
'format': format_type
|
| 975 |
+
}
|
| 976 |
+
except Exception as e:
|
| 977 |
+
print(f"[WARN] Failed to extract review for {item_id}: {e}")
|
| 978 |
+
continue
|
| 979 |
+
|
| 980 |
+
return reviews_dict
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
def combine_rubrics_and_reviews(
|
| 984 |
+
rubrics_data: Dict[str, Dict[str, Any]],
|
| 985 |
+
reviews_dict: Dict[str, Dict[str, Any]]
|
| 986 |
+
) -> List[Dict[str, Any]]:
|
| 987 |
+
"""
|
| 988 |
+
Combine rubrics and reviews into evaluation entries.
|
| 989 |
+
|
| 990 |
+
Args:
|
| 991 |
+
rubrics_data: Dict mapping paper_id to rubric entry
|
| 992 |
+
reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
|
| 993 |
+
|
| 994 |
+
Returns:
|
| 995 |
+
List of evaluation entries with model_review, scores, initial_scores, and format info
|
| 996 |
+
"""
|
| 997 |
+
combined = []
|
| 998 |
+
missing_reviews = []
|
| 999 |
+
|
| 1000 |
+
for paper_id, rubric_entry in rubrics_data.items():
|
| 1001 |
+
review_data = reviews_dict.get(paper_id)
|
| 1002 |
+
if not review_data or not review_data.get('review'):
|
| 1003 |
+
missing_reviews.append(paper_id)
|
| 1004 |
+
continue
|
| 1005 |
+
|
| 1006 |
+
entry = {
|
| 1007 |
+
'id': paper_id,
|
| 1008 |
+
'paper_context': rubric_entry.get('paper_context', ''),
|
| 1009 |
+
'decision': rubric_entry.get('decision', ''),
|
| 1010 |
+
'golden_review': rubric_entry.get('golden_review', ''),
|
| 1011 |
+
'rubrics': rubric_entry.get('rubrics', []),
|
| 1012 |
+
'model_review': review_data.get('review', ''),
|
| 1013 |
+
'scores': review_data.get('scores'), # Refined scores (if available)
|
| 1014 |
+
'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
|
| 1015 |
+
'format': review_data.get('format', 'unknown') # Format type
|
| 1016 |
+
}
|
| 1017 |
+
combined.append(entry)
|
| 1018 |
+
|
| 1019 |
+
if missing_reviews:
|
| 1020 |
+
print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
|
| 1021 |
+
|
| 1022 |
+
return combined
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
# ============================================================================
|
| 1026 |
+
# LLM Service Configuration
|
| 1027 |
+
# ============================================================================
|
| 1028 |
+
|
| 1029 |
+
def load_llm_config(config_path: str) -> Dict[str, Any]:
|
| 1030 |
+
"""Load LLM configuration from YAML file."""
|
| 1031 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 1032 |
+
config = yaml.safe_load(f)
|
| 1033 |
+
return config
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
|
| 1037 |
+
"""Create LLM service from configuration."""
|
| 1038 |
+
mode = config.get('mode', 'gpt').lower()
|
| 1039 |
+
|
| 1040 |
+
if mode == 'gpt':
|
| 1041 |
+
gpt_config = config.get('gpt', {})
|
| 1042 |
+
api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
|
| 1043 |
+
if not api_key:
|
| 1044 |
+
raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
|
| 1045 |
+
|
| 1046 |
+
service = GPTService(
|
| 1047 |
+
api_key=api_key,
|
| 1048 |
+
model_name=gpt_config.get('model_name', 'gpt-4o'),
|
| 1049 |
+
base_url=gpt_config.get('base_url'),
|
| 1050 |
+
timeout=gpt_config.get('timeout', 300)
|
| 1051 |
+
)
|
| 1052 |
+
return service
|
| 1053 |
+
|
| 1054 |
+
elif mode == 'vllm':
|
| 1055 |
+
vllm_config = config.get('vllm', {})
|
| 1056 |
+
service = VLLMService(
|
| 1057 |
+
base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
|
| 1058 |
+
api_key=vllm_config.get('api_key', 'dummy-key'),
|
| 1059 |
+
model_name=vllm_config.get('model_name'),
|
| 1060 |
+
timeout=vllm_config.get('timeout', 300),
|
| 1061 |
+
max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
|
| 1062 |
+
max_retries=vllm_config.get('max_retries', 3),
|
| 1063 |
+
retry_delay=vllm_config.get('retry_delay', 1.0),
|
| 1064 |
+
retry_backoff=vllm_config.get('retry_backoff', 2.0)
|
| 1065 |
+
)
|
| 1066 |
+
return service
|
| 1067 |
+
|
| 1068 |
+
else:
|
| 1069 |
+
raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
# ============================================================================
|
| 1073 |
+
# Main Evaluation Functions
|
| 1074 |
+
# ============================================================================
|
| 1075 |
+
|
| 1076 |
+
def run_semantic_evaluation(
|
| 1077 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1078 |
+
prompt_template: str,
|
| 1079 |
+
llm_service: LLMService,
|
| 1080 |
+
max_workers: int
|
| 1081 |
+
) -> tuple:
|
| 1082 |
+
"""Run semantic evaluation and return results and summary."""
|
| 1083 |
+
print(f"\n{'='*80}")
|
| 1084 |
+
print("RUNNING SEMANTIC EVALUATION")
|
| 1085 |
+
print(f"{'='*80}")
|
| 1086 |
+
print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
|
| 1087 |
+
|
| 1088 |
+
results = []
|
| 1089 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 1090 |
+
future_to_entry = {
|
| 1091 |
+
executor.submit(
|
| 1092 |
+
evaluate_review_semantic,
|
| 1093 |
+
entry,
|
| 1094 |
+
entry['paper_context'],
|
| 1095 |
+
prompt_template,
|
| 1096 |
+
llm_service
|
| 1097 |
+
): entry
|
| 1098 |
+
for entry in evaluation_data
|
| 1099 |
+
}
|
| 1100 |
+
|
| 1101 |
+
for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
|
| 1102 |
+
try:
|
| 1103 |
+
result = future.result()
|
| 1104 |
+
results.append(result)
|
| 1105 |
+
except Exception as e:
|
| 1106 |
+
entry = future_to_entry[future]
|
| 1107 |
+
print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
|
| 1108 |
+
results.append({
|
| 1109 |
+
'id': entry.get('id', 'unknown'),
|
| 1110 |
+
'raw_scores': {},
|
| 1111 |
+
'weighted_scores': {},
|
| 1112 |
+
'total_score': 0.0,
|
| 1113 |
+
'error': str(e),
|
| 1114 |
+
'raw_response': ''
|
| 1115 |
+
})
|
| 1116 |
+
|
| 1117 |
+
# Calculate statistics
|
| 1118 |
+
valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
|
| 1119 |
+
review_scores = [r.get('total_score', 0.0) for r in valid_results]
|
| 1120 |
+
|
| 1121 |
+
summary = {
|
| 1122 |
+
'total_entries': len(results),
|
| 1123 |
+
'valid_entries': len(valid_results),
|
| 1124 |
+
'failed_entries': len(results) - len(valid_results)
|
| 1125 |
+
}
|
| 1126 |
+
|
| 1127 |
+
if review_scores:
|
| 1128 |
+
summary['overall_score'] = {
|
| 1129 |
+
'mean': sum(review_scores) / len(review_scores),
|
| 1130 |
+
'min': min(review_scores),
|
| 1131 |
+
'max': max(review_scores)
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
# Calculate per-rubric statistics (extract rubric titles from first entry)
|
| 1135 |
+
if evaluation_data and evaluation_data[0].get('rubrics'):
|
| 1136 |
+
rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
|
| 1137 |
+
per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
|
| 1138 |
+
summary['per_rubric_statistics'] = per_rubric_stats
|
| 1139 |
+
|
| 1140 |
+
return results, summary
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
def run_auto_metric_evaluation(
|
| 1144 |
+
evaluation_data: List[Dict[str, Any]],
|
| 1145 |
+
strict_mode: bool = False
|
| 1146 |
+
) -> tuple:
|
| 1147 |
+
"""
|
| 1148 |
+
Run auto-metric evaluation and return results and summary.
|
| 1149 |
+
|
| 1150 |
+
For refined format (has scores and initial_scores), evaluates both:
|
| 1151 |
+
- Refined scores evaluation
|
| 1152 |
+
- Initial scores evaluation
|
| 1153 |
+
|
| 1154 |
+
For original format (only initial_scores), evaluates:
|
| 1155 |
+
- Initial scores evaluation only
|
| 1156 |
+
|
| 1157 |
+
Returns:
|
| 1158 |
+
Tuple of (results_list, summary_dict)
|
| 1159 |
+
- results_list: List of evaluation results (may contain both refined and initial results for refined format)
|
| 1160 |
+
- summary_dict: Summary statistics
|
| 1161 |
+
"""
|
| 1162 |
+
print(f"\n{'='*80}")
|
| 1163 |
+
print("RUNNING AUTO-METRIC EVALUATION")
|
| 1164 |
+
print(f"{'='*80}")
|
| 1165 |
+
print(f"Evaluating {len(evaluation_data)} entries...")
|
| 1166 |
+
|
| 1167 |
+
# Detect format types
|
| 1168 |
+
refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
|
| 1169 |
+
original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
|
| 1170 |
+
|
| 1171 |
+
if refined_format_count > 0:
|
| 1172 |
+
print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
|
| 1173 |
+
if original_format_count > 0:
|
| 1174 |
+
print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
|
| 1175 |
+
|
| 1176 |
+
results = []
|
| 1177 |
+
for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
|
| 1178 |
+
format_type = entry.get('format', 'unknown')
|
| 1179 |
+
|
| 1180 |
+
if format_type == 'refined':
|
| 1181 |
+
# Evaluate both refined scores and initial scores
|
| 1182 |
+
try:
|
| 1183 |
+
entry_id = entry.get('id', 'unknown')
|
| 1184 |
+
|
| 1185 |
+
# Evaluate refined scores
|
| 1186 |
+
refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1187 |
+
refined_result['paper_id'] = entry_id # Keep original paper_id
|
| 1188 |
+
refined_result['id'] = f"{entry_id}_refined"
|
| 1189 |
+
results.append(refined_result)
|
| 1190 |
+
|
| 1191 |
+
# Evaluate initial scores
|
| 1192 |
+
initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
|
| 1193 |
+
initial_result['paper_id'] = entry_id # Keep original paper_id
|
| 1194 |
+
initial_result['id'] = f"{entry_id}_initial"
|
| 1195 |
+
results.append(initial_result)
|
| 1196 |
+
except Exception as e:
|
| 1197 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1198 |
+
results.append({
|
| 1199 |
+
'id': entry.get('id', 'unknown'),
|
| 1200 |
+
'error': str(e)
|
| 1201 |
+
})
|
| 1202 |
+
else:
|
| 1203 |
+
# Evaluate initial scores only (or extract from markdown)
|
| 1204 |
+
try:
|
| 1205 |
+
result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
|
| 1206 |
+
results.append(result)
|
| 1207 |
+
except Exception as e:
|
| 1208 |
+
print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
|
| 1209 |
+
results.append({
|
| 1210 |
+
'id': entry.get('id', 'unknown'),
|
| 1211 |
+
'error': str(e)
|
| 1212 |
+
})
|
| 1213 |
+
|
| 1214 |
+
# Calculate statistics
|
| 1215 |
+
valid_results = [r for r in results if 'error' not in r]
|
| 1216 |
+
mse_results = [r for r in valid_results if r.get('overall_error') is not None]
|
| 1217 |
+
|
| 1218 |
+
# Separate refined and initial results for refined format
|
| 1219 |
+
refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
|
| 1220 |
+
initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
|
| 1221 |
+
auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
|
| 1222 |
+
|
| 1223 |
+
summary = {
|
| 1224 |
+
'total_entries': len(results),
|
| 1225 |
+
'valid_entries': len(valid_results),
|
| 1226 |
+
'mse_entries': len(mse_results),
|
| 1227 |
+
'refined_results_count': len(refined_results),
|
| 1228 |
+
'initial_results_count': len(initial_results),
|
| 1229 |
+
'auto_results_count': len(auto_results)
|
| 1230 |
+
}
|
| 1231 |
+
|
| 1232 |
+
# Calculate MSE/MAE statistics
|
| 1233 |
+
# For refined format, only use refined results for overall statistics (avoid double counting)
|
| 1234 |
+
# For other formats, use all results
|
| 1235 |
+
if refined_format_count > 0:
|
| 1236 |
+
# Refined format: use only refined results for overall statistics
|
| 1237 |
+
stats_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1238 |
+
else:
|
| 1239 |
+
# Original/other formats: use all results
|
| 1240 |
+
stats_results = mse_results
|
| 1241 |
+
|
| 1242 |
+
if stats_results:
|
| 1243 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1244 |
+
mse_stats = {}
|
| 1245 |
+
mae_stats = {}
|
| 1246 |
+
|
| 1247 |
+
for dim in dimensions:
|
| 1248 |
+
mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
|
| 1249 |
+
mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
|
| 1250 |
+
|
| 1251 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1252 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1253 |
+
|
| 1254 |
+
if mse_clean:
|
| 1255 |
+
mse_stats[dim] = {
|
| 1256 |
+
'mean': sum(mse_clean) / len(mse_clean),
|
| 1257 |
+
'count': len(mse_clean)
|
| 1258 |
+
}
|
| 1259 |
+
if mae_clean:
|
| 1260 |
+
mae_stats[dim] = {
|
| 1261 |
+
'mean': sum(mae_clean) / len(mae_clean),
|
| 1262 |
+
'count': len(mae_clean)
|
| 1263 |
+
}
|
| 1264 |
+
|
| 1265 |
+
overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
|
| 1266 |
+
overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1267 |
+
|
| 1268 |
+
if overall_clean:
|
| 1269 |
+
summary['overall_error'] = {
|
| 1270 |
+
'mean': sum(overall_clean) / len(overall_clean),
|
| 1271 |
+
'count': len(overall_clean)
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
summary['mse_statistics'] = mse_stats
|
| 1275 |
+
summary['mae_statistics'] = mae_stats
|
| 1276 |
+
|
| 1277 |
+
# Calculate separate statistics for refined and initial results
|
| 1278 |
+
if refined_results:
|
| 1279 |
+
refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
|
| 1280 |
+
if refined_mse_results:
|
| 1281 |
+
refined_mse_stats = {}
|
| 1282 |
+
refined_mae_stats = {}
|
| 1283 |
+
for dim in dimensions:
|
| 1284 |
+
mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1285 |
+
mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1286 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1287 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1288 |
+
if mse_clean:
|
| 1289 |
+
refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1290 |
+
if mae_clean:
|
| 1291 |
+
refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1292 |
+
summary['refined_mse_statistics'] = refined_mse_stats
|
| 1293 |
+
summary['refined_mae_statistics'] = refined_mae_stats
|
| 1294 |
+
|
| 1295 |
+
if initial_results:
|
| 1296 |
+
initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
|
| 1297 |
+
if initial_mse_results:
|
| 1298 |
+
initial_mse_stats = {}
|
| 1299 |
+
initial_mae_stats = {}
|
| 1300 |
+
for dim in dimensions:
|
| 1301 |
+
mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
|
| 1302 |
+
mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
|
| 1303 |
+
mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1304 |
+
mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
|
| 1305 |
+
if mse_clean:
|
| 1306 |
+
initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
|
| 1307 |
+
if mae_clean:
|
| 1308 |
+
initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
|
| 1309 |
+
summary['initial_mse_statistics'] = initial_mse_stats
|
| 1310 |
+
summary['initial_mae_statistics'] = initial_mae_stats
|
| 1311 |
+
|
| 1312 |
+
# Calculate Spearman correlations
|
| 1313 |
+
def filter_valid_pairs(true_list, pred_list):
|
| 1314 |
+
filtered_true = []
|
| 1315 |
+
filtered_pred = []
|
| 1316 |
+
for t, p in zip(true_list, pred_list):
|
| 1317 |
+
if (t is not None and p is not None and
|
| 1318 |
+
not (isinstance(t, float) and math.isnan(t)) and
|
| 1319 |
+
not (isinstance(p, float) and math.isnan(p))):
|
| 1320 |
+
filtered_true.append(t)
|
| 1321 |
+
filtered_pred.append(p)
|
| 1322 |
+
return filtered_true, filtered_pred
|
| 1323 |
+
|
| 1324 |
+
# Calculate Spearman correlations
|
| 1325 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1326 |
+
# For other formats, use all results
|
| 1327 |
+
if refined_format_count > 0:
|
| 1328 |
+
# Calculate refined spearman correlations
|
| 1329 |
+
refined_spearman_stats = {}
|
| 1330 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1331 |
+
for dim in dimensions:
|
| 1332 |
+
true_values = [r.get(f'gt_{dim}') for r in refined_results]
|
| 1333 |
+
pred_values = [r.get(f'model_{dim}') for r in refined_results]
|
| 1334 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1335 |
+
|
| 1336 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1337 |
+
try:
|
| 1338 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1339 |
+
if not math.isnan(corr):
|
| 1340 |
+
refined_spearman_stats[dim] = {
|
| 1341 |
+
'correlation': corr,
|
| 1342 |
+
'count': len(true_clean)
|
| 1343 |
+
}
|
| 1344 |
+
except Exception:
|
| 1345 |
+
pass
|
| 1346 |
+
|
| 1347 |
+
# Calculate initial spearman correlations
|
| 1348 |
+
initial_spearman_stats = {}
|
| 1349 |
+
for dim in dimensions:
|
| 1350 |
+
true_values = [r.get(f'gt_{dim}') for r in initial_results]
|
| 1351 |
+
pred_values = [r.get(f'model_{dim}') for r in initial_results]
|
| 1352 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1353 |
+
|
| 1354 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1355 |
+
try:
|
| 1356 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1357 |
+
if not math.isnan(corr):
|
| 1358 |
+
initial_spearman_stats[dim] = {
|
| 1359 |
+
'correlation': corr,
|
| 1360 |
+
'count': len(true_clean)
|
| 1361 |
+
}
|
| 1362 |
+
except Exception:
|
| 1363 |
+
pass
|
| 1364 |
+
|
| 1365 |
+
# Use refined for overall statistics (avoid double counting)
|
| 1366 |
+
summary['spearman_correlations'] = refined_spearman_stats
|
| 1367 |
+
summary['refined_spearman_correlations'] = refined_spearman_stats
|
| 1368 |
+
summary['initial_spearman_correlations'] = initial_spearman_stats
|
| 1369 |
+
else:
|
| 1370 |
+
# Original/other formats: use all results
|
| 1371 |
+
correlation_results = valid_results
|
| 1372 |
+
spearman_stats = {}
|
| 1373 |
+
dimensions = ['soundness', 'presentation', 'confidence', 'rating']
|
| 1374 |
+
for dim in dimensions:
|
| 1375 |
+
true_values = [r.get(f'gt_{dim}') for r in correlation_results]
|
| 1376 |
+
pred_values = [r.get(f'model_{dim}') for r in correlation_results]
|
| 1377 |
+
true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
|
| 1378 |
+
|
| 1379 |
+
if len(true_clean) >= 2 and len(pred_clean) >= 2:
|
| 1380 |
+
try:
|
| 1381 |
+
corr, _ = spearmanr(true_clean, pred_clean)
|
| 1382 |
+
if not math.isnan(corr):
|
| 1383 |
+
spearman_stats[dim] = {
|
| 1384 |
+
'correlation': corr,
|
| 1385 |
+
'count': len(true_clean)
|
| 1386 |
+
}
|
| 1387 |
+
except Exception:
|
| 1388 |
+
pass
|
| 1389 |
+
|
| 1390 |
+
summary['spearman_correlations'] = spearman_stats
|
| 1391 |
+
|
| 1392 |
+
# Calculate Decision metrics
|
| 1393 |
+
# For refined format, calculate separately for refined and initial, and use refined for overall
|
| 1394 |
+
# For other formats, use all results
|
| 1395 |
+
if refined_format_count > 0:
|
| 1396 |
+
# Calculate refined decision metrics
|
| 1397 |
+
refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1398 |
+
if refined_decision_results:
|
| 1399 |
+
true_decisions = []
|
| 1400 |
+
pred_decisions = []
|
| 1401 |
+
decision_acc = []
|
| 1402 |
+
|
| 1403 |
+
for r in refined_decision_results:
|
| 1404 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1405 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1406 |
+
|
| 1407 |
+
if 'accept' in pred_decision:
|
| 1408 |
+
pred_binary = 1
|
| 1409 |
+
else:
|
| 1410 |
+
pred_binary = 0
|
| 1411 |
+
|
| 1412 |
+
if 'accept' in gt_decision:
|
| 1413 |
+
gt_binary = 1
|
| 1414 |
+
else:
|
| 1415 |
+
gt_binary = 0
|
| 1416 |
+
|
| 1417 |
+
true_decisions.append(gt_binary)
|
| 1418 |
+
pred_decisions.append(pred_binary)
|
| 1419 |
+
|
| 1420 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1421 |
+
decision_acc.append(1.0)
|
| 1422 |
+
else:
|
| 1423 |
+
decision_acc.append(0.0)
|
| 1424 |
+
|
| 1425 |
+
if decision_acc:
|
| 1426 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1427 |
+
try:
|
| 1428 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1429 |
+
refined_decision_metrics = {
|
| 1430 |
+
'accuracy': decision_accuracy,
|
| 1431 |
+
'f1_macro': f1_score,
|
| 1432 |
+
'count': len(decision_acc)
|
| 1433 |
+
}
|
| 1434 |
+
except Exception:
|
| 1435 |
+
refined_decision_metrics = {
|
| 1436 |
+
'accuracy': decision_accuracy,
|
| 1437 |
+
'count': len(decision_acc)
|
| 1438 |
+
}
|
| 1439 |
+
summary['refined_decision_metrics'] = refined_decision_metrics
|
| 1440 |
+
summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
|
| 1441 |
+
|
| 1442 |
+
# Calculate initial decision metrics
|
| 1443 |
+
initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1444 |
+
if initial_decision_results:
|
| 1445 |
+
true_decisions = []
|
| 1446 |
+
pred_decisions = []
|
| 1447 |
+
decision_acc = []
|
| 1448 |
+
|
| 1449 |
+
for r in initial_decision_results:
|
| 1450 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1451 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1452 |
+
|
| 1453 |
+
if 'accept' in pred_decision:
|
| 1454 |
+
pred_binary = 1
|
| 1455 |
+
else:
|
| 1456 |
+
pred_binary = 0
|
| 1457 |
+
|
| 1458 |
+
if 'accept' in gt_decision:
|
| 1459 |
+
gt_binary = 1
|
| 1460 |
+
else:
|
| 1461 |
+
gt_binary = 0
|
| 1462 |
+
|
| 1463 |
+
true_decisions.append(gt_binary)
|
| 1464 |
+
pred_decisions.append(pred_binary)
|
| 1465 |
+
|
| 1466 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1467 |
+
decision_acc.append(1.0)
|
| 1468 |
+
else:
|
| 1469 |
+
decision_acc.append(0.0)
|
| 1470 |
+
|
| 1471 |
+
if decision_acc:
|
| 1472 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1473 |
+
try:
|
| 1474 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1475 |
+
initial_decision_metrics = {
|
| 1476 |
+
'accuracy': decision_accuracy,
|
| 1477 |
+
'f1_macro': f1_score,
|
| 1478 |
+
'count': len(decision_acc)
|
| 1479 |
+
}
|
| 1480 |
+
except Exception:
|
| 1481 |
+
initial_decision_metrics = {
|
| 1482 |
+
'accuracy': decision_accuracy,
|
| 1483 |
+
'count': len(decision_acc)
|
| 1484 |
+
}
|
| 1485 |
+
summary['initial_decision_metrics'] = initial_decision_metrics
|
| 1486 |
+
else:
|
| 1487 |
+
# Original/other formats: use all results
|
| 1488 |
+
decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
|
| 1489 |
+
if decision_results:
|
| 1490 |
+
true_decisions = []
|
| 1491 |
+
pred_decisions = []
|
| 1492 |
+
decision_acc = []
|
| 1493 |
+
|
| 1494 |
+
for r in decision_results:
|
| 1495 |
+
gt_decision = str(r.get('gt_decision', '')).lower().strip()
|
| 1496 |
+
pred_decision = str(r.get('model_decision', '')).lower().strip()
|
| 1497 |
+
|
| 1498 |
+
if 'accept' in pred_decision:
|
| 1499 |
+
pred_binary = 1
|
| 1500 |
+
else:
|
| 1501 |
+
pred_binary = 0
|
| 1502 |
+
|
| 1503 |
+
if 'accept' in gt_decision:
|
| 1504 |
+
gt_binary = 1
|
| 1505 |
+
else:
|
| 1506 |
+
gt_binary = 0
|
| 1507 |
+
|
| 1508 |
+
true_decisions.append(gt_binary)
|
| 1509 |
+
pred_decisions.append(pred_binary)
|
| 1510 |
+
|
| 1511 |
+
if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
|
| 1512 |
+
decision_acc.append(1.0)
|
| 1513 |
+
else:
|
| 1514 |
+
decision_acc.append(0.0)
|
| 1515 |
+
|
| 1516 |
+
if decision_acc:
|
| 1517 |
+
decision_accuracy = sum(decision_acc) / len(decision_acc)
|
| 1518 |
+
try:
|
| 1519 |
+
_, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
|
| 1520 |
+
summary['decision_metrics'] = {
|
| 1521 |
+
'accuracy': decision_accuracy,
|
| 1522 |
+
'f1_macro': f1_score,
|
| 1523 |
+
'count': len(decision_acc)
|
| 1524 |
+
}
|
| 1525 |
+
except Exception:
|
| 1526 |
+
summary['decision_metrics'] = {
|
| 1527 |
+
'accuracy': decision_accuracy,
|
| 1528 |
+
'count': len(decision_acc)
|
| 1529 |
+
}
|
| 1530 |
+
|
| 1531 |
+
# Calculate Pairwise comparison
|
| 1532 |
+
# For refined format, only use refined results (avoid double counting)
|
| 1533 |
+
# For other formats, use all results
|
| 1534 |
+
if refined_format_count > 0:
|
| 1535 |
+
pairwise_results = refined_results
|
| 1536 |
+
else:
|
| 1537 |
+
pairwise_results = valid_results
|
| 1538 |
+
|
| 1539 |
+
paper_scores = []
|
| 1540 |
+
for r in pairwise_results:
|
| 1541 |
+
if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
|
| 1542 |
+
(r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
|
| 1543 |
+
paper_scores.append({
|
| 1544 |
+
'true_rating': r.get('gt_rating'),
|
| 1545 |
+
'pred_rating': r.get('model_rating'),
|
| 1546 |
+
'true_soundness': r.get('gt_soundness'),
|
| 1547 |
+
'pred_soundness': r.get('model_soundness'),
|
| 1548 |
+
'true_presentation': r.get('gt_presentation'),
|
| 1549 |
+
'pred_presentation': r.get('model_presentation'),
|
| 1550 |
+
'true_confidence': r.get('gt_confidence'),
|
| 1551 |
+
'pred_confidence': r.get('model_confidence')
|
| 1552 |
+
})
|
| 1553 |
+
|
| 1554 |
+
if len(paper_scores) >= 2:
|
| 1555 |
+
pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
|
| 1556 |
+
summary['pairwise_accuracies'] = pairwise_accuracies
|
| 1557 |
+
|
| 1558 |
+
return results, summary
|
| 1559 |
+
|
| 1560 |
+
|
| 1561 |
+
# ============================================================================
|
| 1562 |
+
# Main Function
|
| 1563 |
+
# ============================================================================
|
| 1564 |
+
|
| 1565 |
+
def parse_args():
|
| 1566 |
+
"""Parse command line arguments."""
|
| 1567 |
+
parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
|
| 1568 |
+
|
| 1569 |
+
# Input paths
|
| 1570 |
+
parser.add_argument("--rubrics_path", type=str, required=True,
|
| 1571 |
+
help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
|
| 1572 |
+
parser.add_argument("--reviews_path", type=str, required=True,
|
| 1573 |
+
help="Path to JSON file with model reviews (contains pred_fast_mode)")
|
| 1574 |
+
|
| 1575 |
+
# Evaluation mode
|
| 1576 |
+
parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
|
| 1577 |
+
help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
|
| 1578 |
+
|
| 1579 |
+
# Output paths
|
| 1580 |
+
parser.add_argument("--semantic_output", type=str, default=None,
|
| 1581 |
+
help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
|
| 1582 |
+
parser.add_argument("--auto_metric_output", type=str, default=None,
|
| 1583 |
+
help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
|
| 1584 |
+
|
| 1585 |
+
# Semantic evaluation settings
|
| 1586 |
+
parser.add_argument("--yaml_path", type=str, default=None,
|
| 1587 |
+
help="Path to prompts.yaml file (required for semantic evaluation)")
|
| 1588 |
+
parser.add_argument("--config_path", type=str, default=None,
|
| 1589 |
+
help="Path to configs.yaml file (required for semantic evaluation)")
|
| 1590 |
+
|
| 1591 |
+
# Multi-threading
|
| 1592 |
+
parser.add_argument("--max_workers", type=int, default=None,
|
| 1593 |
+
help="Maximum number of worker threads for semantic evaluation (default: 5)")
|
| 1594 |
+
|
| 1595 |
+
# Strict mode (normalize scores to discrete scales)
|
| 1596 |
+
parser.add_argument("--strict_mode", action="store_true", default=False,
|
| 1597 |
+
help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
|
| 1598 |
+
|
| 1599 |
+
# Input format override
|
| 1600 |
+
parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
|
| 1601 |
+
help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
|
| 1602 |
+
|
| 1603 |
+
return parser.parse_args()
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
def main():
|
| 1607 |
+
"""Main execution function."""
|
| 1608 |
+
args = parse_args()
|
| 1609 |
+
|
| 1610 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 1611 |
+
|
| 1612 |
+
# Resolve paths
|
| 1613 |
+
rubrics_path = args.rubrics_path
|
| 1614 |
+
if not os.path.isabs(rubrics_path):
|
| 1615 |
+
rubrics_path = os.path.join(script_dir, rubrics_path)
|
| 1616 |
+
|
| 1617 |
+
reviews_path = args.reviews_path
|
| 1618 |
+
if not os.path.isabs(reviews_path):
|
| 1619 |
+
reviews_path = os.path.join(script_dir, reviews_path)
|
| 1620 |
+
|
| 1621 |
+
max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
|
| 1622 |
+
|
| 1623 |
+
# Validate mode and output paths
|
| 1624 |
+
if args.mode in ["semantic", "both"]:
|
| 1625 |
+
if not args.semantic_output:
|
| 1626 |
+
raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
|
| 1627 |
+
if not args.yaml_path:
|
| 1628 |
+
raise ValueError("--yaml_path is required for semantic evaluation")
|
| 1629 |
+
if not args.config_path:
|
| 1630 |
+
raise ValueError("--config_path is required for semantic evaluation")
|
| 1631 |
+
|
| 1632 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1633 |
+
if not args.auto_metric_output:
|
| 1634 |
+
raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
|
| 1635 |
+
|
| 1636 |
+
# Check if files exist
|
| 1637 |
+
if not os.path.exists(rubrics_path):
|
| 1638 |
+
raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
|
| 1639 |
+
if not os.path.exists(reviews_path):
|
| 1640 |
+
raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
|
| 1641 |
+
|
| 1642 |
+
# Load data
|
| 1643 |
+
print(f"Loading rubrics from {rubrics_path}...")
|
| 1644 |
+
rubrics_data = load_rubrics_json(rubrics_path)
|
| 1645 |
+
print(f"Loaded {len(rubrics_data)} rubrics entries")
|
| 1646 |
+
|
| 1647 |
+
print(f"Loading model reviews from {reviews_path}...")
|
| 1648 |
+
if args.input_format != 'auto':
|
| 1649 |
+
print(f"Using manually specified format: {args.input_format}")
|
| 1650 |
+
else:
|
| 1651 |
+
print("Auto-detecting input format...")
|
| 1652 |
+
reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
|
| 1653 |
+
print(f"Loaded {len(reviews_dict)} model reviews")
|
| 1654 |
+
|
| 1655 |
+
# Combine rubrics and reviews
|
| 1656 |
+
print("Combining rubrics and reviews...")
|
| 1657 |
+
evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
|
| 1658 |
+
print(f"Prepared {len(evaluation_data)} entries for evaluation")
|
| 1659 |
+
|
| 1660 |
+
# Run evaluations based on mode
|
| 1661 |
+
if args.mode in ["semantic", "both"]:
|
| 1662 |
+
# Resolve semantic evaluation paths
|
| 1663 |
+
yaml_path = args.yaml_path
|
| 1664 |
+
if not os.path.isabs(yaml_path):
|
| 1665 |
+
yaml_path = os.path.join(script_dir, yaml_path)
|
| 1666 |
+
|
| 1667 |
+
config_path = args.config_path
|
| 1668 |
+
if not os.path.isabs(config_path):
|
| 1669 |
+
config_path = os.path.join(script_dir, config_path)
|
| 1670 |
+
|
| 1671 |
+
if not os.path.exists(yaml_path):
|
| 1672 |
+
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
|
| 1673 |
+
if not os.path.exists(config_path):
|
| 1674 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 1675 |
+
|
| 1676 |
+
# Load prompt template
|
| 1677 |
+
print(f"Loading prompt template from {yaml_path}...")
|
| 1678 |
+
prompt_template = load_prompt_template(yaml_path)
|
| 1679 |
+
if not prompt_template:
|
| 1680 |
+
raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
|
| 1681 |
+
|
| 1682 |
+
# Initialize LLM service
|
| 1683 |
+
print(f"Loading LLM configuration from {config_path}...")
|
| 1684 |
+
llm_config = load_llm_config(config_path)
|
| 1685 |
+
llm_service = create_llm_service_from_config(llm_config)
|
| 1686 |
+
mode = llm_config.get('mode', 'gpt')
|
| 1687 |
+
print(f"LLM service initialized (mode: {mode})")
|
| 1688 |
+
if hasattr(llm_service, 'model_name'):
|
| 1689 |
+
print(f"Using model: {llm_service.model_name}")
|
| 1690 |
+
|
| 1691 |
+
# Run semantic evaluation
|
| 1692 |
+
semantic_results, semantic_summary = run_semantic_evaluation(
|
| 1693 |
+
evaluation_data, prompt_template, llm_service, max_workers
|
| 1694 |
+
)
|
| 1695 |
+
|
| 1696 |
+
# Save semantic results
|
| 1697 |
+
semantic_output = args.semantic_output
|
| 1698 |
+
if not os.path.isabs(semantic_output):
|
| 1699 |
+
semantic_output = os.path.join(script_dir, semantic_output)
|
| 1700 |
+
|
| 1701 |
+
output_dir = os.path.dirname(semantic_output)
|
| 1702 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1703 |
+
|
| 1704 |
+
with open(semantic_output, 'w', encoding='utf-8') as f:
|
| 1705 |
+
json.dump(semantic_results, f, ensure_ascii=False, indent=2)
|
| 1706 |
+
print(f"\nSemantic evaluation results saved to {semantic_output}")
|
| 1707 |
+
|
| 1708 |
+
# Save semantic summary
|
| 1709 |
+
semantic_summary_path = semantic_output.replace('.json', '_summary.json')
|
| 1710 |
+
with open(semantic_summary_path, 'w', encoding='utf-8') as f:
|
| 1711 |
+
json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
|
| 1712 |
+
print(f"Semantic evaluation summary saved to {semantic_summary_path}")
|
| 1713 |
+
|
| 1714 |
+
# Print semantic summary
|
| 1715 |
+
print("\n" + "="*80)
|
| 1716 |
+
print("SEMANTIC EVALUATION SUMMARY")
|
| 1717 |
+
print("="*80)
|
| 1718 |
+
print(f"Total entries: {semantic_summary['total_entries']}")
|
| 1719 |
+
print(f"Valid entries: {semantic_summary['valid_entries']}")
|
| 1720 |
+
print(f"Failed entries: {semantic_summary['failed_entries']}")
|
| 1721 |
+
if 'overall_score' in semantic_summary:
|
| 1722 |
+
score = semantic_summary['overall_score']
|
| 1723 |
+
print(f"\nOverall Score:")
|
| 1724 |
+
print(f" Mean: {score['mean']:.2f}")
|
| 1725 |
+
print(f" Min: {score['min']:.2f}")
|
| 1726 |
+
print(f" Max: {score['max']:.2f}")
|
| 1727 |
+
|
| 1728 |
+
if args.mode in ["auto_metric", "both"]:
|
| 1729 |
+
# Run auto-metric evaluation
|
| 1730 |
+
auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
|
| 1731 |
+
evaluation_data,
|
| 1732 |
+
strict_mode=args.strict_mode
|
| 1733 |
+
)
|
| 1734 |
+
|
| 1735 |
+
# Save auto-metric results
|
| 1736 |
+
auto_metric_output = args.auto_metric_output
|
| 1737 |
+
if not os.path.isabs(auto_metric_output):
|
| 1738 |
+
auto_metric_output = os.path.join(script_dir, auto_metric_output)
|
| 1739 |
+
|
| 1740 |
+
output_dir = os.path.dirname(auto_metric_output)
|
| 1741 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1742 |
+
|
| 1743 |
+
with open(auto_metric_output, 'w', encoding='utf-8') as f:
|
| 1744 |
+
json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
|
| 1745 |
+
print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
|
| 1746 |
+
|
| 1747 |
+
# Save auto-metric summary
|
| 1748 |
+
auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
|
| 1749 |
+
with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
|
| 1750 |
+
json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
|
| 1751 |
+
print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
|
| 1752 |
+
|
| 1753 |
+
# Print auto-metric summary
|
| 1754 |
+
print("\n" + "="*80)
|
| 1755 |
+
print("AUTO-METRIC EVALUATION SUMMARY")
|
| 1756 |
+
print("="*80)
|
| 1757 |
+
print(f"Total entries: {auto_metric_summary['total_entries']}")
|
| 1758 |
+
print(f"Valid entries: {auto_metric_summary['valid_entries']}")
|
| 1759 |
+
print(f"MSE entries: {auto_metric_summary['mse_entries']}")
|
| 1760 |
+
|
| 1761 |
+
if 'mse_statistics' in auto_metric_summary:
|
| 1762 |
+
print("\nMSE Statistics:")
|
| 1763 |
+
for dim, stats in auto_metric_summary['mse_statistics'].items():
|
| 1764 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1765 |
+
|
| 1766 |
+
if 'mae_statistics' in auto_metric_summary:
|
| 1767 |
+
print("\nMAE Statistics:")
|
| 1768 |
+
for dim, stats in auto_metric_summary['mae_statistics'].items():
|
| 1769 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1770 |
+
|
| 1771 |
+
# Print refined and initial statistics if available
|
| 1772 |
+
if 'refined_mse_statistics' in auto_metric_summary:
|
| 1773 |
+
print("\nRefined Scores - MSE Statistics:")
|
| 1774 |
+
for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
|
| 1775 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1776 |
+
|
| 1777 |
+
if 'refined_mae_statistics' in auto_metric_summary:
|
| 1778 |
+
print("\nRefined Scores - MAE Statistics:")
|
| 1779 |
+
for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
|
| 1780 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1781 |
+
|
| 1782 |
+
if 'initial_mse_statistics' in auto_metric_summary:
|
| 1783 |
+
print("\nInitial Scores - MSE Statistics:")
|
| 1784 |
+
for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
|
| 1785 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1786 |
+
|
| 1787 |
+
if 'initial_mae_statistics' in auto_metric_summary:
|
| 1788 |
+
print("\nInitial Scores - MAE Statistics:")
|
| 1789 |
+
for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
|
| 1790 |
+
print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
|
| 1791 |
+
|
| 1792 |
+
if 'spearman_correlations' in auto_metric_summary:
|
| 1793 |
+
print("\nSpearman Correlations:")
|
| 1794 |
+
for dim, stats in auto_metric_summary['spearman_correlations'].items():
|
| 1795 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1796 |
+
|
| 1797 |
+
# Print refined and initial spearman correlations if available
|
| 1798 |
+
if 'refined_spearman_correlations' in auto_metric_summary:
|
| 1799 |
+
print("\nRefined Scores - Spearman Correlations:")
|
| 1800 |
+
for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
|
| 1801 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1802 |
+
|
| 1803 |
+
if 'initial_spearman_correlations' in auto_metric_summary:
|
| 1804 |
+
print("\nInitial Scores - Spearman Correlations:")
|
| 1805 |
+
for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
|
| 1806 |
+
print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
|
| 1807 |
+
|
| 1808 |
+
if 'decision_metrics' in auto_metric_summary:
|
| 1809 |
+
dm = auto_metric_summary['decision_metrics']
|
| 1810 |
+
print(f"\nDecision Metrics:")
|
| 1811 |
+
print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
|
| 1812 |
+
if 'f1_macro' in dm:
|
| 1813 |
+
print(f" F1 (macro): {dm['f1_macro']:.4f}")
|
| 1814 |
+
|
| 1815 |
+
# Print refined and initial decision metrics if available
|
| 1816 |
+
if 'refined_decision_metrics' in auto_metric_summary:
|
| 1817 |
+
print("\nRefined Scores - Decision Metrics:")
|
| 1818 |
+
rdm = auto_metric_summary['refined_decision_metrics']
|
| 1819 |
+
print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
|
| 1820 |
+
if 'f1_macro' in rdm:
|
| 1821 |
+
print(f" F1 (macro): {rdm['f1_macro']:.4f}")
|
| 1822 |
+
|
| 1823 |
+
if 'initial_decision_metrics' in auto_metric_summary:
|
| 1824 |
+
print("\nInitial Scores - Decision Metrics:")
|
| 1825 |
+
idm = auto_metric_summary['initial_decision_metrics']
|
| 1826 |
+
print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
|
| 1827 |
+
if 'f1_macro' in idm:
|
| 1828 |
+
print(f" F1 (macro): {idm['f1_macro']:.4f}")
|
| 1829 |
+
|
| 1830 |
+
print("\n" + "="*80)
|
| 1831 |
+
print("EVALUATION COMPLETE")
|
| 1832 |
+
print("="*80)
|
| 1833 |
+
|
| 1834 |
+
|
| 1835 |
+
if __name__ == "__main__":
|
| 1836 |
+
main()
|
| 1837 |
+
|
src/evaluator/configs.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM Service Configuration for Rubric Generation
|
| 2 |
+
# Choose one mode: "gpt" or "vllm"
|
| 3 |
+
mode: "vllm" # Options: "gpt" or "vllm"
|
| 4 |
+
|
| 5 |
+
# GPT API Configuration (used when mode="gpt")
|
| 6 |
+
gpt:
|
| 7 |
+
api_key: "your-api-key-here" # Replace with your actual OpenAI API key or set OPENAI_API_KEY env var
|
| 8 |
+
model_name: "gpt-4o" # Options: gpt-4o, gpt-4-turbo, gpt-3.5-turbo, gpt-5, etc.
|
| 9 |
+
base_url: null # Default: https://api.openai.com/v1
|
| 10 |
+
timeout: 300
|
| 11 |
+
|
| 12 |
+
# Default sampling parameters
|
| 13 |
+
temperature: 0.7
|
| 14 |
+
top_p: 0.95
|
| 15 |
+
max_tokens: 16384
|
| 16 |
+
presence_penalty: 0.0
|
| 17 |
+
|
| 18 |
+
# vLLM Service Configuration (used when mode="vllm")
|
| 19 |
+
vllm:
|
| 20 |
+
base_url: "http://localhost:8000/" # vLLM server base URL
|
| 21 |
+
api_key: "dummy-key" # Not used for local vLLM, but required by OpenAI client
|
| 22 |
+
model_name: "openai/gpt-oss-120b" # Model name on vLLM server
|
| 23 |
+
timeout: 300
|
| 24 |
+
|
| 25 |
+
# Rate limiting: Maximum concurrent requests to vLLM server
|
| 26 |
+
max_concurrent_requests: 64
|
| 27 |
+
|
| 28 |
+
# Retry configuration for server errors
|
| 29 |
+
max_retries: 3
|
| 30 |
+
retry_delay: 1.0
|
| 31 |
+
retry_backoff: 2.0
|
| 32 |
+
|
| 33 |
+
# Default sampling parameters
|
| 34 |
+
temperature: 0.7
|
| 35 |
+
top_p: 0.8
|
| 36 |
+
top_k: 20
|
| 37 |
+
max_tokens: 16384
|
| 38 |
+
presence_penalty: 0.0
|