eigentom commited on
Commit
90c099b
·
1 Parent(s): 37d42f7

Initial Update

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README copy.md +313 -0
  2. README.md +4 -6
  3. app.py +470 -0
  4. example.py +54 -0
  5. gradio_app/__init__.py +9 -0
  6. gradio_app/app.py +466 -0
  7. gradio_app/components/__init__.py +73 -0
  8. gradio_app/components/formatters.py +504 -0
  9. gradio_app/components/header.py +39 -0
  10. gradio_app/components/results_panel.py +193 -0
  11. gradio_app/components/settings.py +82 -0
  12. gradio_app/components/styles.py +592 -0
  13. gradio_app/components/upload_section.py +117 -0
  14. gradio_app/utils_single_paper_inference.py +276 -0
  15. requirements.txt +26 -0
  16. scripts/gpt_oss_start_vllm_service.sh +47 -0
  17. scripts/start_load_balancer.sh +87 -0
  18. scripts/start_reranker_service.sh +116 -0
  19. scripts/start_vllm_with_balancer.sh +216 -0
  20. scripts/stop_reranker_services.sh +106 -0
  21. scripts/stop_vllm_services.sh +267 -0
  22. shared/configs/config.yaml +97 -0
  23. shared/configs/llm_service_config.yaml +57 -0
  24. shared/configs/prompts.yaml +580 -0
  25. shared/configs/reranker_endpoint_pool.txt +8 -0
  26. shared/configs/vllm_endpoint_pool.txt +7 -0
  27. shared/utils/__init__.py +113 -0
  28. shared/utils/asta_api_key_pool.py +205 -0
  29. shared/utils/gpt_service.py +210 -0
  30. shared/utils/json_parser.py +428 -0
  31. shared/utils/llm_service.py +64 -0
  32. shared/utils/llm_service_factory.py +191 -0
  33. shared/utils/load_balancer.py +382 -0
  34. shared/utils/mock_llm_service.py +280 -0
  35. shared/utils/prompt_loader.py +220 -0
  36. shared/utils/reranker.py +275 -0
  37. shared/utils/reranker_api_service.py +221 -0
  38. shared/utils/reranker_endpoint_pool.py +160 -0
  39. shared/utils/reranker_pool.py +78 -0
  40. shared/utils/review_logger.py +306 -0
  41. shared/utils/vllm_endpoint_pool.py +257 -0
  42. shared/utils/vllm_service.py +314 -0
  43. shared/utils/vllm_service_simple.py +314 -0
  44. src/__init__.py +6 -0
  45. src/evaluator/1_get_rubrics.py +601 -0
  46. src/evaluator/2_evaluate.py +1730 -0
  47. src/evaluator/2_evaluate_agenticreview.py +1866 -0
  48. src/evaluator/2_evaluate_aiscientist.py +1866 -0
  49. src/evaluator/2_evaluate_cyclereviewer.py +1837 -0
  50. src/evaluator/configs.yaml +38 -0
README copy.md ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReviewGrounder: Improving Review Substantiveness with Rubric-Guided, Tool-Integrated Agents
2
+
3
+ This repository accompanies the paper: *"ReviewGrounder: Improving Review Substantiveness with Rubric-Guided, Tool-Integrated Agents"*. It contains the implementation of **ReviewGrounder**, a rubric-guided, tool-integrated multi-agent framework for generating substantive, evidence-grounded academic paper reviews.
4
+
5
+ ReviewGrounder addresses the key limitation of existing LLM-based reviewers—their tendency to produce superficial, formulaic comments lacking substantive feedback—by explicitly leveraging reviewer rubrics and contextual grounding in existing work.
6
+
7
+ ## System Architecture
8
+
9
+ ReviewGrounder implements a multi-agent framework with clear role separation:
10
+
11
+ ### Drafting Agent (`paper_reviewer.py`)
12
+ The **drafter** generates an initial review draft based solely on the paper content. This stage produces a structured review with strengths, weaknesses, suggestions, and questions, but may lack deep contextual grounding.
13
+
14
+ ### Grounding Agents
15
+
16
+ 1. **Related Work Searcher** (`related_work_searcher.py`):
17
+ - Generates search keywords from paper content
18
+ - Retrieves relevant papers via academic APIs
19
+ - Summarizes and analyzes related work
20
+ - Provides context for novelty assessment
21
+
22
+ 2. **Paper Results Analyzer** (`paper_results_analyzer.py`):
23
+ - Extracts and analyzes experimental sections
24
+ - Summarizes experimental setup, results, and findings
25
+ - Identifies limitations and gaps
26
+
27
+ 3. **Paper Insight Miner** (`paper_insight_miner.py`):
28
+ - Extracts key insights and contributions
29
+ - Identifies technical strengths and weaknesses
30
+
31
+ 4. **Review Refiner** (`review_refiner.py`):
32
+ - Synthesizes information from all grounding agents
33
+ - Refines the initial draft with evidence-based critiques
34
+ - Ensures suggestions are actionable and well-justified
35
+ - Maintains consistency across review sections
36
+
37
+ ### Evaluation System (`src/evaluator/`)
38
+ The **ReviewBench** evaluation framework:
39
+ - **Rubric Generation**: Creates paper-specific rubrics from venue guidelines, paper content, and human reviews
40
+ - **LLM-based Evaluation**: Deep qualitative assessment aligned with rubrics
41
+ - **Rule-based Metrics**: Quantitative metrics (MSE, MAE, Spearman correlation)
42
+
43
+ ## Installation
44
+
45
+ ### Prerequisites
46
+
47
+ - Python >= 3.8
48
+ - CUDA-capable GPU (for local vLLM deployment, optional if using OpenAI API)
49
+ - Sufficient GPU memory for your chosen model (if using vLLM)
50
+
51
+ ### Setup
52
+
53
+ 1. Clone the repository:
54
+ ```bash
55
+ git clone <repository-url>
56
+ cd ReviewGrounder
57
+ ```
58
+
59
+ 2. Install dependencies:
60
+ ```bash
61
+ uv venv
62
+ source .venv/bin/activate
63
+ uv pip install -r requirements.txt
64
+ ```
65
+
66
+ 3. Configure your API keys and settings:
67
+ - Copy `shared/configs/config.yaml` and customize as needed
68
+ - Set environment variables:
69
+ - `ASTA_API_KEY`: For paper search via Asta API (recommended)
70
+ - `OPENAI_API_KEY`: If using OpenAI API instead of vLLM
71
+ - `S2_API_KEY`: Alternative paper search API (optional)
72
+
73
+ 4. (Optional) If using local vLLM, start your vLLM service:
74
+ ```bash
75
+ # Start vLLM service on a single port
76
+ bash scripts/gpt_oss_start_vllm_service.sh
77
+
78
+ # Or start multiple services with load balancing
79
+ bash scripts/start_vllm_with_balancer.sh
80
+ ```
81
+
82
+ ## Quick Start
83
+
84
+ ### Basic Usage
85
+
86
+ Generate a review using the command-line interface:
87
+
88
+ ```bash
89
+ python -m src.reviewer_agent.cli --paper paper.json --output review.json
90
+ ```
91
+
92
+ Where `paper.json` contains your paper data in JSON format with fields like `title`, `abstract`, `text`, etc.
93
+
94
+ ### Using the Python API
95
+
96
+ For programmatic access:
97
+
98
+ ```python
99
+ from src.reviewer_agent import review_paper_with_refiner
100
+
101
+ # Load your paper data
102
+ paper_data = {
103
+ "title": "Your Paper Title",
104
+ "abstract": "Paper abstract...",
105
+ "text": "Full paper text...",
106
+ # ... other fields
107
+ }
108
+
109
+ # Generate review (drafting + grounding stages)
110
+ review = review_paper_with_refiner(paper_data=paper_data)
111
+ print(review)
112
+ ```
113
+
114
+ The `review_paper_with_refiner` function implements the full ReviewGrounder pipeline:
115
+ 1. **Drafting**: Generates initial review draft
116
+ 2. **Grounding**: Retrieves related work, analyzes results, extracts insights
117
+ 3. **Refinement**: Synthesizes all information into a refined, evidence-grounded review
118
+
119
+ ## Usage Examples
120
+
121
+ ### Generate a Review with Related Work Context
122
+
123
+ ```bash
124
+ python -m src.reviewer_agent.cli \
125
+ --paper paper.json \
126
+ --max-related-papers 15 \
127
+ --review-format detailed \
128
+ --output review.json
129
+ ```
130
+
131
+ ### Filter Related Work by Date and Venue
132
+
133
+ ```bash
134
+ python -m src.reviewer_agent.cli \
135
+ --paper paper.json \
136
+ --publication-date-range "2020:" \
137
+ --venues "ICLR,NeurIPS,ICML" \
138
+ --output review.json
139
+ ```
140
+
141
+ ### Use Custom vLLM Endpoint
142
+
143
+ ```bash
144
+ python -m src.reviewer_agent.cli \
145
+ --paper paper.json \
146
+ --vllm-url "http://your-server:8000/v1" \
147
+ --output review.json
148
+ ```
149
+
150
+ ### Evaluate Reviews on ReviewBench
151
+
152
+ ```python
153
+ # 1. Generate reviews
154
+ from src.reviewer_agent import review_paper_with_refiner
155
+ review = review_paper_with_refiner(paper_data={...})
156
+
157
+ # 2. Evaluate reviews using ReviewBench
158
+ from src.evaluator import evaluate_reviews
159
+ results = evaluate_reviews(parquet_path="reviews.parquet")
160
+ ```
161
+
162
+ ## Directory Structure
163
+
164
+ ```
165
+ anonymize_codebase/
166
+ ├── src/
167
+ │ ├── reviewer_agent/ # ReviewGrounder implementation
168
+ │ │ ├── __init__.py
169
+ │ │ ├── paper_reviewer.py # Drafting agent
170
+ │ │ ├── review_refiner.py # Grounding agent: review refinement
171
+ │ │ ├── related_work_searcher.py # Grounding agent: literature search
172
+ │ │ ├── paper_results_summarizer.py # Grounding agent: results analysis
173
+ │ │ ├── paper_insight_miner.py # Grounding agent: insight extraction
174
+ │ │ ├── main_pipeline.py # Full pipeline orchestration
175
+ │ │ ├── cli.py # Command-line interface
176
+ │ │ └── paper_search/ # Paper search APIs
177
+ │ │ ├── asta_api.py
178
+ │ │ ├── semantic_scholar_api.py
179
+ │ │ └── paper_retriever.py
180
+ │ │
181
+ │ └── evaluator/ # ReviewBench evaluation framework
182
+ │ ├── 1_get_rubrics.py # Rubric generation
183
+ │ ├── 2_evaluate.py # Review evaluation
184
+ │ └── ...
185
+
186
+ ├── shared/
187
+ │ ├── utils/ # Shared utilities
188
+ │ │ ├── llm_service.py # LLM service abstraction
189
+ │ │ ├── load_balancer.py # Load balancing for vLLM
190
+ │ │ ├── reranker.py # Paper reranking
191
+ │ │ └── ...
192
+ │ │
193
+ │ └── configs/ # Configuration files
194
+ │ ├── config.yaml # Main config
195
+ │ ├── llm_service_config.yaml # LLM service settings
196
+ │ └── prompts.yaml # Review generation prompts
197
+
198
+ ├── scripts/ # Utility scripts
199
+ │ ├── start_vllm_with_balancer.sh
200
+ │ ├── start_load_balancer.sh
201
+ │ └── ...
202
+
203
+ ├── requirements.txt # Python dependencies
204
+ └── README.md # This file
205
+ ```
206
+
207
+ ## Configuration Guide
208
+
209
+ ### LLM Service Configuration
210
+
211
+ ReviewGrounder supports two LLM backends:
212
+
213
+ 1. **vLLM** (recommended for local deployment): Fast inference with local GPU
214
+ - Default: GPT-OSS-120B for grounding stage
215
+ - Can use smaller models (e.g., Phi-4-14B) for drafting stage
216
+
217
+ 2. **OpenAI API**: Cloud-based, no local GPU required
218
+
219
+ Configure in `shared/configs/llm_service_config.yaml`:
220
+
221
+ ```yaml
222
+ vllm:
223
+ base_url: "http://localhost:8000/"
224
+ model_name: "openai/gpt-oss-120b"
225
+ max_tokens: 16384
226
+
227
+ gpt:
228
+ enabled: false
229
+ api_key: "your-api-key-here"
230
+ model_name: "gpt-4o"
231
+ ```
232
+
233
+ We offer the option of assigning different backends for each agent.
234
+ ```yaml
235
+ llm_assignments:
236
+ keyword_generator: "vllm" # For related work search
237
+ paper_summarizer: "vllm" # For results summarization
238
+ reviewer: "vllm" # For drafting stage
239
+ refiner: "vllm" # For grounding/refinement stage
240
+ ```
241
+
242
+ ### Paper Search Configuration
243
+
244
+ Configure paper search APIs in `shared/configs/config.yaml`:
245
+
246
+ ```yaml
247
+ paper_search:
248
+ asta:
249
+ api_key: null # Set via ASTA_API_KEY env var
250
+ endpoint: "https://asta-tools.allen.ai/mcp/v1"
251
+
252
+ semantic_scholar:
253
+ api_key: null # Set via S2_API_KEY env var
254
+ ```
255
+
256
+ ### Review Format Options
257
+
258
+ Choose from different review formats:
259
+ - `detailed`: Comprehensive review with all sections (default)
260
+ - `summary`: Concise review summary
261
+ - `structured`: Structured format with specific sections
262
+ - `strict_detailed`: Strict adherence to detailed format requirements
263
+
264
+ ## Load Balancing for vLLM
265
+
266
+ For production use with multiple GPUs, you can set up load balancing:
267
+
268
+ ```bash
269
+ # Start 4 vLLM services on ports 8000-8003
270
+ bash scripts/gpt_oss_start_vllm_service.sh
271
+
272
+ # Start load balancer on port 8004
273
+ python -m shared.utils.load_balancer \
274
+ --backends http://localhost:8000/v1 http://localhost:8001/v1 http://localhost:8002/v1 http://localhost:8003/v1 \
275
+ --port 8004 \
276
+ --strategy round_robin
277
+ ```
278
+
279
+ Then point your config to `http://localhost:8004/v1`.
280
+
281
+ ## Evaluation: ReviewBench
282
+
283
+ ReviewGrounder is evaluated on **ReviewBench**, a benchmark that:
284
+
285
+ - Leverages paper-specific rubrics derived from:
286
+ - Official venue guidelines (e.g., ACL, ICML, NeurIPS, ICLR)
287
+ - Paper content
288
+ - Human-written reviews
289
+
290
+ - Evaluates reviews across diverse dimensions:
291
+ - Evidence-based critique
292
+ - Constructive tone
293
+ - Technical depth
294
+ - And more...
295
+
296
+ - Measures both:
297
+ - Alignment with human judgments (scores, decisions)
298
+ - Rubric-based quality (beyond just outcome prediction)
299
+
300
+ See `src/evaluator/` for the evaluation framework implementation.
301
+
302
+ ## Citation
303
+
304
+ If you use ReviewGrounder in your research, please cite:
305
+
306
+ ```bibtex
307
+ @inproceedings{reviewgrounder2026,
308
+ title={ReviewGrounder: Improving Review Substantiveness with Rubric-Guided, Tool-Integrated Agents},
309
+ author={Anonymous},
310
+ booktitle={Proceedings of ACL 2026},
311
+ year={2026}
312
+ }
313
+ ```
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
- title: InteractiveDemo
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: This is the interactive demo of Review Grounder
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test Reviewgrounder
3
+ emoji: 💻
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Review Grounder - Gradio App
3
+
4
+ Main entry point for the Hugging Face Space.
5
+ This module orchestrates the UI components and handles the review pipeline.
6
+
7
+ The app allows users to:
8
+ 1. Upload a research paper in PDF format
9
+ 2. Configure LLM settings (optional, uses OpenAI defaults)
10
+ 3. Generate a comprehensive AI-powered review
11
+ 4. View intermediate results from each pipeline stage
12
+
13
+ Components are organized in the `components/` directory for maintainability.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import re
20
+ import tempfile
21
+ from datetime import datetime
22
+ from pathlib import Path
23
+ from typing import Tuple, Iterator
24
+
25
+ import gradio as gr
26
+
27
+ # Import utility for running the review pipeline
28
+ from gradio_app.utils_single_paper_inference import (
29
+ run_single_paper_review_from_pdf_stepwise,
30
+ )
31
+
32
+ # Import UI components
33
+ from gradio_app.components import (
34
+ get_custom_css,
35
+ create_header,
36
+ create_upload_section,
37
+ create_advanced_settings,
38
+ create_results_panel,
39
+ format_initial_review_html,
40
+ format_related_work_html,
41
+ format_results_html,
42
+ format_insights_html,
43
+ format_final_review,
44
+ format_raw_json,
45
+ )
46
+
47
+
48
+ # ============================================================================
49
+ # App Configuration
50
+ # ============================================================================
51
+
52
+ APP_TITLE = "Review Grounder"
53
+ APP_DESCRIPTION = "AI-Powered Research Paper Review"
54
+
55
+
56
+ def _raw_json_md_to_file(raw_json_md: str) -> str:
57
+ """
58
+ Extract JSON from Raw JSON markdown (```json ... ```) and write to a temp file.
59
+ Returns the file path for gr.DownloadButton.
60
+ """
61
+ if not raw_json_md or not raw_json_md.strip():
62
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
63
+ f.write("{}")
64
+ return f.name
65
+ text = raw_json_md.strip()
66
+ match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
67
+ if match:
68
+ text = match.group(1).strip()
69
+ fd, path = tempfile.mkstemp(suffix=".json", prefix="review_")
70
+ with os.fdopen(fd, "w", encoding="utf-8") as f:
71
+ f.write(text)
72
+ return path
73
+
74
+
75
+ # ============================================================================
76
+ # Environment Check
77
+ # ============================================================================
78
+
79
+ def _check_env() -> Tuple[bool, str]:
80
+ """
81
+ Check for required environment variables.
82
+
83
+ Returns:
84
+ Tuple of (success, message)
85
+ """
86
+ missing = []
87
+ if not os.environ.get("ASTA_API_KEY"):
88
+ missing.append("ASTA_API_KEY")
89
+
90
+ if missing:
91
+ return False, (
92
+ "Missing environment variables: "
93
+ + ", ".join(missing)
94
+ + ".\nPlease configure them in your Hugging Face Space settings."
95
+ )
96
+ return True, "Environment variables detected correctly."
97
+
98
+
99
+ # ============================================================================
100
+ # Review Pipeline Handler
101
+ # ============================================================================
102
+
103
+ def review_pdf_file(
104
+ file_obj,
105
+ api_base_url: str,
106
+ api_key: str,
107
+ model_name: str,
108
+ show_log: bool,
109
+ show_raw_json: bool,
110
+ ) -> Iterator[Tuple[str, str, str, str, str, str, str, gr.update, gr.update, gr.update]]:
111
+ """
112
+ Main callback: process PDF through the review pipeline with real-time updates.
113
+
114
+ Args:
115
+ file_obj: Uploaded PDF file
116
+ api_base_url: LLM API endpoint URL
117
+ api_key: API key for LLM provider
118
+ model_name: Model identifier
119
+ show_log: Whether to display the execution log
120
+ show_raw_json: Whether to display raw JSON output
121
+
122
+ Yields:
123
+ Tuple of all output component updates (no overview)
124
+ """
125
+ log_lines: list[str] = []
126
+
127
+ def _log(msg: str) -> None:
128
+ log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
129
+
130
+ def _log_text() -> str:
131
+ return "\n".join(log_lines) if log_lines else ""
132
+
133
+ # Validate file upload
134
+ if file_obj is None:
135
+ gr.Warning("Please upload a PDF file to start the review.")
136
+ _log("⚠️ Please upload a PDF file to start the review.")
137
+ yield (
138
+ _log_text(), "", "", "", "", "", "",
139
+ gr.update(interactive=True),
140
+ gr.update(visible=show_log),
141
+ gr.update(visible=show_raw_json),
142
+ )
143
+ return
144
+
145
+ # Check environment
146
+ ok, msg = _check_env()
147
+ if not ok:
148
+ gr.Error(msg)
149
+ _log(f"❌ {msg}")
150
+ yield (
151
+ _log_text(), "", "", "", "", "", "",
152
+ gr.update(interactive=True),
153
+ gr.update(visible=show_log),
154
+ gr.update(visible=show_raw_json),
155
+ )
156
+ return
157
+
158
+ # Start pipeline
159
+ _log("🚀 Pipeline started.")
160
+ yield (
161
+ _log_text(), "", "", "", "", "", "",
162
+ gr.update(interactive=False),
163
+ gr.update(visible=show_log),
164
+ gr.update(visible=show_raw_json),
165
+ )
166
+
167
+ try:
168
+ # Normalize file path
169
+ if isinstance(file_obj, dict) and "name" in file_obj:
170
+ src_path = Path(file_obj["name"])
171
+ else:
172
+ src_path = Path(getattr(file_obj, "name", "") or str(file_obj))
173
+
174
+ if not src_path or not src_path.exists():
175
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
176
+ tmp_path = Path(tmp.name)
177
+ if hasattr(file_obj, "read"):
178
+ tmp.write(file_obj.read())
179
+ src_path = tmp_path
180
+
181
+ # Initialize output variables
182
+ status = f"📄 Extracting text from PDF: {src_path.name}..."
183
+ _log(status)
184
+ yield (
185
+ _log_text(), "", "", "", "", "", "",
186
+ gr.update(interactive=False),
187
+ gr.update(visible=show_log),
188
+ gr.update(visible=show_raw_json),
189
+ )
190
+
191
+ initial = ""
192
+ related_html = ""
193
+ results_html = ""
194
+ insights_html = ""
195
+ final_md = ""
196
+ raw_json = ""
197
+
198
+ # Run the stepwise pipeline
199
+ for ev in run_single_paper_review_from_pdf_stepwise(
200
+ str(src_path),
201
+ api_base_url=api_base_url or None,
202
+ api_key=api_key or None,
203
+ model_name=model_name or None,
204
+ enable_logging=True,
205
+ verbose=True,
206
+ ):
207
+ stage = ev.get("stage")
208
+
209
+ # Handle step-level errors
210
+ if stage == "results_analysis_error":
211
+ err = ev.get("error", "Unknown error")
212
+ gr.Warning(f"Results analysis failed: {err}")
213
+ _log(f"⚠️ Results analysis failed: {err}")
214
+ yield (
215
+ _log_text(), initial, related_html, results_html,
216
+ insights_html, final_md, raw_json,
217
+ gr.update(interactive=False),
218
+ gr.update(visible=show_log),
219
+ gr.update(visible=show_raw_json),
220
+ )
221
+ continue
222
+
223
+ if stage == "insights_error":
224
+ err = ev.get("error", "Unknown error")
225
+ gr.Warning(f"Insight mining failed: {err}")
226
+ _log(f"⚠️ Insight mining failed: {err}")
227
+ yield (
228
+ _log_text(), initial, related_html, results_html,
229
+ insights_html, final_md, raw_json,
230
+ gr.update(interactive=False),
231
+ gr.update(visible=show_log),
232
+ gr.update(visible=show_raw_json),
233
+ )
234
+ continue
235
+
236
+ if stage == "related_work_error":
237
+ err = ev.get("error", "Unknown error")
238
+ gr.Warning(f"Related work search failed: {err}")
239
+ _log(f"⚠️ Related work search failed: {err}")
240
+ yield (
241
+ _log_text(), initial, related_html, results_html,
242
+ insights_html, final_md, raw_json,
243
+ gr.update(interactive=False),
244
+ gr.update(visible=show_log),
245
+ gr.update(visible=show_raw_json),
246
+ )
247
+ continue
248
+
249
+ # Process each pipeline stage
250
+ if stage == "extract_pdf":
251
+ status = f"📄 Extracting text from PDF: {src_path.name}..."
252
+ _log(status)
253
+
254
+ elif stage == "parsed_pdf_text":
255
+ _log("✅ Step 0: Extracting PDF text — done")
256
+ _log("⏳ Step 1: Initial review draft — started")
257
+ yield (
258
+ _log_text(), initial, related_html, results_html,
259
+ insights_html, final_md, raw_json,
260
+ gr.update(interactive=False),
261
+ gr.update(visible=show_log),
262
+ gr.update(visible=show_raw_json),
263
+ )
264
+
265
+ elif stage == "initial_review":
266
+ tmp = {"initial_review": ev.get("initial_review", {})}
267
+ tmp["title"] = ev.get("title") or tmp["initial_review"].get("title")
268
+ tmp["abstract"] = ev.get("abstract") or tmp["initial_review"].get("abstract")
269
+ initial = format_initial_review_html(tmp)
270
+ _log("✅ Step 1: Initial review draft — done")
271
+ _log("⏳ Step 2: Results analysis — started")
272
+ yield (
273
+ _log_text(), initial, related_html, results_html,
274
+ insights_html, final_md, raw_json,
275
+ gr.update(interactive=False),
276
+ gr.update(visible=show_log),
277
+ gr.update(visible=show_raw_json),
278
+ )
279
+
280
+ elif stage == "results_analysis":
281
+ tmp = {"results_analyzer_json": ev.get("results_analyzer_json")}
282
+ results_html = format_results_html(tmp)
283
+ _log("✅ Step 2: Results analysis — done")
284
+ _log("⏳ Step 3: Insight mining — started")
285
+ yield (
286
+ _log_text(), initial, related_html, results_html,
287
+ insights_html, final_md, raw_json,
288
+ gr.update(interactive=False),
289
+ gr.update(visible=show_log),
290
+ gr.update(visible=show_raw_json),
291
+ )
292
+
293
+ elif stage == "insights":
294
+ tmp = {"insight_miner_json": ev.get("insight_miner_json")}
295
+ insights_html = format_insights_html(tmp)
296
+ _log("✅ Step 3: Insight mining — done")
297
+ _log("⏳ Step 4: Related work — started")
298
+ yield (
299
+ _log_text(), initial, related_html, results_html,
300
+ insights_html, final_md, raw_json,
301
+ gr.update(interactive=False),
302
+ gr.update(visible=show_log),
303
+ gr.update(visible=show_raw_json),
304
+ )
305
+
306
+ elif stage == "related_work":
307
+ tmp = {
308
+ "related_work_json_list": ev.get("related_work_json_list"),
309
+ "search_keywords": ev.get("search_keywords"),
310
+ }
311
+ related_html = format_related_work_html(tmp)
312
+ _log("✅ Step 4: Related work — done")
313
+ _log("⏳ Step 5: Final refinement — started")
314
+ yield (
315
+ _log_text(), initial, related_html, results_html,
316
+ insights_html, final_md, raw_json,
317
+ gr.update(interactive=False),
318
+ gr.update(visible=show_log),
319
+ gr.update(visible=show_raw_json),
320
+ )
321
+
322
+ elif stage == "final":
323
+ review = ev.get("review", {}) or {}
324
+ initial = format_initial_review_html(review)
325
+ related_html = format_related_work_html(review) if not related_html else related_html
326
+ results_html = format_results_html(review) if not results_html else results_html
327
+ insights_html = format_insights_html(review) if not insights_html else insights_html
328
+ final_md = format_final_review(review)
329
+ raw_json = format_raw_json(review)
330
+ _log("✅ Step 5: Final refinement — done")
331
+ _log(f"🎉 Review complete for: {src_path.name}")
332
+
333
+ else:
334
+ _log(f"⏳ Working... ({stage})")
335
+
336
+ yield (
337
+ _log_text(), initial, related_html, results_html,
338
+ insights_html, final_md, raw_json,
339
+ gr.update(interactive=False),
340
+ gr.update(visible=show_log),
341
+ gr.update(visible=show_raw_json),
342
+ )
343
+
344
+ # Re-enable button at end
345
+ yield (
346
+ _log_text(), initial, related_html, results_html,
347
+ insights_html, final_md, raw_json,
348
+ gr.update(interactive=True),
349
+ gr.update(visible=show_log),
350
+ gr.update(visible=show_raw_json),
351
+ )
352
+
353
+ except Exception as e:
354
+ import traceback
355
+ error_msg = f"❌ Error during review: {str(e)}"
356
+ error_details = traceback.format_exc()
357
+ gr.Error(f"{error_msg}\n\nDetails: {error_details[:500]}")
358
+ _log(error_msg)
359
+ yield (
360
+ _log_text(), "", "", "", "", "", "",
361
+ gr.update(interactive=True),
362
+ gr.update(visible=show_log),
363
+ gr.update(visible=show_raw_json),
364
+ )
365
+
366
+
367
+ # ============================================================================
368
+ # Build the Gradio App
369
+ # ============================================================================
370
+
371
+ with gr.Blocks(
372
+ title=APP_TITLE,
373
+ css=get_custom_css(),
374
+ theme=gr.themes.Soft(),
375
+ ) as demo:
376
+
377
+ # Header section
378
+ create_header()
379
+
380
+ # Main content: two-column layout
381
+ with gr.Row():
382
+ # Left column: Upload and settings
383
+ with gr.Column(scale=2, elem_classes=["panel-card"]):
384
+ pdf_input, run_button = create_upload_section()
385
+
386
+ # Advanced settings (collapsed by default)
387
+ (
388
+ api_base_url_in,
389
+ api_key_in,
390
+ model_name_in,
391
+ show_log_toggle,
392
+ show_raw_json_toggle,
393
+ ) = create_advanced_settings()
394
+
395
+ # Right column: Results (built from component)
396
+ with gr.Column(scale=3, elem_classes=["panel-card", "results-panel"]):
397
+ (
398
+ initial_html,
399
+ results_html,
400
+ insights_html,
401
+ related_html,
402
+ final_md,
403
+ status_output,
404
+ raw_json_md,
405
+ log_accordion,
406
+ raw_json_tab,
407
+ download_json_btn,
408
+ ) = create_results_panel(show_log=False, show_raw_json=False)
409
+
410
+ # Toggle visibility of log accordion
411
+ show_log_toggle.change(
412
+ fn=lambda x: gr.update(visible=x),
413
+ inputs=[show_log_toggle],
414
+ outputs=[log_accordion],
415
+ )
416
+
417
+ # Toggle visibility of raw JSON tab
418
+ show_raw_json_toggle.change(
419
+ fn=lambda x: gr.update(visible=x),
420
+ inputs=[show_raw_json_toggle],
421
+ outputs=[raw_json_tab],
422
+ )
423
+
424
+ # Download raw JSON as file
425
+ download_json_btn.click(
426
+ fn=_raw_json_md_to_file,
427
+ inputs=[raw_json_md],
428
+ outputs=[download_json_btn],
429
+ )
430
+
431
+ # Main review button click handler
432
+ run_button.click(
433
+ fn=review_pdf_file,
434
+ inputs=[
435
+ pdf_input,
436
+ api_base_url_in,
437
+ api_key_in,
438
+ model_name_in,
439
+ show_log_toggle,
440
+ show_raw_json_toggle,
441
+ ],
442
+ outputs=[
443
+ status_output,
444
+ initial_html,
445
+ related_html,
446
+ results_html,
447
+ insights_html,
448
+ final_md,
449
+ raw_json_md,
450
+ run_button,
451
+ log_accordion,
452
+ raw_json_tab,
453
+ ],
454
+ )
455
+
456
+ # Footer
457
+ gr.HTML("""
458
+ <div class="app-footer">
459
+ <p>🔬 Review Grounder · AI-Powered Research Paper Review</p>
460
+ <p>© 2026 ReviewGrounder. All rights reserved.</p>
461
+ </div>
462
+ """)
463
+
464
+
465
+ # ============================================================================
466
+ # Entry Point
467
+ # ============================================================================
468
+
469
+ if __name__ == "__main__":
470
+ demo.launch()
example.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script for running a single-paper review.
3
+
4
+ This example is intentionally thin and delegates to the reusable
5
+ `review_single_paper_from_text` helper, which is what a Hugging Face
6
+ Space backend would call after performing PDF-to-text conversion.
7
+ """
8
+
9
+ import json
10
+ import sys
11
+ from pathlib import Path
12
+ import logging
13
+
14
+ # Add project root to path for imports
15
+ project_root = Path(__file__).parent.parent
16
+ if str(project_root) not in sys.path:
17
+ sys.path.insert(0, str(project_root))
18
+
19
+ from src.reviewer_agent.single_paper_inference import review_single_paper_from_text
20
+
21
+ logging.basicConfig(level=logging.INFO)
22
+ logging.getLogger("httpx").setLevel(logging.WARNING)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def main() -> None:
27
+ """Run a minimal single-paper review using the first example paper."""
28
+ json_path = project_root / "examples" / "example_papers.json"
29
+
30
+ with open(json_path, "r") as f:
31
+ data = json.load(f)
32
+
33
+ # For demonstration we take only the first paper
34
+ first_paper = data[0]
35
+ paper_text = first_paper.get("paper_context", "")
36
+
37
+ logger.info("Running single-paper review from example_papers.json...")
38
+ review = review_single_paper_from_text(
39
+ paper_text,
40
+ keywords=first_paper.get("keywords"),
41
+ # Use config defaults for review_format and verbosity
42
+ enable_logging=True,
43
+ verbose=False,
44
+ )
45
+
46
+ # Save the review content to a JSON file
47
+ with open("review_content.json", "w") as f:
48
+ json.dump([review], f, indent=2)
49
+
50
+ logger.info("Review saved to review_content.json")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
gradio_app/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio integration package for the anonymized review system.
3
+
4
+ This package contains:
5
+ - Lightweight utilities that wrap the single-paper review pipeline for UI use
6
+ - The Gradio app definition used for Hugging Face Spaces deployment
7
+ """
8
+
9
+ __all__ = ["app", "utils_single_paper_inference"]
gradio_app/app.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Review Grounder - Gradio App
3
+
4
+ Main entry point for the Hugging Face Space.
5
+ This module orchestrates the UI components and handles the review pipeline.
6
+
7
+ The app allows users to:
8
+ 1. Upload a research paper in PDF format
9
+ 2. Configure LLM settings (optional, uses OpenAI defaults)
10
+ 3. Generate a comprehensive AI-powered review
11
+ 4. View intermediate results from each pipeline stage
12
+
13
+ Components are organized in the `components/` directory for maintainability.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import re
20
+ import tempfile
21
+ from datetime import datetime
22
+ from pathlib import Path
23
+ from typing import Tuple, Iterator
24
+
25
+ import gradio as gr
26
+
27
+ # Import utility for running the review pipeline
28
+ from utils_single_paper_inference import (
29
+ run_single_paper_review_from_pdf_stepwise,
30
+ )
31
+
32
+ # Import UI components
33
+ from components import (
34
+ get_custom_css,
35
+ create_header,
36
+ create_upload_section,
37
+ create_advanced_settings,
38
+ create_results_panel,
39
+ format_initial_review_html,
40
+ format_related_work_html,
41
+ format_results_html,
42
+ format_insights_html,
43
+ format_final_review,
44
+ format_raw_json,
45
+ )
46
+
47
+
48
+ # ============================================================================
49
+ # App Configuration
50
+ # ============================================================================
51
+
52
+ APP_TITLE = "Review Grounder"
53
+ APP_DESCRIPTION = "AI-Powered Research Paper Review"
54
+
55
+
56
+ def _raw_json_md_to_file(raw_json_md: str) -> str:
57
+ """
58
+ Extract JSON from Raw JSON markdown (```json ... ```) and write to a temp file.
59
+ Returns the file path for gr.DownloadButton.
60
+ """
61
+ if not raw_json_md or not raw_json_md.strip():
62
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
63
+ f.write("{}")
64
+ return f.name
65
+ text = raw_json_md.strip()
66
+ match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
67
+ if match:
68
+ text = match.group(1).strip()
69
+ fd, path = tempfile.mkstemp(suffix=".json", prefix="review_")
70
+ with os.fdopen(fd, "w", encoding="utf-8") as f:
71
+ f.write(text)
72
+ return path
73
+
74
+
75
+ # ============================================================================
76
+ # Environment Check
77
+ # ============================================================================
78
+
79
+ def _check_env() -> Tuple[bool, str]:
80
+ """
81
+ Check for required environment variables.
82
+
83
+ Returns:
84
+ Tuple of (success, message)
85
+ """
86
+ missing = []
87
+ if not os.environ.get("ASTA_API_KEY"):
88
+ missing.append("ASTA_API_KEY")
89
+
90
+ if missing:
91
+ return False, (
92
+ "Missing environment variables: "
93
+ + ", ".join(missing)
94
+ + ".\nPlease configure them in your Hugging Face Space settings."
95
+ )
96
+ return True, "Environment variables detected correctly."
97
+
98
+
99
+ # ============================================================================
100
+ # Review Pipeline Handler
101
+ # ============================================================================
102
+
103
+ def review_pdf_file(
104
+ file_obj,
105
+ api_base_url: str,
106
+ api_key: str,
107
+ model_name: str,
108
+ show_log: bool,
109
+ show_raw_json: bool,
110
+ ) -> Iterator[Tuple[str, str, str, str, str, str, str, gr.update, gr.update, gr.update]]:
111
+ """
112
+ Main callback: process PDF through the review pipeline with real-time updates.
113
+
114
+ Args:
115
+ file_obj: Uploaded PDF file
116
+ api_base_url: LLM API endpoint URL
117
+ api_key: API key for LLM provider
118
+ model_name: Model identifier
119
+ show_log: Whether to display the execution log
120
+ show_raw_json: Whether to display raw JSON output
121
+
122
+ Yields:
123
+ Tuple of all output component updates (no overview)
124
+ """
125
+ log_lines: list[str] = []
126
+
127
+ def _log(msg: str) -> None:
128
+ log_lines.append(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
129
+
130
+ def _log_text() -> str:
131
+ return "\n".join(log_lines) if log_lines else ""
132
+
133
+ # Validate file upload
134
+ if file_obj is None:
135
+ gr.Warning("Please upload a PDF file to start the review.")
136
+ _log("⚠️ Please upload a PDF file to start the review.")
137
+ yield (
138
+ _log_text(), "", "", "", "", "", "",
139
+ gr.update(interactive=True),
140
+ gr.update(visible=show_log),
141
+ gr.update(visible=show_raw_json),
142
+ )
143
+ return
144
+
145
+ # Check environment
146
+ ok, msg = _check_env()
147
+ if not ok:
148
+ gr.Error(msg)
149
+ _log(f"❌ {msg}")
150
+ yield (
151
+ _log_text(), "", "", "", "", "", "",
152
+ gr.update(interactive=True),
153
+ gr.update(visible=show_log),
154
+ gr.update(visible=show_raw_json),
155
+ )
156
+ return
157
+
158
+ # Start pipeline
159
+ _log("🚀 Pipeline started.")
160
+ yield (
161
+ _log_text(), "", "", "", "", "", "",
162
+ gr.update(interactive=False),
163
+ gr.update(visible=show_log),
164
+ gr.update(visible=show_raw_json),
165
+ )
166
+
167
+ try:
168
+ # Normalize file path
169
+ if isinstance(file_obj, dict) and "name" in file_obj:
170
+ src_path = Path(file_obj["name"])
171
+ else:
172
+ src_path = Path(getattr(file_obj, "name", "") or str(file_obj))
173
+
174
+ if not src_path or not src_path.exists():
175
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
176
+ tmp_path = Path(tmp.name)
177
+ if hasattr(file_obj, "read"):
178
+ tmp.write(file_obj.read())
179
+ src_path = tmp_path
180
+
181
+ # Initialize output variables
182
+ status = f"📄 Extracting text from PDF: {src_path.name}..."
183
+ _log(status)
184
+ yield (
185
+ _log_text(), "", "", "", "", "", "",
186
+ gr.update(interactive=False),
187
+ gr.update(visible=show_log),
188
+ gr.update(visible=show_raw_json),
189
+ )
190
+
191
+ initial = ""
192
+ related_html = ""
193
+ results_html = ""
194
+ insights_html = ""
195
+ final_md = ""
196
+ raw_json = ""
197
+
198
+ # Run the stepwise pipeline
199
+ for ev in run_single_paper_review_from_pdf_stepwise(
200
+ str(src_path),
201
+ api_base_url=api_base_url or None,
202
+ api_key=api_key or None,
203
+ model_name=model_name or None,
204
+ enable_logging=True,
205
+ verbose=True,
206
+ ):
207
+ stage = ev.get("stage")
208
+
209
+ # Handle step-level errors
210
+ if stage == "results_analysis_error":
211
+ err = ev.get("error", "Unknown error")
212
+ gr.Warning(f"Results analysis failed: {err}")
213
+ _log(f"⚠️ Results analysis failed: {err}")
214
+ yield (
215
+ _log_text(), initial, related_html, results_html,
216
+ insights_html, final_md, raw_json,
217
+ gr.update(interactive=False),
218
+ gr.update(visible=show_log),
219
+ gr.update(visible=show_raw_json),
220
+ )
221
+ continue
222
+
223
+ if stage == "insights_error":
224
+ err = ev.get("error", "Unknown error")
225
+ gr.Warning(f"Insight mining failed: {err}")
226
+ _log(f"⚠️ Insight mining failed: {err}")
227
+ yield (
228
+ _log_text(), initial, related_html, results_html,
229
+ insights_html, final_md, raw_json,
230
+ gr.update(interactive=False),
231
+ gr.update(visible=show_log),
232
+ gr.update(visible=show_raw_json),
233
+ )
234
+ continue
235
+
236
+ if stage == "related_work_error":
237
+ err = ev.get("error", "Unknown error")
238
+ gr.Warning(f"Related work search failed: {err}")
239
+ _log(f"⚠️ Related work search failed: {err}")
240
+ yield (
241
+ _log_text(), initial, related_html, results_html,
242
+ insights_html, final_md, raw_json,
243
+ gr.update(interactive=False),
244
+ gr.update(visible=show_log),
245
+ gr.update(visible=show_raw_json),
246
+ )
247
+ continue
248
+
249
+ # Process each pipeline stage
250
+ if stage == "extract_pdf":
251
+ status = f"📄 Extracting text from PDF: {src_path.name}..."
252
+ _log(status)
253
+
254
+ elif stage == "parsed_pdf_text":
255
+ _log("✅ Step 0: Extracting PDF text — done")
256
+ _log("⏳ Step 1: Initial review draft — started")
257
+ yield (
258
+ _log_text(), initial, related_html, results_html,
259
+ insights_html, final_md, raw_json,
260
+ gr.update(interactive=False),
261
+ gr.update(visible=show_log),
262
+ gr.update(visible=show_raw_json),
263
+ )
264
+
265
+ elif stage == "initial_review":
266
+ tmp = {"initial_review": ev.get("initial_review", {})}
267
+ tmp["title"] = ev.get("title") or tmp["initial_review"].get("title")
268
+ tmp["abstract"] = ev.get("abstract") or tmp["initial_review"].get("abstract")
269
+ initial = format_initial_review_html(tmp)
270
+ _log("✅ Step 1: Initial review draft — done")
271
+ _log("⏳ Step 2: Results analysis — started")
272
+ yield (
273
+ _log_text(), initial, related_html, results_html,
274
+ insights_html, final_md, raw_json,
275
+ gr.update(interactive=False),
276
+ gr.update(visible=show_log),
277
+ gr.update(visible=show_raw_json),
278
+ )
279
+
280
+ elif stage == "results_analysis":
281
+ tmp = {"results_analyzer_json": ev.get("results_analyzer_json")}
282
+ results_html = format_results_html(tmp)
283
+ _log("✅ Step 2: Results analysis — done")
284
+ _log("⏳ Step 3: Insight mining — started")
285
+ yield (
286
+ _log_text(), initial, related_html, results_html,
287
+ insights_html, final_md, raw_json,
288
+ gr.update(interactive=False),
289
+ gr.update(visible=show_log),
290
+ gr.update(visible=show_raw_json),
291
+ )
292
+
293
+ elif stage == "insights":
294
+ tmp = {"insight_miner_json": ev.get("insight_miner_json")}
295
+ insights_html = format_insights_html(tmp)
296
+ _log("✅ Step 3: Insight mining — done")
297
+ _log("⏳ Step 4: Related work — started")
298
+ yield (
299
+ _log_text(), initial, related_html, results_html,
300
+ insights_html, final_md, raw_json,
301
+ gr.update(interactive=False),
302
+ gr.update(visible=show_log),
303
+ gr.update(visible=show_raw_json),
304
+ )
305
+
306
+ elif stage == "related_work":
307
+ tmp = {
308
+ "related_work_json_list": ev.get("related_work_json_list"),
309
+ "search_keywords": ev.get("search_keywords"),
310
+ }
311
+ related_html = format_related_work_html(tmp)
312
+ _log("✅ Step 4: Related work — done")
313
+ _log("⏳ Step 5: Final refinement — started")
314
+ yield (
315
+ _log_text(), initial, related_html, results_html,
316
+ insights_html, final_md, raw_json,
317
+ gr.update(interactive=False),
318
+ gr.update(visible=show_log),
319
+ gr.update(visible=show_raw_json),
320
+ )
321
+
322
+ elif stage == "final":
323
+ review = ev.get("review", {}) or {}
324
+ initial = format_initial_review_html(review)
325
+ related_html = format_related_work_html(review) if not related_html else related_html
326
+ results_html = format_results_html(review) if not results_html else results_html
327
+ insights_html = format_insights_html(review) if not insights_html else insights_html
328
+ final_md = format_final_review(review)
329
+ raw_json = format_raw_json(review)
330
+ _log("✅ Step 5: Final refinement — done")
331
+ _log(f"🎉 Review complete for: {src_path.name}")
332
+
333
+ else:
334
+ _log(f"⏳ Working... ({stage})")
335
+
336
+ yield (
337
+ _log_text(), initial, related_html, results_html,
338
+ insights_html, final_md, raw_json,
339
+ gr.update(interactive=False),
340
+ gr.update(visible=show_log),
341
+ gr.update(visible=show_raw_json),
342
+ )
343
+
344
+ # Re-enable button at end
345
+ yield (
346
+ _log_text(), initial, related_html, results_html,
347
+ insights_html, final_md, raw_json,
348
+ gr.update(interactive=True),
349
+ gr.update(visible=show_log),
350
+ gr.update(visible=show_raw_json),
351
+ )
352
+
353
+ except Exception as e:
354
+ import traceback
355
+ error_msg = f"❌ Error during review: {str(e)}"
356
+ error_details = traceback.format_exc()
357
+ gr.Error(f"{error_msg}\n\nDetails: {error_details[:500]}")
358
+ _log(error_msg)
359
+ yield (
360
+ _log_text(), "", "", "", "", "", "",
361
+ gr.update(interactive=True),
362
+ gr.update(visible=show_log),
363
+ gr.update(visible=show_raw_json),
364
+ )
365
+
366
+
367
+ # ============================================================================
368
+ # Build the Gradio App
369
+ # ============================================================================
370
+
371
+ with gr.Blocks(title=APP_TITLE) as demo:
372
+
373
+ # Header section
374
+ create_header()
375
+
376
+ # Main content: two-column layout
377
+ with gr.Row():
378
+ # Left column: Upload and settings
379
+ with gr.Column(scale=2, elem_classes=["panel-card"]):
380
+ pdf_input, run_button = create_upload_section()
381
+
382
+ # Advanced settings (collapsed by default)
383
+ (
384
+ api_base_url_in,
385
+ api_key_in,
386
+ model_name_in,
387
+ show_log_toggle,
388
+ show_raw_json_toggle,
389
+ ) = create_advanced_settings()
390
+
391
+ # Right column: Results (built from component)
392
+ with gr.Column(scale=3, elem_classes=["panel-card", "results-panel"]):
393
+ (
394
+ initial_html,
395
+ results_html,
396
+ insights_html,
397
+ related_html,
398
+ final_md,
399
+ status_output,
400
+ raw_json_md,
401
+ log_accordion,
402
+ raw_json_tab,
403
+ download_json_btn,
404
+ ) = create_results_panel(show_log=False, show_raw_json=False)
405
+
406
+ # Toggle visibility of log accordion
407
+ show_log_toggle.change(
408
+ fn=lambda x: gr.update(visible=x),
409
+ inputs=[show_log_toggle],
410
+ outputs=[log_accordion],
411
+ )
412
+
413
+ # Toggle visibility of raw JSON tab
414
+ show_raw_json_toggle.change(
415
+ fn=lambda x: gr.update(visible=x),
416
+ inputs=[show_raw_json_toggle],
417
+ outputs=[raw_json_tab],
418
+ )
419
+
420
+ # Download raw JSON as file
421
+ download_json_btn.click(
422
+ fn=_raw_json_md_to_file,
423
+ inputs=[raw_json_md],
424
+ outputs=[download_json_btn],
425
+ )
426
+
427
+ # Main review button click handler
428
+ run_button.click(
429
+ fn=review_pdf_file,
430
+ inputs=[
431
+ pdf_input,
432
+ api_base_url_in,
433
+ api_key_in,
434
+ model_name_in,
435
+ show_log_toggle,
436
+ show_raw_json_toggle,
437
+ ],
438
+ outputs=[
439
+ status_output,
440
+ initial_html,
441
+ related_html,
442
+ results_html,
443
+ insights_html,
444
+ final_md,
445
+ raw_json_md,
446
+ run_button,
447
+ log_accordion,
448
+ raw_json_tab,
449
+ ],
450
+ )
451
+
452
+ # Footer
453
+ gr.HTML("""
454
+ <div class="app-footer">
455
+ <p>🔬 Review Grounder · AI-Powered Research Paper Review</p>
456
+ <p>© 2026 ReviewGrounder. All rights reserved.</p>
457
+ </div>
458
+ """)
459
+
460
+
461
+ # ============================================================================
462
+ # Entry Point
463
+ # ============================================================================
464
+
465
+ if __name__ == "__main__":
466
+ demo.launch(css=get_custom_css(), theme=gr.themes.Soft())
gradio_app/components/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Components package for the Review Grounder Gradio app.
3
+
4
+ This package contains modular UI components for building
5
+ the Review Grounder interface.
6
+
7
+ Modules:
8
+ - styles: Custom CSS styles
9
+ - formatters: Data formatting utilities
10
+ - header: App header component
11
+ - upload_section: PDF upload and instructions
12
+ - settings: Advanced settings panel
13
+ - results_panel: Review results display
14
+ """
15
+
16
+ from .styles import get_custom_css
17
+ from .formatters import (
18
+ safe_json_parse,
19
+ format_overview,
20
+ format_initial_review,
21
+ format_initial_review_html,
22
+ format_related_work_html,
23
+ format_results_html,
24
+ format_insights_html,
25
+ format_final_review,
26
+ format_raw_json,
27
+ )
28
+ from .header import create_header
29
+ from .upload_section import (
30
+ create_how_it_works,
31
+ create_upload_area,
32
+ create_action_buttons,
33
+ create_upload_section,
34
+ )
35
+ from .settings import (
36
+ create_advanced_settings,
37
+ DEFAULT_API_ENDPOINT,
38
+ DEFAULT_MODEL_NAME,
39
+ )
40
+ from .results_panel import (
41
+ create_results_placeholder,
42
+ create_results_panel,
43
+ )
44
+
45
+
46
+ __all__ = [
47
+ # Styles
48
+ "get_custom_css",
49
+ # Formatters
50
+ "safe_json_parse",
51
+ "format_overview",
52
+ "format_initial_review",
53
+ "format_initial_review_html",
54
+ "format_related_work_html",
55
+ "format_results_html",
56
+ "format_insights_html",
57
+ "format_final_review",
58
+ "format_raw_json",
59
+ # Header
60
+ "create_header",
61
+ # Upload section
62
+ "create_how_it_works",
63
+ "create_upload_area",
64
+ "create_action_buttons",
65
+ "create_upload_section",
66
+ # Settings
67
+ "create_advanced_settings",
68
+ "DEFAULT_API_ENDPOINT",
69
+ "DEFAULT_MODEL_NAME",
70
+ # Results panel
71
+ "create_results_placeholder",
72
+ "create_results_panel",
73
+ ]
gradio_app/components/formatters.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Formatting utilities for the Review Grounder Gradio app.
3
+
4
+ This module contains all functions for formatting review data
5
+ into displayable HTML or Markdown for the UI components.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from typing import Any
12
+
13
+
14
+ def safe_json_parse(value: Any) -> Any:
15
+ """
16
+ Safely parse a JSON string or return the value if already parsed.
17
+
18
+ Args:
19
+ value: A JSON string or already-parsed object
20
+
21
+ Returns:
22
+ Parsed JSON object or None if parsing fails
23
+ """
24
+ if value is None:
25
+ return None
26
+ try:
27
+ if isinstance(value, str):
28
+ return json.loads(value)
29
+ return value
30
+ except Exception as e:
31
+ # print out the exact error
32
+ print(f"Error parsing JSON: {e}")
33
+
34
+ return None
35
+
36
+
37
+ def format_overview(review: dict) -> str:
38
+ """
39
+ Format high-level overview: scores and keywords only.
40
+
41
+ Args:
42
+ review: The review dictionary containing scores and metadata
43
+
44
+ Returns:
45
+ Formatted Markdown string with scores and search keywords
46
+ """
47
+ if not review:
48
+ return "No review data."
49
+
50
+ scores = review.get("scores", {}) or {}
51
+ rating = scores.get("rating") or review.get("rating")
52
+ confidence = scores.get("confidence") or review.get("confidence")
53
+ decision = scores.get("decision") or review.get("decision")
54
+
55
+ parts = [
56
+ "### Scores",
57
+ f"- **Rating**: {rating if rating is not None else 'N/A'}",
58
+ f"- **Confidence**: {confidence if confidence is not None else 'N/A'}",
59
+ f"- **Decision**: {decision or 'N/A'}",
60
+ ]
61
+
62
+ keywords = review.get("search_keywords")
63
+ if keywords:
64
+ parts.append("")
65
+ parts.append("### Search Keywords")
66
+ parts.append("".join(f"- {k}\n" for k in keywords).rstrip())
67
+
68
+ return "\n".join(parts)
69
+
70
+
71
+ def _escape_html(text: str) -> str:
72
+ """Escape HTML special characters for safe display."""
73
+ if not text:
74
+ return ""
75
+ return (
76
+ str(text)
77
+ .replace("&", "&amp;")
78
+ .replace("<", "&lt;")
79
+ .replace(">", "&gt;")
80
+ .replace('"', "&quot;")
81
+ )
82
+
83
+
84
+ def format_initial_review(review: dict) -> str:
85
+ """
86
+ Format the initial draft review as plain text (legacy).
87
+ Prefer format_initial_review_html for UI display.
88
+ """
89
+ initial = review.get("initial_review")
90
+ if not initial:
91
+ return "Initial draft review not available (pipeline may have failed early)."
92
+ text = initial.get("review") or ""
93
+ if not text:
94
+ return json.dumps(initial, indent=2, ensure_ascii=False)
95
+ return text
96
+
97
+
98
+ def format_initial_review_html(review: dict) -> str:
99
+ """
100
+ Format the initial draft review as styled HTML cards (never raw JSON string).
101
+
102
+ Renders Summary, scores, Strengths, Weaknesses, Questions in card/section
103
+ layout. If initial_review comes as a JSON string, parses it first then renders cards.
104
+
105
+ Args:
106
+ review: The review dictionary containing initial_review
107
+
108
+ Returns:
109
+ HTML string for display in gr.HTML
110
+ """
111
+ initial = review.get("initial_review")
112
+ if not initial:
113
+ return "<p class='review-message'>Initial draft review not available (pipeline may have failed early).</p>"
114
+
115
+ # If backend passed a JSON string, parse to dict so we can render cards
116
+ if isinstance(initial, str):
117
+ initial = safe_json_parse(initial) or {}
118
+ if not initial:
119
+ return "<p class='review-message'>Initial draft data could not be parsed.</p>"
120
+
121
+
122
+
123
+ # Prefer structured HTML when we have typical draft fields
124
+ if _looks_like_raw_json(initial):
125
+ pass # fall through to structured format below
126
+ else:
127
+ print(f"[WARN] Failed to parse initial draft as JSON, treating it as single text.")
128
+ # Single "review" text only
129
+ text = initial.get("review") or ""
130
+ if text:
131
+ return f'<div class="review-draft-content"><div class="review-text">{_nl2br(text)}</div></div>'
132
+
133
+ # Structured format: summary, scores, strengths, weaknesses, questions, etc.
134
+ html = '<div class="initial-draft-card card-grid"><div class="card"><h4>📝 Initial Draft</h4>'
135
+
136
+ summary = initial.get("summary") or ""
137
+ if summary:
138
+ html += f'<div class="kv"><div class="k">Summary</div><div class="v">{_escape_html(summary)}</div></div>'
139
+
140
+ score_fields = [
141
+ ("soundness", "Soundness"),
142
+ ("presentation", "Presentation"),
143
+ ("contribution", "Contribution"),
144
+ ("rating", "Rating"),
145
+ ("confidence", "Confidence"),
146
+ ("decision", "Decision"),
147
+ ]
148
+ for key, label in score_fields:
149
+ val = initial.get(key)
150
+ if val is not None and val != "":
151
+ html += f'<div class="kv"><div class="k">{label}</div><div class="v">{_escape_html(str(val))}</div></div>'
152
+
153
+ strengths = initial.get("strengths") or ""
154
+ if strengths:
155
+ html += f'<div class="kv"><div class="k">Strengths</div><div class="v">{_nl2br(strengths)}</div></div>'
156
+
157
+ weaknesses = initial.get("weaknesses") or ""
158
+ if weaknesses:
159
+ html += f'<div class="kv"><div class="k">Weaknesses</div><div class="v">{_nl2br(weaknesses)}</div></div>'
160
+
161
+ questions = initial.get("questions")
162
+ if questions:
163
+ if isinstance(questions, list):
164
+ q_text = "\n".join(f"• {q}" for q in questions if q)
165
+ else:
166
+ q_text = str(questions)
167
+ if q_text:
168
+ html += f'<div class="kv"><div class="k">Questions</div><div class="v">{_nl2br(q_text)}</div></div>'
169
+
170
+ html += "</div></div>"
171
+ return html
172
+
173
+
174
+ def _looks_like_raw_json(obj: Any) -> bool:
175
+ """Heuristic: dict has typical review keys (summary, strengths) then treat as structured."""
176
+ if not isinstance(obj, dict):
177
+ return False
178
+ return any(k in obj for k in ("summary", "strengths", "weaknesses", "soundness", "rating"))
179
+
180
+
181
+ def _nl2br(text: str) -> str:
182
+ """Escape HTML and convert newlines to <br> for safe display."""
183
+ if not text:
184
+ return ""
185
+ return _escape_html(text).replace("\n", "<br>\n")
186
+
187
+
188
+ def format_related_work_html(review: dict) -> str:
189
+ """
190
+ Format related work as HTML with styled cards.
191
+
192
+ Args:
193
+ review: The review dictionary containing related_work_json_list
194
+
195
+ Returns:
196
+ HTML string with related work cards
197
+ """
198
+ rw = review.get("related_work_json_list")
199
+ if not rw:
200
+ return "<p>No related work information available.</p>"
201
+
202
+ try:
203
+ data = json.loads(rw) if isinstance(rw, str) else rw
204
+ except Exception:
205
+ return f"<p>Error parsing related work data: {str(rw)[:200]}</p>"
206
+
207
+ if not data:
208
+ return "<p>No related work summaries found.</p>"
209
+
210
+ html = '<div class="related-work-container"><h3>Related Work Summaries</h3>'
211
+
212
+ for idx, item in enumerate(data, start=1):
213
+ summary = item.get("summary", "").strip()
214
+ main_methods = item.get("main_methods", "").strip()
215
+ key_findings = item.get("key_findings", "").strip()
216
+ relation = item.get("relation", "").strip()
217
+
218
+ html += f'<div class="related-paper-card">'
219
+ html += f'<div class="paper-header">{idx}. {summary[:100] or "Related paper"}...</div>'
220
+
221
+ if summary:
222
+ html += f'''
223
+ <div class="paper-field">
224
+ <div class="paper-field-label">Summary</div>
225
+ <div class="paper-field-value">{summary}</div>
226
+ </div>
227
+ '''
228
+ if main_methods:
229
+ html += f'''
230
+ <div class="paper-field">
231
+ <div class="paper-field-label">Main Methods</div>
232
+ <div class="paper-field-value">{main_methods}</div>
233
+ </div>
234
+ '''
235
+ if key_findings:
236
+ html += f'''
237
+ <div class="paper-field">
238
+ <div class="paper-field-label">Key Findings</div>
239
+ <div class="paper-field-value">{key_findings}</div>
240
+ </div>
241
+ '''
242
+ if relation:
243
+ html += f'''
244
+ <div class="paper-field">
245
+ <div class="paper-field-label">Relation</div>
246
+ <div class="paper-field-value">{relation}</div>
247
+ </div>
248
+ '''
249
+ html += "</div>"
250
+
251
+ html += "</div>"
252
+ return html
253
+
254
+
255
+ def _render_review_issues_html(issues: dict) -> str:
256
+ """
257
+ Render review issues (incorrect, missing, needs specificity) as HTML.
258
+
259
+ Args:
260
+ issues: Dictionary containing issue categories
261
+
262
+ Returns:
263
+ HTML string with formatted issues
264
+ """
265
+ if not issues:
266
+ return ""
267
+
268
+ def render_issue_list(title: str, items: list) -> str:
269
+ if not items:
270
+ return ""
271
+ blocks = []
272
+ for it in items:
273
+ if isinstance(it, dict):
274
+ head = (
275
+ it.get("review_claim")
276
+ or it.get("what_missing")
277
+ or it.get("review_text")
278
+ or "Issue"
279
+ )
280
+ body_parts = []
281
+ for k in ["why_wrong", "why_important", "how_to_fix", "evidence"]:
282
+ if it.get(k):
283
+ body_parts.append(
284
+ f"<div class='k'>{k.replace('_', ' ').title()}</div>"
285
+ f"<div class='v'>{it.get(k)}</div>"
286
+ )
287
+ body = "".join(body_parts) or (
288
+ f"<div class='v mono'>{json.dumps(it, indent=2, ensure_ascii=False)}</div>"
289
+ )
290
+ blocks.append(f"<details><summary>{head}</summary>{body}</details>")
291
+ else:
292
+ blocks.append(f"<div class='v'>{str(it)}</div>")
293
+ return (
294
+ f'<div class="kv"><div class="k">{title}</div><div class="v">'
295
+ + "".join(blocks)
296
+ + "</div></div>"
297
+ )
298
+
299
+ html = ""
300
+ html += render_issue_list("Incorrect / Hallucinated", issues.get("incorrect_or_hallucinated", []))
301
+ html += render_issue_list("Missing Key Points", issues.get("missing_key_points", []))
302
+ html += render_issue_list("Needs Specificity", issues.get("needs_specificity", []))
303
+ return html
304
+
305
+
306
+ def _render_rewrite_suggestions_html(suggestions: list) -> str:
307
+ """
308
+ Render rewrite suggestions as collapsible HTML blocks.
309
+
310
+ Args:
311
+ suggestions: List of rewrite suggestion dictionaries
312
+
313
+ Returns:
314
+ HTML string with formatted suggestions
315
+ """
316
+ if not suggestions:
317
+ return ""
318
+
319
+ blocks = []
320
+ for s in suggestions:
321
+ if isinstance(s, dict):
322
+ head = f"{s.get('apply_to', 'Rewrite')} · {s.get('target', '')}".strip(" ·")
323
+ suggested = s.get("suggested_text", "")
324
+ evidence = s.get("evidence", "")
325
+ body = ""
326
+ if suggested:
327
+ body += f"<div class='k'>Suggested Text</div><div class='v'>{suggested}</div>"
328
+ if evidence:
329
+ body += f"<div class='k'>Evidence</div><div class='v'>{evidence}</div>"
330
+ blocks.append(
331
+ f"<details><summary>{head or 'Rewrite suggestion'}</summary>{body}</details>"
332
+ )
333
+ else:
334
+ blocks.append(f"<div class='v'>{str(s)}</div>")
335
+
336
+ return (
337
+ '<div class="kv"><div class="k">Rewrite Suggestions</div><div class="v">'
338
+ + "".join(blocks)
339
+ + "</div></div>"
340
+ )
341
+
342
+
343
+ def format_results_html(review: dict) -> str:
344
+ """
345
+ Format results analyzer output as structured HTML cards.
346
+
347
+ Args:
348
+ review: The review dictionary containing results_analyzer_json
349
+
350
+ Returns:
351
+ HTML string with formatted results analysis
352
+ """
353
+ parsed = safe_json_parse(review.get("results_analyzer_json"))
354
+ if not parsed:
355
+ return "<p>Unable to parse the results analysis data.</p>"
356
+
357
+ facts = parsed.get("facts", {}) if isinstance(parsed, dict) else {}
358
+ datasets = facts.get("datasets", [])
359
+ metrics = facts.get("metrics", [])
360
+ baselines = facts.get("baselines", [])
361
+ key_results = facts.get("key_results", [])
362
+ review_issues = parsed.get("review_issues", {}) if isinstance(parsed, dict) else {}
363
+ rewrite_suggestions = parsed.get("rewrite_suggestions", []) if isinstance(parsed, dict) else []
364
+
365
+ html = '<div class="card-grid">'
366
+ html += '<div class="card"><h4>Results Analysis <span class="pill">structured</span></h4>'
367
+
368
+ if datasets:
369
+ html += (
370
+ '<div class="kv"><div class="k">Datasets</div><div class="v">'
371
+ + "\n".join(f"- {x}" for x in datasets)
372
+ + "</div></div>"
373
+ )
374
+ if metrics:
375
+ html += (
376
+ '<div class="kv"><div class="k">Metrics</div><div class="v">'
377
+ + "\n".join(f"- {x}" for x in metrics)
378
+ + "</div></div>"
379
+ )
380
+ if baselines:
381
+ html += (
382
+ '<div class="kv"><div class="k">Baselines</div><div class="v">'
383
+ + "\n".join(f"- {x}" for x in baselines)
384
+ + "</div></div>"
385
+ )
386
+ if key_results:
387
+ items = []
388
+ for kr in key_results:
389
+ if isinstance(kr, dict):
390
+ claim = kr.get("claim", "")
391
+ evidence = kr.get("evidence", "")
392
+ block = (
393
+ f"<details><summary>{claim or 'Key result'}</summary>"
394
+ f"<div class='v'>{evidence}</div></details>"
395
+ )
396
+ items.append(block)
397
+ if items:
398
+ html += (
399
+ '<div class="kv"><div class="k">Key Results (claim → evidence)</div>'
400
+ '<div class="v">' + "".join(items) + "</div></div>"
401
+ )
402
+
403
+ if review_issues:
404
+ html += _render_review_issues_html(review_issues)
405
+ if rewrite_suggestions:
406
+ html += _render_rewrite_suggestions_html(rewrite_suggestions)
407
+
408
+ html += "</div></div>"
409
+ return html
410
+
411
+
412
+ def format_insights_html(review: dict) -> str:
413
+ """
414
+ Format insight miner output as structured HTML cards.
415
+
416
+ Args:
417
+ review: The review dictionary containing insight_miner_json
418
+
419
+ Returns:
420
+ HTML string with formatted insights
421
+ """
422
+ parsed = safe_json_parse(review.get("insight_miner_json"))
423
+ if not parsed:
424
+ return "<p>Unable to parse the insights data.</p>"
425
+
426
+ facts = parsed.get("facts", {}) if isinstance(parsed, dict) else {}
427
+ review_issues = parsed.get("review_issues", {}) if isinstance(parsed, dict) else {}
428
+ rewrite_suggestions = parsed.get("rewrite_suggestions", []) if isinstance(parsed, dict) else []
429
+
430
+ def render_list(title: str, items: list) -> str:
431
+ if not items:
432
+ return ""
433
+ blocks = []
434
+ for it in items:
435
+ if isinstance(it, dict):
436
+ head = (
437
+ it.get("claim")
438
+ or it.get("point")
439
+ or it.get("item")
440
+ or it.get("what_missing")
441
+ or "Item"
442
+ )
443
+ evidence = (
444
+ it.get("evidence")
445
+ or it.get("why_important")
446
+ or it.get("why_wrong")
447
+ or it.get("how_to_fix")
448
+ or ""
449
+ )
450
+ blocks.append(
451
+ f"<details><summary>{head}</summary><div class='v'>{evidence}</div></details>"
452
+ )
453
+ else:
454
+ blocks.append(f"<div class='v'>{str(it)}</div>")
455
+ return (
456
+ f'<div class="kv"><div class="k">{title}</div><div class="v">'
457
+ + "".join(blocks)
458
+ + "</div></div>"
459
+ )
460
+
461
+ html = '<div class="card-grid">'
462
+ html += '<div class="card"><h4>Paper Insights <span class="pill">structured</span></h4>'
463
+
464
+ html += render_list("Core Contributions", facts.get("core_contributions", []))
465
+ html += render_list("Method Summary", facts.get("method_summary", []))
466
+ html += render_list("Assumptions & Scope", facts.get("assumptions_and_scope", []))
467
+ html += render_list("Novelty Claims (paper)", facts.get("novelty_claims_in_paper", []))
468
+
469
+ if review_issues:
470
+ html += _render_review_issues_html(review_issues)
471
+ if rewrite_suggestions:
472
+ html += _render_rewrite_suggestions_html(rewrite_suggestions)
473
+
474
+ html += "</div></div>"
475
+ return html
476
+
477
+
478
+ def format_final_review(review: dict) -> str:
479
+ """
480
+ Extract and return the final review markdown.
481
+
482
+ Args:
483
+ review: The review dictionary containing the final review
484
+
485
+ Returns:
486
+ The final review markdown string
487
+ """
488
+ return review.get("review_markdown") or review.get("review") or "Final review markdown missing."
489
+
490
+
491
+ def format_raw_json(review: dict) -> str:
492
+ """
493
+ Format the complete review as a JSON code block.
494
+
495
+ Args:
496
+ review: The complete review dictionary
497
+
498
+ Returns:
499
+ JSON formatted as a Markdown code block
500
+ """
501
+ try:
502
+ return "```json\n" + json.dumps(review, indent=2, ensure_ascii=False) + "\n```"
503
+ except Exception:
504
+ return str(review)
gradio_app/components/header.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Header component for the Review Grounder Gradio app.
3
+
4
+ This module provides the app header with gradient background,
5
+ title, BETA badge, and privacy notice banner.
6
+ """
7
+
8
+ import gradio as gr
9
+
10
+
11
+ def create_header() -> None:
12
+ """
13
+ Create the app header with gradient background and privacy notice.
14
+
15
+ Renders:
16
+ - Purple gradient header with title and BETA badge
17
+ - Privacy notice banner explaining the demo mode
18
+ """
19
+ # Main header with gradient background
20
+ gr.HTML("""
21
+ <div class="app-header">
22
+ <div class="app-header-content">
23
+ <h1 class="app-title" style="margin-bottom: 12px;">
24
+ 🔬 Review Grounder
25
+ <span class="beta-badge">BETA</span>
26
+ </h1>
27
+ </div>
28
+
29
+
30
+ <div class="privacy-notice">
31
+ <span class="privacy-icon">🔒</span>
32
+ <div class="privacy-content">
33
+ <h4 style="color: white;">Privacy Notice: Anonymous Demo</h4>
34
+ <p style="color: white;">This is an anonymous demonstration. We do not save your PDF file, paper information, or any uploaded content.</p>
35
+ </div>
36
+ </div>
37
+
38
+ </div>
39
+ """)
gradio_app/components/results_panel.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Results panel component for the Review Grounder Gradio app.
3
+
4
+ This module provides the right panel with tabs for displaying
5
+ different stages of the review pipeline results.
6
+ """
7
+
8
+ import gradio as gr
9
+ from typing import Tuple
10
+
11
+
12
+ def create_results_placeholder() -> str:
13
+ """
14
+ Return the HTML for the initial placeholder state.
15
+
16
+ Returns:
17
+ HTML string for the "Ready to Review" placeholder
18
+ """
19
+ return """
20
+ <div class="results-placeholder">
21
+ <div class="results-placeholder-icon">📋</div>
22
+ <h3>Ready to Review</h3>
23
+ <p>Upload your PDF and click "🚀 Generate AI Review" to get comprehensive feedback on your research.</p>
24
+ </div>
25
+ """
26
+
27
+ def create_initial_draft_placeholder() -> str:
28
+ """
29
+ Return the HTML for the initial draft placeholder state.
30
+
31
+ Returns:
32
+ HTML string for the initial draft placeholder
33
+ """
34
+ return """
35
+ <div class="results-placeholder">
36
+ <div class="results-placeholder-icon">📝</div>
37
+ <h3>Initial Draft</h3>
38
+ <p>The draft of the paper review will appear here.</p>
39
+ </div>
40
+ """
41
+
42
+ def create_results_analyzer_placeholder() -> str:
43
+ """
44
+ Return the HTML for the results analyzer placeholder state.
45
+
46
+ Returns:
47
+ HTML string for the results analyzer placeholder
48
+ """
49
+ return """
50
+ <div class="results-placeholder">
51
+ <div class="results-placeholder-icon">📈</div>
52
+ <h3>Results Analyzer</h3>
53
+ <p>The analysis of the paper's experiments and results will appear here.</p>
54
+ </div>
55
+ """
56
+
57
+ def create_insights_miner_placeholder() -> str:
58
+ """
59
+ Return the HTML for the insights miner placeholder state.
60
+
61
+ Returns:
62
+ HTML string for the insights miner placeholder
63
+ """
64
+ return """
65
+ <div class="results-placeholder">
66
+ <div class="results-placeholder-icon">💡</div>
67
+ <h3>Insight Miner</h3>
68
+ <p>The insights retrieved from the paper's content will appear here.</p>
69
+ </div>
70
+ """
71
+
72
+ def create_related_work_placeholder() -> str:
73
+ """
74
+ Return the HTML for the related work placeholder state.
75
+
76
+ Returns:
77
+ HTML string for the related work placeholder
78
+ """
79
+ return """
80
+ <div class="results-placeholder">
81
+ <div class="results-placeholder-icon">📚</div>
82
+ <h3>Related Work</h3>
83
+ <p>The curated research papers and their summaries related to your uploaded paper will appear here.</p>
84
+ </div>
85
+ """
86
+
87
+ def create_final_review_placeholder() -> str:
88
+ """
89
+ Return the HTML for the final review placeholder state.
90
+
91
+ Returns:
92
+ HTML string for the final review placeholder
93
+ """
94
+ return """
95
+ <div class="results-placeholder">
96
+ <div class="results-placeholder-icon">🎯</div>
97
+ <h3>Final Review</h3>
98
+ <p>The final refined review of the paper will appear here. It is the refinement of the initial draft based on the joint information from the results analyzer, insight miner, and related work.</p>
99
+ </div>
100
+ """
101
+
102
+
103
+ def create_results_panel(
104
+ show_log: bool = False,
105
+ show_raw_json: bool = False,
106
+ ) -> Tuple[gr.HTML, gr.HTML, gr.HTML, gr.HTML, gr.Markdown, gr.Textbox, gr.Markdown, gr.Accordion, gr.Tab, gr.DownloadButton]:
107
+ """
108
+ Create the results panel with tabs for different pipeline stages.
109
+
110
+ Returns:
111
+ Tuple including download_json_btn for downloading raw JSON.
112
+ """
113
+ gr.HTML('<div class="panel-title">📝 AI Review Results</div>')
114
+
115
+ # Status log (conditionally visible)
116
+ with gr.Accordion("📋 Pipeline Log", open=False, visible=show_log) as log_accordion:
117
+ status_output = gr.Textbox(
118
+ value="Ready. Upload a PDF and click '🚀 Generate AI Review' to start.",
119
+ lines=8,
120
+ max_lines=20,
121
+ interactive=False,
122
+ autoscroll=True,
123
+ elem_classes=["status-log"],
124
+ show_label=False,
125
+ )
126
+
127
+ # Main results tabs
128
+ with gr.Tabs():
129
+ with gr.Tab("🎯 Final Review", id="final"):
130
+ gr.HTML("""
131
+ <div class="final-review-toolbar">
132
+ <button type="button" class="copy-final-btn" onclick="(function(){
133
+ var el = document.getElementById('final-review-md');
134
+ if (!el) el = document.querySelector('[id=\\'final-review-md\\']');
135
+ if (!el) el = document.querySelector('.final-review-wrap .gr-markdown, .final-review-wrap [class*=\\'markdown\\']');
136
+ var text = el ? (el.innerText || el.textContent || '') : '';
137
+ if (text && navigator.clipboard && navigator.clipboard.writeText) {
138
+ navigator.clipboard.writeText(text).then(function(){ alert('Copied to clipboard'); }).catch(function(){ alert('Copy failed'); });
139
+ } else { alert('Nothing to copy'); }
140
+ })();" title="Copy full review text">📋 Copy to clipboard</button>
141
+ </div>
142
+ """)
143
+ with gr.Group(elem_classes=["final-review-wrap"]):
144
+ final_md = gr.Markdown(
145
+ value=create_results_placeholder(),
146
+ label="Final Review",
147
+ elem_id="final-review-md",
148
+ )
149
+
150
+ with gr.Tab("📝 Initial Draft", id="initial"):
151
+ initial_html = gr.HTML(
152
+ value=create_initial_draft_placeholder(),
153
+ label="Initial Draft",
154
+ )
155
+
156
+ with gr.Tab("📈 Results Analyzer", id="results"):
157
+ results_html = gr.HTML(
158
+ value=create_results_analyzer_placeholder(),
159
+ label="Results Analyzer",
160
+ )
161
+
162
+ with gr.Tab("💡 Insight Miner", id="insights"):
163
+ insights_html = gr.HTML(
164
+ value=create_insights_miner_placeholder(),
165
+ label="Insight Miner",
166
+ )
167
+
168
+ with gr.Tab("📚 Related Work", id="related"):
169
+ related_html = gr.HTML(
170
+ value=create_related_work_placeholder(),
171
+ label="Related Work",
172
+ )
173
+
174
+ # Raw JSON tab (conditionally visible based on toggle)
175
+ with gr.Tab("🔧 Raw JSON", id="raw_json", visible=show_raw_json) as raw_json_tab:
176
+ raw_json_md = gr.Markdown(
177
+ value="Raw JSON output for debugging will appear here.",
178
+ label="Raw JSON",
179
+ )
180
+ download_json_btn = gr.DownloadButton("⬇️ Download as JSON")
181
+
182
+ return (
183
+ initial_html,
184
+ results_html,
185
+ insights_html,
186
+ related_html,
187
+ final_md,
188
+ status_output,
189
+ raw_json_md,
190
+ log_accordion,
191
+ raw_json_tab,
192
+ download_json_btn,
193
+ )
gradio_app/components/settings.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Settings component for the Review Grounder Gradio app.
3
+
4
+ This module provides the advanced settings panel with:
5
+ - LLM endpoint configuration
6
+ - API key input
7
+ - Model name selection
8
+ - Toggle for showing/hiding log and raw JSON
9
+ """
10
+
11
+ import gradio as gr
12
+ from typing import Tuple
13
+
14
+
15
+ # Default values for OpenAI API
16
+ DEFAULT_API_ENDPOINT = "https://api.openai.com/v1"
17
+ DEFAULT_MODEL_NAME = "gpt-4o"
18
+
19
+
20
+ def create_advanced_settings() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox, gr.Checkbox, gr.Checkbox]:
21
+ """
22
+ Create the advanced settings accordion with LLM configuration.
23
+
24
+ The accordion is collapsed by default to hide technical details
25
+ from casual users, while allowing power users to customize.
26
+
27
+ Returns:
28
+ Tuple containing:
29
+ - api_base_url_in: Textbox for LLM endpoint URL
30
+ - api_key_in: Textbox for API key (password masked)
31
+ - model_name_in: Textbox for model name
32
+ - show_log_toggle: Checkbox for showing/hiding log
33
+ - show_raw_json_toggle: Checkbox for showing/hiding raw JSON
34
+ """
35
+ with gr.Accordion(
36
+ "⚙️ Advanced Settings",
37
+ open=False,
38
+ elem_classes=["advanced-settings"]
39
+ ):
40
+ gr.Markdown("""
41
+ Configure your LLM provider and display preferences.
42
+ Leave fields empty to use environment variables.
43
+ """)
44
+
45
+ api_base_url_in = gr.Textbox(
46
+ label="🔗 LLM Endpoint (base_url)",
47
+ placeholder="https://api.openai.com/v1",
48
+ value=DEFAULT_API_ENDPOINT,
49
+ info="The base URL for your LLM API (OpenAI, OpenRouter, local, etc.)",
50
+ )
51
+
52
+ api_key_in = gr.Textbox(
53
+ label="🔑 API Key",
54
+ type="password",
55
+ placeholder="sk-...",
56
+ value="",
57
+ info="Your API key (leave empty to use OPENAI_API_KEY env var)",
58
+ )
59
+
60
+ model_name_in = gr.Textbox(
61
+ label="🤖 Model Name",
62
+ placeholder="gpt-4o",
63
+ value=DEFAULT_MODEL_NAME,
64
+ info="Model identifier (e.g., gpt-4o, gpt-4-turbo, claude-3-opus)",
65
+ )
66
+
67
+ gr.Markdown("---")
68
+ gr.Markdown("**Display Options**")
69
+
70
+ show_log_toggle = gr.Checkbox(
71
+ label="📋 Show Pipeline Log",
72
+ value=False,
73
+ info="Display detailed execution log during processing",
74
+ )
75
+
76
+ show_raw_json_toggle = gr.Checkbox(
77
+ label="📄 Show Raw JSON Output",
78
+ value=False,
79
+ info="Display raw JSON data in results (for debugging)",
80
+ )
81
+
82
+ return api_base_url_in, api_key_in, model_name_in, show_log_toggle, show_raw_json_toggle
gradio_app/components/styles.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom CSS styles for the Review Grounder Gradio app.
3
+
4
+ This module provides all custom CSS needed for the modern UI design,
5
+ including the gradient header, card layouts, buttons, and other visual elements.
6
+ """
7
+
8
+
9
+ def get_custom_css() -> str:
10
+ """
11
+ Return the complete custom CSS for the app.
12
+
13
+ Includes styles for:
14
+ - Gradient header with purple theme
15
+ - Privacy notice banner
16
+ - How it works section with numbered steps
17
+ - Card-based layouts
18
+ - Custom button styling
19
+ - Results panel tabs
20
+ """
21
+ return """
22
+ /* ===== Global Styles ===== */
23
+ .gradio-container {
24
+ max-width: 1400px !important;
25
+ margin: 0 auto !important;
26
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif !important;
27
+ }
28
+
29
+ /* ===== Header Styles ===== */
30
+ .app-header {
31
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
32
+ padding: 20px 30px;
33
+ border-radius: 12px;
34
+ margin-bottom: 20px;
35
+ position: relative;
36
+ overflow: hidden;
37
+ display: flex;
38
+ flex-direction: column;
39
+ width: calc(100% + 24px) !important;
40
+ margin-left: -12px;
41
+ margin-right: -12px;
42
+ align-items: stretch; /* Stretch children to fill container width */
43
+ justify-content: flex-start; /* Align children from top to bottom */
44
+ }
45
+
46
+ .app-header::before {
47
+ content: '';
48
+ position: absolute;
49
+ top: 0;
50
+ left: 0;
51
+ right: 0;
52
+ bottom: 0;
53
+ background-image: radial-gradient(circle at 20% 50%, rgba(255,255,255,0.1) 1px, transparent 1px),
54
+ radial-gradient(circle at 80% 50%, rgba(255,255,255,0.1) 1px, transparent 1px);
55
+ background-size: 40px 40px;
56
+ }
57
+
58
+ .app-header-content {
59
+ position: relative;
60
+ z-index: 1;
61
+ display: flex;
62
+ justify-content: space-between;
63
+ align-items: flex-start;
64
+ gap: 12px;
65
+ }
66
+
67
+
68
+
69
+ .app-title {
70
+ color: white;
71
+ font-size: 1.8em;
72
+ font-weight: 700;
73
+ margin: 0;
74
+ display: flex;
75
+ align-items: center;
76
+ gap: 12px;
77
+ }
78
+
79
+ .beta-badge {
80
+ background: rgba(255,255,255,0.2);
81
+ color: white;
82
+ padding: 4px 12px;
83
+ border-radius: 20px;
84
+ font-size: 0.5em;
85
+ font-weight: 600;
86
+ text-transform: uppercase;
87
+ letter-spacing: 0.5px;
88
+ }
89
+
90
+ .back-button {
91
+ background: rgba(255,255,255,0.2);
92
+ color: white;
93
+ padding: 8px 16px;
94
+ border-radius: 8px;
95
+ text-decoration: none;
96
+ font-size: 0.9em;
97
+ transition: background 0.2s;
98
+ }
99
+
100
+ .back-button:hover {
101
+ background: rgba(255,255,255,0.3);
102
+ }
103
+
104
+ /* ===== Privacy Notice Banner ===== */
105
+ .privacy-notice {
106
+ background: linear-gradient(90deg, #5b4fa8 0%, #7c3aed 100%);
107
+ color: white;
108
+
109
+ padding: 12px 20px;
110
+ border-radius: 8px;
111
+ margin-top: 20px;
112
+ margin-bottom: 20px;
113
+ display: flex;
114
+ align-items: flex-start;
115
+ gap: 12px;
116
+ }
117
+
118
+ .privacy-icon {
119
+ font-size: 1.2em;
120
+ margin-top: 2px;
121
+ }
122
+
123
+ .privacy-content h4 {
124
+ margin: 0 0 4px 0;
125
+ font-size: 1em;
126
+ font-weight: 600;
127
+ }
128
+
129
+ .privacy-content p {
130
+ margin: 0;
131
+ font-size: 0.85em;
132
+ opacity: 0.9;
133
+ }
134
+
135
+ /* ===== Main Panel Cards ===== */
136
+ .panel-card {
137
+ background: white;
138
+ border-radius: 12px;
139
+ padding: 24px;
140
+ box-shadow: 0 2px 8px rgba(0,0,0,0.08);
141
+ border: 1px solid #e5e7eb;
142
+ overflow: hidden;
143
+ }
144
+
145
+ # /* Unify width: Gradio theme often gives HTML/Accordion blocks a max-width.
146
+ # Force all blocks in the panel to span full width like the button. */
147
+ # .panel-card .gr-block,
148
+ # .panel-card .block,
149
+ # .panel-card > div {
150
+ # max-width: none !important;
151
+ # width: 100% !important;
152
+ # box-sizing: border-box;
153
+ # }
154
+
155
+ # .panel-card .how-it-works {
156
+ # width: 100% !important;
157
+ # box-sizing: border-box;
158
+ # }
159
+
160
+ .panel-title {
161
+ font-size: 1.1em;
162
+ font-weight: 600;
163
+ color: #1f2937;
164
+ margin: 0 0 20px 0;
165
+ display: flex;
166
+ align-items: center;
167
+ gap: 8px;
168
+ }
169
+
170
+ /* ===== How It Works Section ===== */
171
+ /* Stretch to same width as button: cancel panel padding with negative margin, then add inner padding */
172
+ .how-it-works {
173
+ background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
174
+ border-radius: 10px;
175
+ padding: 16px 24px;
176
+ margin-left: -12px;
177
+ margin-right: -12px;
178
+ margin-bottom: 20px;
179
+ width: calc(100% + 24px) !important;
180
+ max-width: none;
181
+ box-sizing: border-box;
182
+ }
183
+
184
+ .how-it-works-title {
185
+ color: #d97706;
186
+ font-weight: 600;
187
+ font-size: 0.95em;
188
+ margin-bottom: 12px;
189
+ display: flex;
190
+ align-items: center;
191
+ gap: 6px;
192
+ }
193
+
194
+ .step-item {
195
+ display: flex;
196
+ align-items: flex-start;
197
+ gap: 10px;
198
+ margin-bottom: 8px;
199
+ font-size: 0.9em;
200
+ color: #374151;
201
+ }
202
+
203
+ .step-number {
204
+ background: #f97316;
205
+ color: white;
206
+ width: 20px;
207
+ height: 20px;
208
+ border-radius: 50%;
209
+ display: flex;
210
+ align-items: center;
211
+ justify-content: center;
212
+ font-size: 0.75em;
213
+ font-weight: 600;
214
+ flex-shrink: 0;
215
+ }
216
+
217
+ /* ===== Upload Area ===== */
218
+ .upload-area {
219
+ border: 2px dashed #d1d5db;
220
+ border-radius: 12px;
221
+ padding: 40px 20px;
222
+ text-align: center;
223
+ background: #f9fafb;
224
+ transition: all 0.2s;
225
+ margin-bottom: 20px;
226
+ }
227
+
228
+ .upload-area:hover {
229
+ border-color: #9ca3af;
230
+ background: #f3f4f6;
231
+ }
232
+
233
+ /* Hide default Gradio file upload prompt text ("Click to upload or drag and drop", etc.) */
234
+ .file-upload-minimal .gr-formatted-text,
235
+ .file-upload-minimal .gr-box > div:not([class*="file"]):not([class*="preview"]),
236
+ #pdf-upload .gr-formatted-text,
237
+ #pdf-upload .wrap-inner .gr-formatted-text {
238
+ display: none !important;
239
+ }
240
+
241
+ .upload-icon {
242
+ font-size: 3em;
243
+ color: #9ca3af;
244
+ margin-bottom: 12px;
245
+ }
246
+
247
+ .upload-text {
248
+ color: #6b7280;
249
+ font-size: 0.95em;
250
+ }
251
+
252
+ .upload-link {
253
+ color: #6366f1;
254
+ font-weight: 600;
255
+ cursor: pointer;
256
+ }
257
+
258
+ .upload-hint {
259
+ color: #9ca3af;
260
+ font-size: 0.85em;
261
+ margin-top: 8px;
262
+ }
263
+
264
+ /* ===== Primary Action Button ===== */
265
+ .primary-btn {
266
+ background: linear-gradient(135deg, #fb923c 0%, #f97316 100%) !important;
267
+ color: white !important;
268
+ padding: 14px 28px !important;
269
+ border-radius: 10px !important;
270
+ font-weight: 600 !important;
271
+ font-size: 1em !important;
272
+ border: none !important;
273
+ cursor: pointer !important;
274
+ transition: all 0.2s !important;
275
+ width: 100% !important;
276
+ box-shadow: 0 4px 12px rgba(249, 115, 22, 0.3) !important;
277
+ }
278
+
279
+ .primary-btn:hover {
280
+ transform: translateY(-1px) !important;
281
+ box-shadow: 0 6px 16px rgba(249, 115, 22, 0.4) !important;
282
+ }
283
+
284
+ .primary-btn:disabled {
285
+ opacity: 0.6 !important;
286
+ cursor: not-allowed !important;
287
+ transform: none !important;
288
+ }
289
+
290
+ /* ===== Secondary Buttons (Disabled State) ===== */
291
+ .secondary-btn {
292
+ background: #f3f4f6 !important;
293
+ color: #9ca3af !important;
294
+ padding: 12px 24px !important;
295
+ border-radius: 10px !important;
296
+ font-weight: 500 !important;
297
+ border: 1px solid #e5e7eb !important;
298
+ cursor: not-allowed !important;
299
+ width: 100% !important;
300
+ margin-top: 8px !important;
301
+ }
302
+
303
+ .unavailable-badge {
304
+ background: #e5e7eb;
305
+ color: #9ca3af;
306
+ padding: 2px 8px;
307
+ border-radius: 4px;
308
+ font-size: 0.7em;
309
+ margin-left: 8px;
310
+ text-transform: uppercase;
311
+ }
312
+
313
+ /* ===== Advanced Settings Accordion ===== */
314
+ /* Stretch to same width as button: cancel panel padding, align inner padding with panel */
315
+ .panel-card .advanced-settings,
316
+ .advanced-settings {
317
+ margin-left: -12px !important;
318
+ margin-right: -12px !important;
319
+ width: calc(100% + 24px) !important;
320
+ max-width: none !important;
321
+ box-sizing: border-box;
322
+ }
323
+
324
+ .advanced-settings .label-wrap {
325
+ background: #f9fafb !important;
326
+ border-radius: 8px !important;
327
+ padding: 12px 24px !important;
328
+ }
329
+
330
+ .advanced-settings .label-wrap span {
331
+ font-weight: 500 !important;
332
+ color: #4b5563 !important;
333
+ }
334
+
335
+ /* Accordion content area: same horizontal padding as panel */
336
+ .advanced-settings .gr-group,
337
+ .advanced-settings .wrap {
338
+ padding-left: 24px !important;
339
+ padding-right: 24px !important;
340
+ box-sizing: border-box;
341
+ }
342
+
343
+ /* ===== Results Panel ===== */
344
+ .results-panel {
345
+ background: white;
346
+ border-radius: 12px;
347
+ padding: 24px;
348
+ box-shadow: 0 2px 8px rgba(0,0,0,0.08);
349
+ border: 1px solid #e5e7eb;
350
+ min-height: 500px;
351
+ }
352
+
353
+ .results-placeholder {
354
+ text-align: center;
355
+ padding: 60px 20px;
356
+ color: #9ca3af;
357
+ }
358
+
359
+ .results-placeholder-icon {
360
+ font-size: 4em;
361
+ margin-bottom: 16px;
362
+ }
363
+
364
+ .results-placeholder h3 {
365
+ color: #374151;
366
+ margin: 0 0 8px 0;
367
+ font-weight: 600;
368
+ }
369
+
370
+ .results-placeholder p {
371
+ margin: 0;
372
+ font-size: 0.95em;
373
+ }
374
+
375
+ /* ===== Tab Styling ===== */
376
+ .tabs .tab-nav {
377
+ border-bottom: 2px solid #e5e7eb !important;
378
+ }
379
+
380
+ .tabs .tab-nav button {
381
+ font-weight: 500 !important;
382
+ color: #6b7280 !important;
383
+ padding: 12px 20px !important;
384
+ border: none !important;
385
+ background: transparent !important;
386
+ }
387
+
388
+ .tabs .tab-nav button.selected {
389
+ color: #6366f1 !important;
390
+ border-bottom: 2px solid #6366f1 !important;
391
+ }
392
+
393
+ /* ===== Initial Draft & Final Review content ===== */
394
+ .review-message {
395
+ color: #6b7280;
396
+ font-style: italic;
397
+ padding: 1em;
398
+ }
399
+
400
+ .review-draft-content,
401
+ .initial-draft-card {
402
+ max-width: 100%;
403
+ }
404
+
405
+ .review-text {
406
+ white-space: pre-wrap;
407
+ word-break: break-word;
408
+ line-height: 1.6;
409
+ color: #1f2937;
410
+ }
411
+
412
+ /* Final Review: toolbar with copy button at top right */
413
+ .final-review-toolbar {
414
+ display: flex;
415
+ justify-content: flex-end;
416
+ margin-bottom: 12px;
417
+ }
418
+ .copy-final-btn {
419
+ padding: 8px 16px;
420
+ border-radius: 8px;
421
+ border: 1px solid #e5e7eb;
422
+ background: #f9fafb;
423
+ color: #374151;
424
+ font-size: 0.9em;
425
+ cursor: pointer;
426
+ transition: background 0.2s, color 0.2s;
427
+ }
428
+ .copy-final-btn:hover {
429
+ background: #f3f4f6;
430
+ color: #111827;
431
+ }
432
+
433
+ /* Final Review markdown area: improve readability */
434
+ .results-panel .gr-markdown,
435
+ .results-panel .prose {
436
+ line-height: 1.7 !important;
437
+ color: #1f2937 !important;
438
+ max-width: 100% !important;
439
+ }
440
+
441
+ .results-panel .gr-markdown h1,
442
+ .results-panel .gr-markdown h2,
443
+ .results-panel .gr-markdown h3 {
444
+ margin-top: 1em !important;
445
+ margin-bottom: 0.5em !important;
446
+ color: #111827 !important;
447
+ }
448
+
449
+ .results-panel .gr-markdown ul,
450
+ .results-panel .gr-markdown ol {
451
+ padding-left: 1.5em !important;
452
+ margin: 0.5em 0 !important;
453
+ }
454
+
455
+ .results-panel .gr-markdown p {
456
+ margin: 0.5em 0 !important;
457
+ }
458
+
459
+ /* ===== Card Grid for Results ===== */
460
+ .card-grid {
461
+ display: flex;
462
+ flex-direction: column;
463
+ gap: 12px;
464
+ }
465
+
466
+ .card {
467
+ background: #f8f9fa;
468
+ border: 1px solid #e9ecef;
469
+ border-radius: 10px;
470
+ padding: 14px 16px;
471
+ }
472
+
473
+ .card h4 {
474
+ margin: 0 0 8px 0;
475
+ font-size: 1.05em;
476
+ }
477
+
478
+ .kv {
479
+ margin: 8px 0;
480
+ padding: 10px;
481
+ background: #ffffff;
482
+ border-radius: 8px;
483
+ border: 1px solid #eef1f4;
484
+ }
485
+
486
+ .k {
487
+ font-weight: 650;
488
+ color: #495057;
489
+ margin-bottom: 4px;
490
+ }
491
+
492
+ .v {
493
+ color: #212529;
494
+ line-height: 1.55;
495
+ white-space: pre-wrap;
496
+ word-break: break-word;
497
+ }
498
+
499
+ details {
500
+ background: #ffffff;
501
+ border: 1px solid #eef1f4;
502
+ border-radius: 8px;
503
+ padding: 10px 12px;
504
+ }
505
+
506
+ summary {
507
+ cursor: pointer;
508
+ font-weight: 650;
509
+ color: #212529;
510
+ }
511
+
512
+ .mono {
513
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;
514
+ font-size: 12.5px;
515
+ }
516
+
517
+ .pill {
518
+ display: inline-block;
519
+ padding: 2px 8px;
520
+ border-radius: 999px;
521
+ background: #e7f1ff;
522
+ color: #0b5ed7;
523
+ font-size: 12px;
524
+ margin-left: 8px;
525
+ }
526
+
527
+ /* ===== Related Work Cards ===== */
528
+ .related-work-container {
529
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
530
+ max-width: 100%;
531
+ }
532
+
533
+ .related-paper-card {
534
+ background: #f8f9fa;
535
+ border-left: 4px solid #007bff;
536
+ border-radius: 6px;
537
+ padding: 16px;
538
+ margin-bottom: 16px;
539
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
540
+ transition: transform 0.2s, box-shadow 0.2s;
541
+ }
542
+
543
+ .related-paper-card:hover {
544
+ transform: translateY(-2px);
545
+ box-shadow: 0 4px 8px rgba(0,0,0,0.15);
546
+ }
547
+
548
+ .paper-header {
549
+ font-weight: 600;
550
+ font-size: 1.1em;
551
+ color: #212529;
552
+ margin-bottom: 12px;
553
+ }
554
+
555
+ .paper-field {
556
+ margin: 8px 0;
557
+ padding: 8px;
558
+ background: white;
559
+ border-radius: 4px;
560
+ }
561
+
562
+ .paper-field-label {
563
+ font-weight: 600;
564
+ color: #495057;
565
+ font-size: 0.9em;
566
+ margin-bottom: 4px;
567
+ }
568
+
569
+ .paper-field-value {
570
+ color: #212529;
571
+ line-height: 1.5;
572
+ }
573
+
574
+ /* ===== Log/Status Text Area ===== */
575
+ .status-log textarea {
576
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important;
577
+ font-size: 0.85em !important;
578
+ background: #1f2937 !important;
579
+ color: #10b981 !important;
580
+ border-radius: 8px !important;
581
+ }
582
+
583
+ /* ===== Footer ===== */
584
+ .app-footer {
585
+ text-align: center;
586
+ padding: 20px;
587
+ color: #9ca3af;
588
+ font-size: 0.85em;
589
+ border-top: 1px solid #e5e7eb;
590
+ margin-top: 30px;
591
+ }
592
+ """
gradio_app/components/upload_section.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload section component for the Review Grounder Gradio app.
3
+
4
+ This module provides the left panel with:
5
+ - "How it works" instructions
6
+ - PDF upload area
7
+ - Action buttons
8
+ """
9
+
10
+ import gradio as gr
11
+ from typing import Tuple
12
+
13
+
14
+ def create_how_it_works() -> None:
15
+ """
16
+ Create the "How it works" instruction section.
17
+
18
+ Displays a numbered list of steps explaining the review process.
19
+ """
20
+ gr.HTML("""
21
+ <div class="how-it-works">
22
+ <div class="how-it-works-title">
23
+ ⚡ How it works
24
+ </div>
25
+ <div class="step-item">
26
+ <span class="step-number">1</span>
27
+ <span>Upload your research paper in PDF format</span>
28
+ </div>
29
+ <div class="step-item">
30
+ <span class="step-number">2</span>
31
+ <span>Configure your LLM settings (or use defaults)</span>
32
+ </div>
33
+ <div class="step-item">
34
+ <span class="step-number">3</span>
35
+ <span>Click "🚀 Generate AI Review" to start</span>
36
+ </div>
37
+ <div class="step-item">
38
+ <span class="step-number">4</span>
39
+ <span>Watch as our AI analyzes your paper in real-time</span>
40
+ </div>
41
+ <div class="step-item">
42
+ <span class="step-number">5</span>
43
+ <span>Get comprehensive, grounded feedback on your research</span>
44
+ </div>
45
+ </div>
46
+ """)
47
+
48
+
49
+ def create_upload_area() -> gr.File:
50
+ """
51
+ Create the PDF upload area component.
52
+
53
+ Returns:
54
+ gr.File: The file upload component for PDF files
55
+ """
56
+ pdf_input = gr.File(
57
+ label="",
58
+ file_types=[".pdf"],
59
+ type="filepath",
60
+ elem_classes=["upload-area", "file-upload-minimal"],
61
+ elem_id="pdf-upload",
62
+ show_label=False,
63
+ )
64
+
65
+ gr.HTML("""
66
+ <div style="text-align: center; color: #9ca3af; font-size: 0.85em; margin-top: -10px; margin-bottom: 15px;">
67
+ PDF files only (max 10MB)
68
+ </div>
69
+ """)
70
+
71
+ return pdf_input
72
+
73
+
74
+ def create_action_buttons() -> gr.Button:
75
+ """
76
+ Create the action buttons for starting the review.
77
+
78
+ Returns:
79
+ gr.Button: The primary "Generate Review" button
80
+ """
81
+ run_button = gr.Button(
82
+ "🚀 Generate AI Review",
83
+ variant="primary",
84
+ elem_classes=["primary-btn"],
85
+ )
86
+
87
+ # # Placeholder buttons for future features (disabled)
88
+ # gr.HTML("""
89
+ # <button class="secondary-btn" disabled>
90
+ # 📊 DeepReviewer
91
+ # <span class="unavailable-badge">COMING SOON</span>
92
+ # </button>
93
+ # <button class="secondary-btn" disabled>
94
+ # 🛡️ SafeReviewer
95
+ # <span class="unavailable-badge">COMING SOON</span>
96
+ # </button>
97
+ # """)
98
+
99
+ return run_button
100
+
101
+
102
+ def create_upload_section() -> Tuple[gr.File, gr.Button]:
103
+ """
104
+ Create the complete upload section panel.
105
+
106
+ Returns:
107
+ Tuple containing:
108
+ - pdf_input: The file upload component
109
+ - run_button: The generate review button
110
+ """
111
+ gr.HTML('<div class="panel-title">📁 Upload Your Paper</div>')
112
+
113
+ create_how_it_works()
114
+ pdf_input = create_upload_area()
115
+ run_button = create_action_buttons()
116
+
117
+ return pdf_input, run_button
gradio_app/utils_single_paper_inference.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility wrappers and a minimal CLI for single-paper inference using:
3
+ - ASTA API key from environment variable `ASTA_API_KEY`
4
+ - OpenAI/OpenRouter endpoint and key from environment variables
5
+
6
+ This module is designed to be imported by Gradio or other web frontends,
7
+ while still remaining executable as a standalone CLI tool for debugging.
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import sys
13
+ from pathlib import Path
14
+ import json
15
+ from typing import Any, Dict, Iterator, Optional
16
+
17
+ PROJECT_ROOT = Path(__file__).parent.parent
18
+ if str(PROJECT_ROOT) not in sys.path:
19
+ sys.path.insert(0, str(PROJECT_ROOT))
20
+
21
+ from src.reviewer_agent.single_paper_inference import (
22
+ extract_text_from_pdf,
23
+ _split_paper_latex_sections,
24
+ _init_single_paper_pipeline,
25
+ )
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def run_single_paper_review_from_pdf(
32
+ pdf_path: str,
33
+ *,
34
+ enable_logging: bool = True,
35
+ verbose: bool = True,
36
+ api_base_url: str | None = None,
37
+ api_key: str | None = None,
38
+ model_name: str | None = None,
39
+ ) -> dict:
40
+ """
41
+ High-level utility to run the single-paper review pipeline on a PDF path.
42
+
43
+ This is the main entry point intended to be called by Gradio or other UIs.
44
+ It delegates to `review_single_paper_from_pdf` which uses:
45
+ - ASTA API key from `ASTA_API_KEY`
46
+ - LLM settings and OpenAI/OpenRouter keys from environment/config files,
47
+ but can be overridden via `api_base_url`, `api_key`, and `model_name`.
48
+ """
49
+ pdf_path = str(Path(pdf_path).expanduser())
50
+ logger.info(f"Running single-paper review for PDF: {pdf_path}")
51
+ # Keep the original one-shot behavior for backward compatibility.
52
+ # For true streaming updates, use `run_single_paper_review_from_pdf_stepwise`.
53
+ from src.reviewer_agent.single_paper_inference import review_single_paper_from_pdf
54
+
55
+ review = review_single_paper_from_pdf(
56
+ pdf_path,
57
+ enable_logging=enable_logging,
58
+ verbose=verbose,
59
+ gpt_api_key=api_key,
60
+ gpt_base_url=api_base_url,
61
+ gpt_model_name=model_name,
62
+ )
63
+ return review
64
+
65
+
66
+ def _normalize_base_url(base_url: Optional[str]) -> Optional[str]:
67
+ """
68
+ Normalize an OpenAI-compatible base_url.
69
+
70
+ Your local gateway expects requests at:
71
+ http://localhost:8000/chat/completions
72
+
73
+ The OpenAI client will append `/chat/completions` to whatever `base_url`
74
+ we pass in. That means:
75
+ base_url = "http://localhost:8000"
76
+ -> "http://localhost:8000/chat/completions" ✅
77
+
78
+ If a user accidentally includes `/chat/completions` in the textbox, we
79
+ strip that suffix so the final URL is still correct.
80
+ """
81
+ if not base_url:
82
+ return None
83
+ u = base_url.strip()
84
+ if not u:
85
+ return None
86
+
87
+ # Strip trailing slash for normalization
88
+ u = u.rstrip("/")
89
+
90
+ # If user pasted the full path (…/chat/completions), strip it back to the host.
91
+ if u.endswith("/chat/completions"):
92
+ u = u[: -len("/chat/completions")]
93
+
94
+ return u
95
+
96
+
97
+ def run_single_paper_review_from_pdf_stepwise(
98
+ pdf_path: str,
99
+ *,
100
+ enable_logging: bool = True,
101
+ verbose: bool = True,
102
+ api_base_url: str | None = None,
103
+ api_key: str | None = None,
104
+ model_name: str | None = None,
105
+ ) -> Iterator[Dict[str, Any]]:
106
+ """
107
+ Stepwise (streamable) single-paper pipeline.
108
+
109
+ Yields dict events like:
110
+ {"stage": "extract_pdf", ...}
111
+ {"stage": "initial_review", "initial_review": {...}}
112
+ {"stage": "results_analysis", "results_analyzer_json": "..."}
113
+ {"stage": "insights", "insight_miner_json": "..."}
114
+ {"stage": "related_work", "related_work_json_list": [...], "search_keywords": [...]}
115
+ {"stage": "final", "review": {...}}
116
+ """
117
+ pdf_path = str(Path(pdf_path).expanduser())
118
+ yield {"stage": "extract_pdf", "pdf_path": pdf_path}
119
+
120
+ paper_text = extract_text_from_pdf(pdf_path)
121
+ yield {"stage": "parsed_pdf_text", "text_len": len(paper_text)}
122
+
123
+ sections = _split_paper_latex_sections(paper_text)
124
+ title = (sections.get("title") or "").strip()
125
+ abstract = (sections.get("abstract") or "").strip()
126
+ content = (sections.get("content") or "").strip()
127
+ yield {"stage": "parsed_sections", "title": title, "abstract": abstract}
128
+
129
+ reviewer, refiner, related_work_searcher, paper_results_analyzer, paper_insight_miner = (
130
+ _init_single_paper_pipeline(
131
+ enable_logging=enable_logging,
132
+ use_test_llm=False,
133
+ gpt_api_key=api_key,
134
+ gpt_base_url=api_base_url,
135
+ gpt_model_name=model_name,
136
+ )
137
+ )
138
+
139
+ # Step 1: initial draft (reviewer)
140
+ initial_review = reviewer.review_paper(
141
+ title=title,
142
+ abstract=abstract,
143
+ content=content,
144
+ keywords=None,
145
+ review_format="ai_researcher",
146
+ auto_save_log=False,
147
+ verbose=verbose,
148
+ )
149
+ yield {"stage": "initial_review", "initial_review": initial_review}
150
+
151
+ # Helper: format initial review for analyzers.
152
+ try:
153
+ initial_review_text = (
154
+ refiner._format_review_dict(initial_review, "detailed")
155
+ if hasattr(refiner, "_format_review_dict")
156
+ else str(initial_review)
157
+ )
158
+ except Exception:
159
+ initial_review_text = str(initial_review)
160
+
161
+ # Step 2a: results analyzer
162
+ results_analyzer_json = None
163
+ if paper_results_analyzer and content:
164
+ try:
165
+ results_analyzer_json = paper_results_analyzer.analyze_paper_results(
166
+ content, initial_review_text
167
+ )
168
+ except Exception as e:
169
+ results_analyzer_json = None
170
+ yield {"stage": "results_analysis_error", "error": str(e)}
171
+ yield {"stage": "results_analysis", "results_analyzer_json": results_analyzer_json}
172
+
173
+ # Step 2b: insight miner
174
+ insight_miner_json = None
175
+ if paper_insight_miner and content:
176
+ try:
177
+ insight_miner_json = paper_insight_miner.mine_paper_insights(
178
+ content, initial_review_text
179
+ )
180
+ except Exception as e:
181
+ insight_miner_json = None
182
+ yield {"stage": "insights_error", "error": str(e)}
183
+ yield {"stage": "insights", "insight_miner_json": insight_miner_json}
184
+
185
+ # Step 2c: related work (structured list)
186
+ related_work_list = []
187
+ search_keywords = None
188
+ if related_work_searcher:
189
+ try:
190
+ related_work_list = related_work_searcher.generate_related_work_json_list(
191
+ title=title,
192
+ abstract=abstract,
193
+ content=content,
194
+ keywords=None,
195
+ publication_date_range=None,
196
+ venues=None,
197
+ )
198
+ search_keywords = getattr(related_work_searcher, "last_keywords", None)
199
+ except Exception as e:
200
+ related_work_list = []
201
+ yield {"stage": "related_work_error", "error": str(e)}
202
+ yield {
203
+ "stage": "related_work",
204
+ "related_work_json_list": related_work_list,
205
+ "search_keywords": search_keywords,
206
+ }
207
+
208
+ # Step 3: refine (final)
209
+ related_work_json_str = json.dumps(related_work_list, ensure_ascii=False)
210
+ refined = refiner.refine_review(
211
+ initial_review=initial_review,
212
+ insight_miner_json=insight_miner_json,
213
+ results_analyzer_json=results_analyzer_json,
214
+ related_work_json_list=related_work_json_str,
215
+ title=title,
216
+ abstract=abstract,
217
+ content=content,
218
+ review_format="detailed",
219
+ verbose=verbose,
220
+ )
221
+ if search_keywords is not None:
222
+ refined["search_keywords"] = search_keywords
223
+
224
+ yield {"stage": "final", "review": refined}
225
+
226
+
227
+ def main() -> None:
228
+ """
229
+ Simple CLI wrapper mainly for local debugging.
230
+
231
+ Example:
232
+ python -m gradio.test_single_paper_openrouter \\
233
+ --pdf /path/to/paper.pdf
234
+ """
235
+ parser = argparse.ArgumentParser(
236
+ description="Run single-paper review on a PDF file."
237
+ )
238
+ parser.add_argument(
239
+ "--pdf",
240
+ type=str,
241
+ required=True,
242
+ help="Path to the PDF file to review.",
243
+ )
244
+ parser.add_argument(
245
+ "--no-logging",
246
+ action="store_true",
247
+ help="Disable on-disk logging for this run.",
248
+ )
249
+ parser.add_argument(
250
+ "--quiet",
251
+ action="store_true",
252
+ help="Reduce console output verbosity.",
253
+ )
254
+ args = parser.parse_args()
255
+
256
+ from time import time
257
+
258
+ start_time = time()
259
+ review = run_single_paper_review_from_pdf(
260
+ args.pdf,
261
+ enable_logging=not args.no_logging,
262
+ verbose=not args.quiet,
263
+ )
264
+ end_time = time()
265
+
266
+ print(f"Time taken: {end_time - start_time:.2f} seconds")
267
+ print("\n=== Review keys ===")
268
+ print(list(review.keys()))
269
+ print("\n=== Review Markdown ===")
270
+ print(review.get("review_markdown", ""))
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main()
275
+
276
+
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unified Review System Requirements
2
+
3
+ # Core dependencies
4
+ pandas>=1.5.0
5
+ pyarrow>=10.0.0
6
+ pyyaml>=6.0.0
7
+ python-dotenv>=1.0.0
8
+
9
+ # LLM and ML dependencies
10
+ vllm>=0.6.0
11
+ openai>=1.0.0
12
+ transformers>=4.35.0
13
+ torch>=2.0.0
14
+ numpy>=1.24.0
15
+ FlagEmbedding>=1.2.0
16
+
17
+ # Utilities
18
+ requests>=2.31.0
19
+ pydantic>=2.0.0
20
+ tqdm>=4.66.0
21
+ scipy>=1.10.0
22
+ scikit-learn>=1.3.0
23
+
24
+ # PDF + UI (for Gradio / Hugging Face Space)
25
+ pdfminer.six>=20221105
26
+ gradio>=4.0.0
scripts/gpt_oss_start_vllm_service.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script to start vLLM service for Qwen3-235B-A22B-Instruct-2507
2
+
3
+ # optional: limit GPU usage
4
+ # export CUDA_VISIBLE_DEVICES=0,1,2,3
5
+ export CUDA_VISIBLE_DEVICES=4,5,6,7
6
+
7
+ # Configuration
8
+ # MODEL_NAME="Qwen/Qwen3-235B-A22B-Instruct-2507"
9
+ MODEL_NAME="openai/gpt-oss-120b"
10
+ PORT=${VLLM_PORT:-8000}
11
+ TP_SIZE=${TP_SIZE:-4} # Tensor parallelism size, smaller or equal to the number of available GPUs
12
+ GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.85} # ideally 0.85
13
+ MAX_MODEL_LEN=${MAX_MODEL_LEN:-131072} # Native context length, can extend to 1010000
14
+
15
+ # Check if model path is provided
16
+ if [ -z "$MODEL_PATH" ]; then
17
+ MODEL_PATH="$MODEL_NAME"
18
+ echo "Using HuggingFace model: $MODEL_PATH"
19
+ else
20
+ echo "Using local model: $MODEL_PATH"
21
+ fi
22
+
23
+ echo "Starting vLLM service..."
24
+ echo "Model: $MODEL_PATH"
25
+ echo "Port: $PORT"
26
+ echo "Tensor Parallelism: $TP_SIZE"
27
+ echo "GPU Memory Utilization: $GPU_MEMORY_UTILIZATION"
28
+ echo "Max Model Length: $MAX_MODEL_LEN"
29
+
30
+ # python3 -m vllm.entrypoints.openai.api_server \
31
+ # --model "$MODEL_PATH" \
32
+ # --port $PORT \
33
+ # --tensor-parallel-size $TP_SIZE \
34
+ # --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
35
+ # --max-model-len $MAX_MODEL_LEN \
36
+ # --trust-remote-code \
37
+ # # --dtype bfloat16
38
+
39
+
40
+ vllm serve openai/gpt-oss-120b \
41
+ --port $PORT \
42
+ --tensor-parallel-size $TP_SIZE \
43
+ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
44
+ --max-model-len $MAX_MODEL_LEN \
45
+ --trust-remote-code \
46
+ --dtype bfloat16
47
+
scripts/start_load_balancer.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Start Python Load Balancer for vLLM and Reranker services
3
+ # Usage: ./scripts/start_load_balancer.sh [service_type] [num_instances] [base_port] [lb_port]
4
+
5
+ set -e
6
+
7
+ SERVICE_TYPE="${1:-vllm}" # vllm or reranker
8
+ NUM_INSTANCES="${2:-4}"
9
+ BASE_PORT="${3:-8000}"
10
+ LB_PORT="${4:-$BASE_PORT}"
11
+
12
+ echo "Starting Load Balancer for $SERVICE_TYPE"
13
+ echo "Number of instances: $NUM_INSTANCES"
14
+ echo "Base port: $BASE_PORT"
15
+ echo "Load balancer port: $LB_PORT"
16
+ echo ""
17
+
18
+ # Get script directory
19
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
20
+ PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
21
+ cd "$PROJECT_ROOT"
22
+
23
+ # Activate virtual environment if it exists
24
+ if [ -d ".venv" ]; then
25
+ source .venv/bin/activate
26
+ fi
27
+
28
+ # Check if FastAPI is installed
29
+ python3 -c "import fastapi" 2>/dev/null || {
30
+ echo "Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx"
31
+ exit 1
32
+ }
33
+
34
+ # Build backend list
35
+ BACKENDS=()
36
+ for i in $(seq 0 $((NUM_INSTANCES - 1))); do
37
+ PORT=$((BASE_PORT + i))
38
+ if [ "$SERVICE_TYPE" = "vllm" ]; then
39
+ BACKENDS+=("http://localhost:${PORT}/v1")
40
+ else
41
+ BACKENDS+=("http://localhost:${PORT}")
42
+ fi
43
+ done
44
+
45
+ # Create logs directory based on service type
46
+ if [ "$SERVICE_TYPE" = "vllm" ]; then
47
+ LB_LOG_DIR="logs/vllm"
48
+ else
49
+ LB_LOG_DIR="logs/reranker"
50
+ fi
51
+ mkdir -p "$LB_LOG_DIR"
52
+
53
+ echo "Backends:"
54
+ for backend in "${BACKENDS[@]}"; do
55
+ echo " - $backend"
56
+ done
57
+ echo ""
58
+
59
+ # Start load balancer
60
+ echo "Starting load balancer..."
61
+ python3 -m shared.utils.load_balancer \
62
+ --backends "${BACKENDS[@]}" \
63
+ --host 0.0.0.0 \
64
+ --port "$LB_PORT" \
65
+ --strategy round_robin \
66
+ --health-check-interval 10.0 \
67
+ > "${LB_LOG_DIR}/load_balancer_${SERVICE_TYPE}_port${LB_PORT}.log" 2>&1 &
68
+
69
+ LB_PID=$!
70
+
71
+ # Save PID to file based on service type
72
+ if [ "$SERVICE_TYPE" = "vllm" ]; then
73
+ PID_FILE="logs/vllm/vllm_lb_pid.txt"
74
+ mkdir -p logs/vllm
75
+ else
76
+ PID_FILE="logs/reranker/reranker_lb_pid.txt"
77
+ mkdir -p logs/reranker
78
+ fi
79
+ echo "$LB_PID" > "$PID_FILE"
80
+
81
+ echo "Load balancer started with PID: $LB_PID"
82
+ echo "Load balancer URL: http://localhost:${LB_PORT}"
83
+ echo "PID saved to: $PID_FILE"
84
+ echo ""
85
+ echo "To check status: curl http://localhost:${LB_PORT}/health"
86
+ echo "To stop: ./scripts/stop_vllm_services.sh (for vllm) or ./scripts/stop_reranker_services.sh (for reranker)"
87
+ echo "Or manually: kill $LB_PID"
scripts/start_reranker_service.sh ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Start Reranker API Service on multiple GPUs
3
+ # Usage: ./scripts/start_reranker_service.sh [model_path] [num_gpus] [base_port]
4
+
5
+ set -e
6
+
7
+ # Default values
8
+ # MODEL_PATH="${1:-OpenScholar/OpenScholar_Reranker}"
9
+ # NUM_GPUS="${2:-4}"
10
+ # BASE_PORT="${3:-8005}"
11
+
12
+ # MODEL_PATH="BAAI/bge-reranker-base"
13
+ # MODEL_PATH="BAAI/bge-reranker-large"
14
+ MODEL_PATH="${1:-OpenScholar/OpenScholar_Reranker}"
15
+ NUM_GPUS=8
16
+ BASE_PORT=8008
17
+
18
+ echo "Starting Reranker API Service"
19
+ echo "Model: $MODEL_PATH"
20
+ echo "Number of GPUs: $NUM_GPUS"
21
+ echo "Base port: $BASE_PORT"
22
+ echo ""
23
+
24
+ # Get script directory
25
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
26
+ PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
27
+ cd "$PROJECT_ROOT"
28
+
29
+ # Activate virtual environment if it exists
30
+ if [ -d ".venv" ]; then
31
+ source .venv/bin/activate
32
+ fi
33
+
34
+ # Check if FastAPI is installed
35
+ python3 -c "import fastapi" 2>/dev/null || {
36
+ echo "Error: FastAPI not installed. Install with: pip install fastapi uvicorn"
37
+ exit 1
38
+ }
39
+
40
+ # Check if FlagEmbedding is installed
41
+ python3 -c "from FlagEmbedding import FlagReranker" 2>/dev/null || {
42
+ echo "Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding"
43
+ exit 1
44
+ }
45
+
46
+ # Create logs directory
47
+ mkdir -p logs
48
+
49
+ # PID file for stopping services later
50
+ PID_FILE="logs/reranker/reranker_pids.txt"
51
+ LB_PID_FILE="logs/reranker/reranker_lb_pid.txt"
52
+
53
+ # Start services on each GPU
54
+ PIDS=()
55
+ ENDPOINTS=()
56
+
57
+ # Ensure we use GPUs 0, 1, 2, 3 (explicitly)
58
+
59
+ # now we use gpus: 1,2,3,4,5,6,7
60
+ for i in $(seq 0 $((NUM_GPUS - 1))); do
61
+ PORT=$((BASE_PORT + i))
62
+ GPU_ID=$i # Use GPU 0, 1, 2, 3 explicitly
63
+
64
+ echo "Starting reranker service on GPU $GPU_ID, port $PORT..."
65
+
66
+ # Set CUDA device (each service will see only one GPU)
67
+ export CUDA_VISIBLE_DEVICES=$GPU_ID
68
+
69
+ # Start service in background
70
+ # Note: When CUDA_VISIBLE_DEVICES is set, cuda:0 refers to the visible GPU
71
+ nohup python3 -m shared.utils.reranker_api_service \
72
+ --model_path "$MODEL_PATH" \
73
+ --host 0.0.0.0 \
74
+ --port "$PORT" \
75
+ --use_fp16 \
76
+ --device "cuda:0" \
77
+ > "logs/reranker/reranker_service_gpu${GPU_ID}_port${PORT}.log" 2>&1 &
78
+
79
+ PID=$!
80
+ PIDS+=($PID)
81
+ ENDPOINTS+=("http://localhost:${PORT}")
82
+
83
+ echo " Started with PID: $PID"
84
+ echo " Endpoint: http://localhost:${PORT}"
85
+ sleep 2 # Give service time to start
86
+ done
87
+
88
+ echo ""
89
+ echo "All reranker services started!"
90
+ echo ""
91
+ echo "Endpoints:"
92
+ for endpoint in "${ENDPOINTS[@]}"; do
93
+ echo " - $endpoint"
94
+ done
95
+
96
+ # Create endpoint pool file
97
+ ENDPOINT_POOL_FILE="shared/configs/reranker_endpoint_pool.txt"
98
+ mkdir -p "$(dirname "$ENDPOINT_POOL_FILE")"
99
+ printf "%s\n" "${ENDPOINTS[@]}" > "$ENDPOINT_POOL_FILE"
100
+ echo ""
101
+ echo "Endpoint pool file created: $ENDPOINT_POOL_FILE"
102
+
103
+ # Save PIDs to file (one per line)
104
+ printf "%s\n" "${PIDS[@]}" > "$PID_FILE"
105
+ echo ""
106
+ echo "PIDs saved to: $PID_FILE"
107
+ echo ""
108
+ echo "To stop these specific reranker services, run:"
109
+ echo " ./scripts/stop_reranker_services.sh"
110
+ echo ""
111
+ echo "This will only kill the processes listed above, not other reranker services."
112
+ echo ""
113
+ echo "To check service status, run:"
114
+ for endpoint in "${ENDPOINTS[@]}"; do
115
+ echo "curl $endpoint/health"
116
+ done
scripts/start_vllm_with_balancer.sh ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to start vLLM services on GPU 4,5,6,7 and load balancer
3
+ # Usage: ./scripts/start_vllm_with_balancer.sh
4
+
5
+ set -e
6
+
7
+ # Get script directory
8
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
9
+ PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
10
+ cd "$PROJECT_ROOT"
11
+
12
+ # Configuration
13
+ # MODEL_NAME="ZhuofengLi/Qwen3-4B-Instruct-2507-DeepReview-lora-sft" #
14
+ MODEL_NAME="openai/gpt-oss-120b"
15
+ # GPU_CONFIG="2:8001,3:8002,4:8003,5:8004" # GPU:PORT pairs
16
+ # GPU_CONFIG="1:8001,2:8002,3:8003,4:8004, 5:8005, 6:8006, 7:8007, 0:7999" # GPU:PORT pairs
17
+ GPU_CONFIG="0:8001,1:8002,2:8003,3:8004, 5:8005, 6:8006, 7:8007" # GPU:PORT pairs
18
+ TP_SIZE=1 # Tensor parallelism size per instance
19
+ GPU_MEMORY_UTILIZATION=0.85
20
+ MAX_MODEL_LEN=131072
21
+
22
+ # Load balancer configuration
23
+ LB_PORT=8000 # Load balancer port
24
+ LB_STRATEGY="round_robin" # or "least_conn"
25
+ LB_HEALTH_CHECK_INTERVAL=10.0
26
+
27
+ # Log directory
28
+ LOG_DIR="./logs/vllm"
29
+ mkdir -p "$LOG_DIR"
30
+
31
+ # Endpoint pool file
32
+ ENDPOINT_POOL_FILE="shared/configs/vllm_endpoint_pool.txt"
33
+ mkdir -p "$(dirname "$ENDPOINT_POOL_FILE")"
34
+
35
+ echo "=========================================="
36
+ echo "Starting vLLM Services + Load Balancer"
37
+ echo "=========================================="
38
+ echo "Model: $MODEL_NAME"
39
+ echo "GPU Configuration: $GPU_CONFIG"
40
+ echo "Load Balancer Port: $LB_PORT"
41
+ echo "Log Directory: $LOG_DIR"
42
+ echo ""
43
+
44
+ # Step 1: Start vLLM services
45
+ echo "=== Step 1: Starting vLLM services ==="
46
+ echo ""
47
+
48
+ # Clear existing endpoints
49
+ > "$ENDPOINT_POOL_FILE"
50
+
51
+ # Parse GPU configuration
52
+ IFS=',' read -ra GPU_CONFIGS <<< "$GPU_CONFIG"
53
+
54
+ # Array to store PIDs
55
+ VLLM_PIDS=()
56
+
57
+ for gpu_config in "${GPU_CONFIGS[@]}"; do
58
+ IFS=':' read -r gpu_id port <<< "$gpu_config"
59
+
60
+ echo "Starting vLLM on GPU $gpu_id, port $port..."
61
+
62
+ # Set CUDA_VISIBLE_DEVICES for this specific GPU
63
+ export CUDA_VISIBLE_DEVICES=$gpu_id
64
+
65
+ # Log file
66
+ LOG_FILE="$LOG_DIR/vllm_gpu${gpu_id}_port${port}.log"
67
+
68
+ # Start vLLM service in background
69
+ (
70
+ echo "=== GPU $gpu_id, Port $port ===" >> "$LOG_FILE"
71
+ echo "Starting at $(date)" >> "$LOG_FILE"
72
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" >> "$LOG_FILE"
73
+ echo "" >> "$LOG_FILE"
74
+
75
+ vllm serve "$MODEL_NAME" \
76
+ --port "$port" \
77
+ --tensor-parallel-size "$TP_SIZE" \
78
+ --gpu-memory-utilization "$GPU_MEMORY_UTILIZATION" \
79
+ --max-model-len "$MAX_MODEL_LEN" \
80
+ --trust-remote-code \
81
+ --dtype bfloat16 \
82
+ >> "$LOG_FILE" 2>&1
83
+ ) &
84
+
85
+ PID=$!
86
+ VLLM_PIDS+=($PID)
87
+
88
+ # Add endpoint to pool file (for load balancer)
89
+ echo "http://localhost:$port/v1" >> "$ENDPOINT_POOL_FILE"
90
+
91
+ echo " -> Started with PID $PID"
92
+ echo " -> Endpoint: http://localhost:$port/v1"
93
+ echo " -> Log: $LOG_FILE"
94
+
95
+ # Wait a bit before starting next service
96
+ sleep 3
97
+ done
98
+
99
+ # Save PIDs (one per line for easier parsing)
100
+ printf "%s\n" "${VLLM_PIDS[@]}" > "$LOG_DIR/vllm_pids.txt"
101
+ echo ""
102
+ echo "vLLM service PIDs saved to: $LOG_DIR/vllm_pids.txt"
103
+ echo ""
104
+
105
+ # Step 2: Wait for services to be ready
106
+ echo "=== Step 2: Waiting for vLLM services to be ready ==="
107
+ echo "Waiting 90 seconds for services to initialize..."
108
+ sleep 90
109
+
110
+ # Check service health
111
+ echo ""
112
+ echo "Checking service health..."
113
+ HEALTHY_COUNT=0
114
+ for gpu_config in "${GPU_CONFIGS[@]}"; do
115
+ IFS=':' read -r gpu_id port <<< "$gpu_config"
116
+ if curl -s "http://localhost:$port/v1/models" > /dev/null 2>&1; then
117
+ echo " GPU $gpu_id (port $port): HEALTHY"
118
+ HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
119
+ else
120
+ echo " GPU $gpu_id (port $port): NOT READY (may still be initializing)"
121
+ fi
122
+ done
123
+
124
+ if [ $HEALTHY_COUNT -eq 0 ]; then
125
+ echo ""
126
+ echo "WARNING: No services are healthy yet. They may still be loading the model."
127
+ echo "You can check logs in $LOG_DIR/ for progress."
128
+ fi
129
+
130
+ echo ""
131
+
132
+ # Step 3: Start load balancer
133
+ echo "=== Step 3: Starting Load Balancer ==="
134
+ echo ""
135
+
136
+ # Build backend URLs
137
+ BACKEND_URLS=()
138
+ for gpu_config in "${GPU_CONFIGS[@]}"; do
139
+ IFS=':' read -r gpu_id port <<< "$gpu_config"
140
+ BACKEND_URLS+=("http://localhost:$port/v1")
141
+ done
142
+
143
+ echo "Load Balancer Configuration:"
144
+ echo " Port: $LB_PORT"
145
+ echo " Strategy: $LB_STRATEGY"
146
+ echo " Backends: ${BACKEND_URLS[*]}"
147
+ echo ""
148
+
149
+ # Activate virtual environment if it exists
150
+ if [ -d ".venv" ]; then
151
+ source .venv/bin/activate
152
+ fi
153
+
154
+ # Check if FastAPI is installed
155
+ python3 -c "import fastapi" 2>/dev/null || {
156
+ echo "Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx"
157
+ exit 1
158
+ }
159
+
160
+ # Start load balancer in background
161
+ echo "Starting load balancer..."
162
+ nohup python3 -m shared.utils.load_balancer \
163
+ --backends "${BACKEND_URLS[@]}" \
164
+ --host 0.0.0.0 \
165
+ --port "$LB_PORT" \
166
+ --strategy "$LB_STRATEGY" \
167
+ --health-check-interval "$LB_HEALTH_CHECK_INTERVAL" \
168
+ > "$LOG_DIR/load_balancer_port${LB_PORT}.log" 2>&1 &
169
+
170
+ LB_PID=$!
171
+
172
+ # Save load balancer PID
173
+ echo "$LB_PID" > "$LOG_DIR/vllm_lb_pid.txt"
174
+
175
+ echo " -> Load balancer started with PID $LB_PID"
176
+ echo " -> Endpoint: http://localhost:$LB_PORT"
177
+ echo " -> Log: $LOG_DIR/load_balancer_port${LB_PORT}.log"
178
+ echo " -> PID saved to: $LOG_DIR/vllm_lb_pid.txt"
179
+
180
+ # Wait a bit for load balancer to start
181
+ sleep 5
182
+
183
+ # Check load balancer health
184
+ echo ""
185
+ echo "Checking load balancer health..."
186
+ if curl -s "http://localhost:$LB_PORT/health" > /dev/null 2>&1; then
187
+ echo " Load balancer: HEALTHY"
188
+ curl -s "http://localhost:$LB_PORT/health" | python3 -m json.tool 2>/dev/null || curl -s "http://localhost:$LB_PORT/health"
189
+ else
190
+ echo " Load balancer: NOT READY (check log: $LOG_DIR/load_balancer_port${LB_PORT}.log)"
191
+ fi
192
+
193
+ echo ""
194
+ echo "=========================================="
195
+ echo "Deployment Complete!"
196
+ echo "=========================================="
197
+ echo ""
198
+ echo "vLLM Services:"
199
+ for i in "${!GPU_CONFIGS[@]}"; do
200
+ gpu_config="${GPU_CONFIGS[$i]}"
201
+ IFS=':' read -r gpu_id port <<< "$gpu_config"
202
+ PID="${VLLM_PIDS[$i]}"
203
+ echo " GPU $gpu_id: http://localhost:$port/v1 (PID: $PID)"
204
+ done
205
+ echo ""
206
+ echo "Load Balancer:"
207
+ echo " http://localhost:$LB_PORT (PID: $LB_PID)"
208
+ echo ""
209
+ echo "Configuration:"
210
+ echo " Update llm_service_config.yaml: base_url: \"http://localhost:$LB_PORT/v1\""
211
+ echo ""
212
+ echo "To stop these specific services, run:"
213
+ echo " ./scripts/stop_vllm_services.sh"
214
+ echo ""
215
+ echo "This will only kill the processes listed above, not other vLLM services."
216
+ echo ""
scripts/stop_reranker_services.sh ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to stop reranker services and load balancer (only the ones we started)
3
+ # Usage: ./scripts/stop_reranker_services.sh
4
+
5
+ set -e
6
+
7
+ # Get script directory
8
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
9
+ PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
10
+ cd "$PROJECT_ROOT"
11
+
12
+ LOG_DIR="./logs/reranker"
13
+ PID_FILE="$LOG_DIR/reranker_pids.txt"
14
+ LB_PID_FILE="$LOG_DIR/reranker_lb_pid.txt"
15
+
16
+ echo "=== Stopping Reranker Services and Load Balancer ==="
17
+ echo ""
18
+
19
+ # Step 1: Stop load balancer (if PID file exists)
20
+ echo "Step 1: Stopping reranker load balancer..."
21
+ if [ -f "$LB_PID_FILE" ]; then
22
+ LB_PID=$(cat "$LB_PID_FILE" 2>/dev/null | head -1)
23
+ if [ -n "$LB_PID" ] && ps -p $LB_PID > /dev/null 2>&1; then
24
+ echo " Killing load balancer PID $LB_PID..."
25
+ kill -TERM $LB_PID 2>/dev/null || true
26
+ sleep 2
27
+ if ps -p $LB_PID > /dev/null 2>&1; then
28
+ echo " Force killing load balancer PID $LB_PID..."
29
+ kill -KILL $LB_PID 2>/dev/null || true
30
+ fi
31
+ echo " Load balancer stopped"
32
+ rm -f "$LB_PID_FILE"
33
+ else
34
+ echo " Load balancer PID from file not found (may have already terminated)"
35
+ rm -f "$LB_PID_FILE"
36
+ fi
37
+ else
38
+ echo " No load balancer PID file found ($LB_PID_FILE)"
39
+ echo " If load balancer is running, you may need to find and kill it manually"
40
+ fi
41
+
42
+ echo ""
43
+
44
+ # Step 2: Stop reranker services (ONLY the ones we started)
45
+ echo "Step 2: Stopping reranker services (only the ones we started)..."
46
+
47
+ if [ -f "$PID_FILE" ]; then
48
+ echo " Reading PIDs from $PID_FILE"
49
+ KILLED_COUNT=0
50
+ NOT_FOUND_COUNT=0
51
+
52
+ while IFS= read -r pid || [ -n "$pid" ]; do
53
+ # Skip empty lines
54
+ [ -z "$pid" ] && continue
55
+
56
+ if ps -p $pid > /dev/null 2>&1; then
57
+ echo " Killing reranker service PID $pid..."
58
+ kill -TERM $pid 2>/dev/null || true
59
+ KILLED_COUNT=$((KILLED_COUNT + 1))
60
+ else
61
+ echo " PID $pid: Process not found (may have already terminated)"
62
+ NOT_FOUND_COUNT=$((NOT_FOUND_COUNT + 1))
63
+ fi
64
+ done < "$PID_FILE"
65
+
66
+ if [ $KILLED_COUNT -gt 0 ]; then
67
+ echo " Waiting 3 seconds for graceful shutdown..."
68
+ sleep 3
69
+
70
+ # Force kill if still running
71
+ while IFS= read -r pid || [ -n "$pid" ]; do
72
+ [ -z "$pid" ] && continue
73
+ if ps -p $pid > /dev/null 2>&1; then
74
+ echo " Force killing reranker service PID $pid..."
75
+ kill -KILL $pid 2>/dev/null || true
76
+ fi
77
+ done < "$PID_FILE"
78
+
79
+ echo " Stopped $KILLED_COUNT reranker service(s)"
80
+ else
81
+ echo " No running processes found from saved PIDs"
82
+ fi
83
+
84
+ if [ $NOT_FOUND_COUNT -gt 0 ]; then
85
+ echo " ($NOT_FOUND_COUNT process(es) were already terminated)"
86
+ fi
87
+
88
+ # Remove PID file after stopping
89
+ rm -f "$PID_FILE"
90
+ else
91
+ echo " WARNING: $PID_FILE not found!"
92
+ echo " Cannot safely stop services without PID file."
93
+ echo " If you know the PIDs, you can manually kill them."
94
+ echo " To avoid affecting other users, DO NOT use pkill!"
95
+ fi
96
+
97
+ echo ""
98
+ echo " NOTE: Only processes from reranker_pids.txt were killed."
99
+ echo " Other reranker services (if any) were NOT affected."
100
+
101
+ echo ""
102
+ echo "=== Checking GPU status ==="
103
+ nvidia-smi --query-gpu=index,memory.used --format=csv,noheader | grep -E '^ 0,|^ 1,|^ 2,|^ 3,' || echo "GPU 0,1,2,3 status:"
104
+
105
+ echo ""
106
+ echo "Done!"
scripts/stop_vllm_services.sh ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to stop vLLM services and load balancer
3
+ # Usage: ./scripts/stop_vllm_services.sh
4
+
5
+ # Don't use set -e here because we want to continue even if some kills fail
6
+ # set -e
7
+
8
+ # Get script directory
9
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
10
+ PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
11
+ cd "$PROJECT_ROOT"
12
+
13
+ LOG_DIR="./logs/vllm"
14
+
15
+ # Function to recursively collect all descendant PIDs of a given PID
16
+ # Returns space-separated list of all PIDs in the process tree
17
+ collect_descendant_pids() {
18
+ local root_pid=$1
19
+ local all_pids="$root_pid"
20
+ local to_check="$root_pid"
21
+ local new_pids=""
22
+
23
+ # Iteratively collect all descendants until no new children are found
24
+ while [ -n "$to_check" ]; do
25
+ new_pids=""
26
+ for pid in $to_check; do
27
+ # Find direct children of this PID
28
+ local children=$(ps -o pid --no-headers --ppid $pid 2>/dev/null | tr '\n' ' ')
29
+ if [ -n "$children" ]; then
30
+ # Add children to the list
31
+ all_pids="$all_pids $children"
32
+ new_pids="$new_pids $children"
33
+ fi
34
+ done
35
+ to_check="$new_pids"
36
+ done
37
+
38
+ echo "$all_pids"
39
+ }
40
+
41
+ # Function to collect log files opened by a process and its descendants
42
+ # Returns newline-separated list of log file paths
43
+ collect_process_log_files() {
44
+ local root_pid=$1
45
+ local log_files=""
46
+
47
+ # Collect all descendant PIDs (including the root)
48
+ local all_pids=$(collect_descendant_pids $root_pid)
49
+
50
+ # Use lsof to find all log files opened by these processes
51
+ # Look for files in the log directory that are opened by any of these PIDs
52
+ for pid in $all_pids; do
53
+ [ -z "$pid" ] && continue
54
+ if ps -p $pid > /dev/null 2>&1; then
55
+ # Find log files opened by this PID (files with .log extension in LOG_DIR)
56
+ # lsof output format: COMMAND PID USER FD TYPE DEVICE SIZE/OFF NODE NAME
57
+ # We need the last field (NAME) which is the file path
58
+ # Try both absolute and relative paths
59
+ local log_dir_abs=$(cd "$PROJECT_ROOT" && cd "$LOG_DIR" && pwd 2>/dev/null || echo "$LOG_DIR")
60
+ local pid_logs=$(lsof -p $pid 2>/dev/null | awk 'NR>1 {print $NF}' | grep -E "\.log$" | grep -E "(^$log_dir_abs/|$LOG_DIR/)" | sort -u)
61
+ if [ -n "$pid_logs" ]; then
62
+ log_files="$log_files"$'\n'"$pid_logs"
63
+ fi
64
+ fi
65
+ done
66
+
67
+ # Remove duplicates and empty lines, return unique log files
68
+ echo "$log_files" | grep -v '^$' | sort -u
69
+ }
70
+
71
+ # Function to kill a PID and all its descendants
72
+ # This ensures all child processes (including GPU processes) are terminated
73
+ kill_process_tree() {
74
+ local root_pid=$1
75
+ local signal=${2:-TERM}
76
+
77
+ if ! ps -p $root_pid > /dev/null 2>&1; then
78
+ return 1
79
+ fi
80
+
81
+ # Collect all descendant PIDs (including the root)
82
+ local all_pids=$(collect_descendant_pids $root_pid)
83
+
84
+ # Kill all processes
85
+ # For TERM, we kill from leaves to root (reverse order) for graceful shutdown
86
+ # For KILL, order doesn't matter
87
+ if [ "$signal" = "KILL" ]; then
88
+ # Force kill all processes
89
+ for pid in $all_pids; do
90
+ [ -z "$pid" ] && continue
91
+ kill -KILL $pid 2>/dev/null || true
92
+ done
93
+ else
94
+ # Graceful shutdown: kill children first, then parent
95
+ # Convert to array and kill in reverse order
96
+ local pids_array=($all_pids)
97
+ for ((idx=${#pids_array[@]}-1; idx>=0; idx--)); do
98
+ pid=${pids_array[$idx]}
99
+ [ -z "$pid" ] && continue
100
+ kill -TERM $pid 2>/dev/null || true
101
+ done
102
+ fi
103
+ }
104
+
105
+ echo "=== Stopping vLLM Services and Load Balancer ==="
106
+ echo ""
107
+
108
+ # Step 1: Stop load balancer (if PID file exists)
109
+ echo "Step 1: Stopping vLLM load balancer..."
110
+ LB_PID_FILE="$LOG_DIR/vllm_lb_pid.txt"
111
+ LB_LOG_FILES=""
112
+ if [ -f "$LB_PID_FILE" ]; then
113
+ LB_PID=$(cat "$LB_PID_FILE" 2>/dev/null | head -1)
114
+ if [ -n "$LB_PID" ] && ps -p $LB_PID > /dev/null 2>&1; then
115
+ echo " Killing load balancer PID $LB_PID..."
116
+ # Collect log files before killing
117
+ LB_LOG_FILES=$(collect_process_log_files $LB_PID)
118
+ # Also try to find load balancer log files by pattern (fallback if lsof doesn't work)
119
+ if [ -z "$LB_LOG_FILES" ]; then
120
+ LB_LOG_FILES=$(find "$LOG_DIR" -maxdepth 1 -name "load_balancer*.log" -type f 2>/dev/null)
121
+ fi
122
+ kill -TERM $LB_PID 2>/dev/null || true
123
+ sleep 2
124
+ if ps -p $LB_PID > /dev/null 2>&1; then
125
+ echo " Force killing load balancer PID $LB_PID..."
126
+ kill -KILL $LB_PID 2>/dev/null || true
127
+ fi
128
+ echo " Load balancer stopped"
129
+ rm -f "$LB_PID_FILE"
130
+
131
+ # Remove load balancer log files
132
+ if [ -n "$LB_LOG_FILES" ]; then
133
+ echo " Removing load balancer log files..."
134
+ while IFS= read -r log_file; do
135
+ [ -z "$log_file" ] && continue
136
+ if [ -f "$log_file" ]; then
137
+ rm -f "$log_file"
138
+ echo " Removed: $log_file"
139
+ fi
140
+ done <<< "$LB_LOG_FILES"
141
+ else
142
+ echo " Note: Could not detect load balancer log file (process may have already terminated)"
143
+ fi
144
+ else
145
+ echo " Load balancer PID from file not found (may have already terminated)"
146
+ rm -f "$LB_PID_FILE"
147
+ fi
148
+ else
149
+ echo " No load balancer PID file found ($LB_PID_FILE)"
150
+ echo " If load balancer is running, you may need to find and kill it manually"
151
+ fi
152
+
153
+ echo ""
154
+
155
+ # Step 2: Stop vLLM services (ONLY the ones we started)
156
+ echo "Step 2: Stopping vLLM services (only the ones we started)..."
157
+
158
+ # Try to read PIDs from file
159
+ if [ -f "$LOG_DIR/vllm_pids.txt" ]; then
160
+ echo " Reading PIDs from $LOG_DIR/vllm_pids.txt"
161
+
162
+ # Read all PIDs into an array
163
+ pids_array=()
164
+ while IFS= read -r pid || [ -n "$pid" ]; do
165
+ # Skip empty lines
166
+ [ -z "$pid" ] && continue
167
+ pids_array+=($pid)
168
+ done < "$LOG_DIR/vllm_pids.txt"
169
+
170
+ KILLED_COUNT=0
171
+ NOT_FOUND_COUNT=0
172
+
173
+ # Collect log files for all vLLM services before killing
174
+ vllm_log_files=""
175
+ for pid in "${pids_array[@]}"; do
176
+ if ps -p $pid > /dev/null 2>&1; then
177
+ # Collect log files for this PID
178
+ pid_logs=$(collect_process_log_files $pid)
179
+ if [ -n "$pid_logs" ]; then
180
+ vllm_log_files="$vllm_log_files"$'\n'"$pid_logs"
181
+ fi
182
+ fi
183
+ done
184
+
185
+ # First pass: graceful shutdown (TERM signal)
186
+ for pid in "${pids_array[@]}"; do
187
+ if ps -p $pid > /dev/null 2>&1; then
188
+ echo " Killing vLLM service PID $pid and all its descendant processes..."
189
+ # Collect and show how many processes will be killed
190
+ descendant_pids=$(collect_descendant_pids $pid)
191
+ pid_count=$(echo $descendant_pids | wc -w)
192
+ echo " Found $pid_count process(es) in the process tree"
193
+ # Use our recursive function to kill the entire process tree
194
+ kill_process_tree $pid TERM
195
+ KILLED_COUNT=$((KILLED_COUNT + 1))
196
+ else
197
+ echo " PID $pid: Process not found (may have already terminated)"
198
+ NOT_FOUND_COUNT=$((NOT_FOUND_COUNT + 1))
199
+ fi
200
+ done
201
+
202
+ if [ $KILLED_COUNT -gt 0 ]; then
203
+ echo " Waiting 3 seconds for graceful shutdown..."
204
+ sleep 3
205
+
206
+ # Second pass: force kill (KILL signal) if still running
207
+ for pid in "${pids_array[@]}"; do
208
+ if ps -p $pid > /dev/null 2>&1; then
209
+ echo " Force killing vLLM service PID $pid and all its descendant processes..."
210
+ # Collect and show how many processes will be force killed
211
+ descendant_pids=$(collect_descendant_pids $pid)
212
+ pid_count=$(echo $descendant_pids | wc -w)
213
+ echo " Force killing $pid_count process(es) in the process tree"
214
+ # Use our recursive function to force kill the entire process tree
215
+ kill_process_tree $pid KILL
216
+ fi
217
+ done
218
+
219
+ echo " Stopped $KILLED_COUNT vLLM service(s)"
220
+ else
221
+ echo " No running processes found from saved PIDs"
222
+ fi
223
+
224
+ if [ $NOT_FOUND_COUNT -gt 0 ]; then
225
+ echo " ($NOT_FOUND_COUNT process(es) were already terminated)"
226
+ fi
227
+
228
+ # Remove vLLM log files
229
+ if [ -n "$vllm_log_files" ]; then
230
+ echo ""
231
+ echo " Removing vLLM service log files..."
232
+ removed_count=0
233
+ while IFS= read -r log_file; do
234
+ [ -z "$log_file" ] && continue
235
+ if [ -f "$log_file" ]; then
236
+ rm -f "$log_file"
237
+ echo " Removed: $log_file"
238
+ removed_count=$((removed_count + 1))
239
+ fi
240
+ done <<< "$vllm_log_files"
241
+ if [ $removed_count -eq 0 ] && [ -n "$vllm_log_files" ]; then
242
+ echo " (No log files found to remove - they may have already been deleted)"
243
+ fi
244
+ else
245
+ echo ""
246
+ echo " Note: Could not detect vLLM log files (processes may have already terminated)"
247
+ fi
248
+
249
+ # Remove PID file after stopping
250
+ rm -f "$LOG_DIR/vllm_pids.txt"
251
+ else
252
+ echo " WARNING: $LOG_DIR/vllm_pids.txt not found!"
253
+ echo " Cannot safely stop services without PID file."
254
+ echo " If you know the PIDs, you can manually kill them."
255
+ echo " To avoid affecting other users, DO NOT use pkill!"
256
+ fi
257
+
258
+ echo ""
259
+ echo " NOTE: Only processes from vllm_pids.txt were killed."
260
+ echo " Other vLLM services (if any) were NOT affected."
261
+
262
+ echo ""
263
+ echo "=== Checking GPU status ==="
264
+ nvidia-smi --query-gpu=index,memory.used --format=csv,noheader | grep -E '^ 4,|^ 5,|^ 6,|^ 7,' || echo "GPU 4,5,6,7 status:"
265
+
266
+ echo ""
267
+ echo "Done!"
shared/configs/config.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for Paper Reviewer Agent
2
+ #
3
+ # Note: LLM service configuration is in configs/llm_service_config.yaml
4
+ # Note: Prompts configuration is in configs/prompts.yaml
5
+
6
+ # Global verbose mode - controls all step outputs (Step 1, Step 2, Step 3, etc.)
7
+ # Set to false to suppress intermediate progress output for faster execution
8
+ verbose: false
9
+
10
+ # Paper Search Configuration
11
+ paper_search:
12
+ # Asta API Configuration
13
+ asta:
14
+ # Single API key (backward compatible, lower priority)
15
+ api_key: null # Set via ASTA_API_KEY env var
16
+ # API key pool
17
+ api_key_pool_path: "asta_api_pool.txt" # path relative to shared/configs/ or absolute path
18
+ endpoint: "https://asta-tools.allen.ai/mcp/v1"
19
+
20
+ # Semantic Scholar API Configuration (alternative)
21
+ semantic_scholar:
22
+ api_key: null # Set via S2_API_KEY env var
23
+
24
+ # Reranker Configuration
25
+ reranker:
26
+ # Reranker model path (for direct mode)
27
+ model: "OpenScholar/OpenScholar_Reranker" # e.g., "OpenScholar/OpenScholar_Reranker" or "BAAI/bge-reranker-base"
28
+ use_fp16: true
29
+
30
+ # Reranker API Configuration (for API mode with load balancing)
31
+ # If base_url is set, use API mode with load balancer
32
+ # If endpoint_pool_path is set, use API mode with endpoint pool
33
+ # If both are None, use direct mode (load model directly)
34
+ api:
35
+ # Base URL for reranker API service (load balancer address)
36
+ # Example: "http://localhost:8009" (load balancer that distributes to 8005-8008)
37
+ # If set, will use API mode with load balancer
38
+ base_url: "http://localhost:8008"
39
+
40
+ # Endpoint pool file path (alternative to base_url)
41
+ # Example: "reranker_endpoint_pool.txt" (contains list of endpoints: http://localhost:8005, http://localhost:8006, ...)
42
+ # If set, will use API mode with endpoint pool (round-robin load balancing)
43
+ # endpoint_pool_path: "reranker_endpoint_pool.txt" # Set to use API mode with endpoint pool
44
+ endpoint_pool_path: null
45
+
46
+ # Request timeout in seconds
47
+ timeout: 30.0
48
+
49
+ # Retrieval Configuration
50
+ retrieval:
51
+ top_n: 10
52
+ use_abstract: true
53
+ norm_cite: false
54
+ min_citation: null
55
+ limit_per_keyword: 20
56
+
57
+ # Related Work Searcher Configuration
58
+ related_work_searcher:
59
+ max_related_papers: 10
60
+ max_parallel_summaries: 1
61
+ publication_date_range: null # e.g., "2020:" for papers from 2020 onwards
62
+ venues: null # e.g., "ICLR,NeurIPS"
63
+ verbose: false # Set to false to suppress intermediate progress output (faster, less output)
64
+
65
+ # Paper Reviewer Configuration
66
+ paper_reviewer:
67
+ review_format: "ai_researcher" # Options: "detailed", "summary", "structured", "ai_researcher"
68
+ max_tokens: 16384 # Maximum tokens for reviewer output (increased for detailed reviews)
69
+
70
+ # Review Refiner Configuration
71
+ review_refiner:
72
+ review_format: "strict_detailed" # Options: "detailed", "summary", "structured", "strict_detailed" - should match paper_reviewer format
73
+ max_tokens: 16384 # Maximum tokens for refiner output (increased for detailed reviews)
74
+
75
+ # Output Configuration
76
+ output:
77
+ save_reviews: true
78
+ output_dir: "./outputs"
79
+ format: "json" # Options: "json", "markdown", "txt"
80
+
81
+ # Evaluation Configuration
82
+ evaluation:
83
+ # Default number of worker threads for concurrent evaluation
84
+ max_workers: 16
85
+
86
+ # Component name for LLM service assignment (used with llm_service_config.yaml)
87
+ # Options: "keyword_generator", "paper_summarizer", "reviewer", "refiner"
88
+ # Defaults to "reviewer" if not specified
89
+ llm_component: "reviewer"
90
+
91
+ # Default model name (can be overridden by command line args or llm_service_config.yaml)
92
+ default_model_name: "Qwen/Qwen2.5-72B-Instruct"
93
+
94
+ # Prompt versions
95
+ rubric_generation_prompt_version: "v2" # Options: "v1", "v2"
96
+ evaluator_prompt_version: "v1" # Options: "v0", "v1"
97
+
shared/configs/llm_service_config.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vLLM Service Configuration
2
+
3
+ # vLLM server settings
4
+ vllm:
5
+ # Base URL for vLLM service
6
+ # In production, this should point to a load balancer (nginx/HAProxy) that distributes
7
+ # requests across multiple vLLM service instances running on different GPUs.
8
+ # Example: "http://localhost:8000/v1" (load balancer address)
9
+ base_url: "http://localhost:8000/" # directly to the load balancer, 8000-8003 are used by the vllm services
10
+
11
+ api_key: "dummy-key" # Not used for local vLLM, but required by OpenAI client
12
+ model_name: "openai/gpt-oss-120b"
13
+ # model_name: "Qwen/Qwen3-4B-Instruct-2507"
14
+ # model_name: "Qwen/Qwen3-235B-A22B-Instruct-2507"
15
+ timeout: 300
16
+
17
+ # Rate limiting: Maximum concurrent requests to vLLM server
18
+ # Lower this if you're getting 500 errors (suggests server overload)
19
+ # Recommended: 4-8 for small models, 2-4 for large models
20
+ max_concurrent_requests: 64
21
+
22
+ # Retry configuration for server errors
23
+ max_retries: 3 # Number of retries for 500/502/503/504 errors
24
+ retry_delay: 1.0 # Initial delay in seconds
25
+ retry_backoff: 2.0 # Exponential backoff multiplier
26
+
27
+ # Default sampling parameters
28
+ temperature: 0.7
29
+ top_p: 0.8
30
+ top_k: 20
31
+ max_tokens: 16384
32
+ presence_penalty: 0.0
33
+
34
+ # GPT / OpenRouter API Configuration
35
+ gpt:
36
+ enabled: true
37
+ # Leave api_key null so it is taken from OPENAI_API_KEY or OPENROUTER_API_KEY
38
+ api_key: null
39
+ # Use the OpenRouter model you specified
40
+ model_name: "openai/gpt-oss-120b"
41
+ # Point the OpenAI-compatible client to OpenRouter
42
+ base_url: "http://localhost:8000/"
43
+ timeout: 300
44
+
45
+ # Default sampling parameters
46
+ temperature: 0.7
47
+ top_p: 0.95
48
+ max_tokens: 16384
49
+ presence_penalty: 0.0
50
+
51
+ # LLM Service Assignment
52
+ # Specify which LLM service to use for each component
53
+ llm_assignments:
54
+ keyword_generator: "gpt" # Options: "vllm", "gpt"
55
+ paper_summarizer: "gpt" # Options: "vllm", "gpt"
56
+ reviewer: "gpt" # Options: "vllm", "gpt"
57
+ refiner: "gpt" # Options: "vllm", "gpt" - defaults to reviewer if not specified
shared/configs/prompts.yaml ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt templates for the paper reviewer agent
2
+
3
+ # Keyword generation prompt
4
+ # Based on OpenScholar's keyword extraction prompt, adapted for JSON output format
5
+ keyword_generation:
6
+ system: "You are an experienced research assistant helping to find related work for a paper. Always respond with valid JSON."
7
+ user: |
8
+ Suggest search queries to retrieve relevant papers related to the following paper. The search queries must be simple, short, and comma separated. Focus on core technical concepts, methods, and key techniques that would help find related work.
9
+
10
+ Here's an example:
11
+ ##
12
+ Paper: How have prior work incorporated personality attributes to train personalized dialogue generation models?
13
+ Search queries: personalized dialogue generation, personalized language models, personalized dialogue
14
+ ##
15
+ Paper: How do retrieval-augmented LMs perform well in knowledge-intensive tasks?
16
+ Search queries: retrieval-augmented LMs, knowledge-intensive tasks, large language models for knowledge-intensive tasks, retrieval-augmented generation
17
+ ##
18
+
19
+ Paper information:
20
+ {context}
21
+
22
+ IMPORTANT: You must respond with valid JSON only. Use this format:
23
+ {{
24
+ "keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"]
25
+ }}
26
+
27
+ Return only the JSON, no additional text or explanation. Generate 3-5 short, comma-separated keywords as search queries.
28
+
29
+ # DOMAIN-SPECIFICAGENTS
30
+ # Paper summarization prompt
31
+ # Generate structured summary of each related paper: summary, main methods, key findings, relation with the target paper
32
+ paper_summarization:
33
+ user: |
34
+ You are an senior research assistant who is proficient at identifying the main contribution, key findings of papers and the relations between different works.
35
+
36
+ For the reference paper below:
37
+
38
+ {reference_paper}
39
+
40
+ You are given this paper as a related work to the reference paper:
41
+
42
+ {related_paper}
43
+
44
+ Now, you need to provide concise information on what the related work is about, its main methods, results and the relation between the reference paper, to make it easier for the supervisor to write the review.
45
+
46
+ Focusing on the relationship between the reference paper and the related work, summarize the related work's main methods, results and the relation between them in a concise way.
47
+
48
+ IMPORTANT: You must respond with valid JSON only. Use this format:
49
+ {{
50
+ "summary": "Your concise summary here in 2-3 sentences.",
51
+ "main_methods": "The main methods of the related work.",
52
+ "key_findings": "The key findings of the related work.",
53
+ "relation": "The relation between the related work and the paper you are reviewing, such as how they share similar ideas, solving the same problem, have diverged claims, etc."
54
+ }}
55
+ Return only the JSON, no additional text or explanation.
56
+
57
+ # Paper Insight Miner Prompt
58
+ paper_insight_miner:
59
+ user: |
60
+ You are an expert research assistant. Your task is to help refine the method/contribution parts of a candidate review, using the paper content as the source of truth.
61
+
62
+ SCOPE (strict):
63
+ - ONLY cover: core contributions, technical approach, model/algorithm design, mathematical formulation, assumptions, optimization/training, implementation details, and limitations of the method.
64
+ - Novelty: ONLY assess novelty claims AS PRESENTED IN THE PAPER (no external knowledge, no web search, no comparing to papers not mentioned in the text).
65
+ - Do NOT comment on experimental results, benchmarks, or score/decision fields (handled by another module).
66
+ - Do NOT do external related-work positioning (handled by another module).
67
+
68
+ Paper content:
69
+ {content}
70
+
71
+ Candidate review:
72
+ {candidate_review}
73
+
74
+ What to do:
75
+ 1) Extract the paper’s core contributions and method details (paper-grounded).
76
+ 2) Check the candidate review’s method/contribution claims against the paper and identify:
77
+ - incorrect/hallucinated/contradicted claims,
78
+ - missing key technical points,
79
+ - vague or generic statements that should be made specific.
80
+ 3) Provide short rewrite suggestions WITH evidence anchors (Section/Equation/Algorithm/Figure/snippet if available).
81
+
82
+ Output JSON only:
83
+ {
84
+ "facts": {
85
+ "core_contributions": [
86
+ {"claim": "...", "evidence": "..."}
87
+ ],
88
+ "method_summary": [
89
+ {"point": "key component / step / design choice", "evidence": "..."}
90
+ ],
91
+ "assumptions_and_scope": [
92
+ {"item": "...", "evidence": "..."}
93
+ ],
94
+ "novelty_claims_in_paper": [
95
+ {"claim": "as stated by the authors", "evidence": "..."}
96
+ ]
97
+ },
98
+ "review_issues": {
99
+ "incorrect_or_hallucinated": [
100
+ {"review_claim": "...", "why_wrong": "...", "evidence": "..."}
101
+ ],
102
+ "missing_key_points": [
103
+ {"what_missing": "...", "why_important": "...", "evidence": "..."}
104
+ ],
105
+ "needs_specificity": [
106
+ {"review_text": "...", "how_to_fix": "name the component/assumption/equation/step", "evidence": "..."}
107
+ ]
108
+ },
109
+ "rewrite_suggestions": [
110
+ {
111
+ "apply_to": "Summary|Strengths|Weaknesses|Suggestions|Questions (method-related only)",
112
+ "target": "Core Contribution Accuracy|Evidence-Based Critique|Critique Clarity",
113
+ "suggested_text": "1-2 sentences",
114
+ "evidence": "..."
115
+ }
116
+ ]
117
+ }
118
+
119
+ Rules:
120
+ - If you cannot find support in the paper text, set evidence to "not_found_in_text"; do NOT assert the paper is missing it.
121
+ - Keep each list short (<=5 items). Prefer the most important contributions/components/issues.
122
+ - Return JSON only. No extra text.
123
+
124
+ # Paper results summarization prompt for Result Analyzer Agent
125
+ # according to the paper content and candidate review, pinpoint the issues and provide refinement suggestions with concrete evidence to the candidate review.
126
+ paper_results_analyzer:
127
+ user: |
128
+ You are an expert research assistant. Your task is to help refine the experiment/evaluation parts of a candidate review, using the paper content as the source of truth.
129
+
130
+ SCOPE (strict):
131
+ - ONLY cover experimental evaluation: datasets, baselines, metrics, tables/figures, quantitative results, statistical evidence, ablations.
132
+ - Do NOT comment on novelty, related-work positioning, writing/presentation quality, or overall recommendation. Other agents will handle those.
133
+
134
+ Paper content:
135
+ {content}
136
+
137
+ Candidate review:
138
+ {candidate_review}
139
+
140
+ What to do:
141
+ 1) Extract key experimental facts from the paper.
142
+ 2) Check experiment-related claims in the candidate review and identify:
143
+ - incorrect/hallucinated/contradicted claims,
144
+ - missing key experimental points,
145
+ - vague statements that should be made specific.
146
+ 3) Provide short rewrite suggestions WITH evidence anchors (Table/Figure/Section/snippet if available).
147
+
148
+ Output JSON only:
149
+ {
150
+ "facts": {
151
+ "datasets": ["..."],
152
+ "metrics": ["..."],
153
+ "baselines": ["..."],
154
+ "key_results": [
155
+ {"claim": "...", "evidence": "..."}
156
+ ]
157
+ },
158
+ "review_issues": {
159
+ "incorrect_or_hallucinated": [
160
+ {"review_claim": "...", "why_wrong": "...", "evidence": "..."}
161
+ ],
162
+ "missing_key_points": [
163
+ {"what_missing": "...", "why_important": "...", "evidence": "..."}
164
+ ],
165
+ "needs_specificity": [
166
+ {"review_text": "...", "how_to_fix": "...", "evidence": "..."}
167
+ ]
168
+ },
169
+ "rewrite_suggestions": [
170
+ {
171
+ "apply_to": "Summary|Strengths|Weaknesses|Suggestions|Questions (experiment-related only)",
172
+ "target": "Results Interpretation|Evidence-Based Critique",
173
+ "suggested_text": "1-2 sentences",
174
+ "evidence": "..."
175
+ }
176
+ ]
177
+ }
178
+
179
+ Rules:
180
+ - If you cannot find support in the paper text, set evidence to "not_found_in_text"; do NOT assert the paper is missing it.
181
+ - Keep each list short (<=5 items). Prefer the most important issues/results.
182
+ - Return JSON only. No extra text.
183
+
184
+
185
+ # Reviewer model prompts (this may never been used)
186
+ # Need to be adapted to the actual review model
187
+ review_prompts:
188
+ detailed: |
189
+ You are reviewing a research paper. Please provide a detailed review in markdown format covering the following sections in this exact order:
190
+
191
+ ## Summary
192
+ A brief summary of the paper's contributions and main findings.
193
+
194
+ ## Soundness
195
+ Rate the technical soundness and correctness of the methodology (provide a score from 1 to 5, where 1 is very poor and 5 is excellent). Do not add any explanation.
196
+
197
+ ## Presentation
198
+ Rate the clarity and quality of presentation (provide a score from 1 to 5, where 1 is very poor and 5 is excellent). Do not add any explanation.
199
+
200
+ ## Contribution
201
+ Rate the significance and novelty of the contribution (provide a score from 1 to 5, where 1 is very poor and 5 is excellent). Do not add any explanation.
202
+
203
+ ## Strengths
204
+ What are the main strengths of this paper? Consider:
205
+ - Novelty and originality
206
+ - Technical soundness
207
+ - Experimental validation
208
+ - Clarity of presentation
209
+ - Significance of contributions
210
+
211
+ ## Weaknesses
212
+ What are the main weaknesses or concerns? Consider:
213
+ - Methodological issues
214
+ - Missing experiments or baselines
215
+ - Limitations not acknowledged
216
+ - Clarity issues
217
+ - Reproducibility concerns
218
+
219
+ ## Questions
220
+ Any questions you have for the authors that would help clarify aspects of the work.
221
+
222
+ ## Rating
223
+ Overall rating of the paper (provide a score from 1 to 10, following the reviewer scale). Do not add any explanation.
224
+
225
+ ## Confidence
226
+ Your confidence in your assessment (provide a score from 1 to 5, following the reviewer scale). Do not add any explanation.
227
+
228
+ ## Decision
229
+ Your recommendation: "accept", "reject", or "undecided". Do not add any explanation.
230
+
231
+ IMPORTANT: Write your review in markdown format with the exact section headers above (## Summary, ## Soundness, etc.). Include scores and explanations for each scoring section. Be constructive, specific, and fair in your review.
232
+
233
+ ai_researcher: |
234
+ You are an expert academic reviewer. Your task is to provide a thorough, structured, and balanced review of the following research paper.
235
+
236
+ Step 1: Read and Analyze the Paper Carefully
237
+ - Read the paper paragraph by paragraph.
238
+ - For each paragraph:
239
+ - Perform detailed analysis and document your thought process using <think></think> tags.
240
+ - Identify strengths, weaknesses, unclear points, logical flaws, technical inconsistencies, or missing references.
241
+ - Highlight both strengths and weaknesses.
242
+ - Ensure all observations are supported by reasoning inside <think></think> tags.
243
+
244
+ Step 2: Conduct the Review
245
+ - After completing paragraph-by-paragraph analysis, provide an overall assessment following the structure below.
246
+ - Provide scores, recommendations, and a final decision in the strict JSON format.
247
+ - Do **not** include <think> tags inside JSON format.
248
+ - Be concise yet sufficiently detailed for an academic review.
249
+
250
+ Step 3: Organize your reviews into the following JSON format:
251
+ {
252
+ "summary": [Concise, detailed summary covering methodology, key ideas, and results.],
253
+ "soundness": [Score 1-5],
254
+ "presentation": [Score 1-5],
255
+ "contribution": [Score 1-5],
256
+ "strengths": [List major strengths],
257
+ "weaknesses": [List major weaknesses, with confidence levels if applicable],
258
+ "suggestions": [Concrete recommendations to address weaknesses],
259
+ "questions": [Outstanding questions or clarifications needed],
260
+ "rating": [Overall score, e.g., 1-10],
261
+ "confidence": [Confidence in assessment, e.g., 1-5],
262
+ "decision": [Accept, Reject]
263
+ }
264
+
265
+ Few-shot example (for format reference and not copy the example):
266
+
267
+ {
268
+ "summary": "This paper introduces a novel algorithm for modeling shared dynamics between multiple observation processes, validated on both simulated and real-world data.",
269
+ "soundness": 3.0,
270
+ "presentation": 3.0,
271
+ "contribution": 3.0,
272
+ "strengths": "- Novel decomposition approach.\n - Separation of shared and residual dynamics.\n - Validated on real and simulated data.",
273
+ "weaknesses": "- Strong linearity assumption (high confidence).\n - Limited experiments (medium confidence).",
274
+ "suggestions": "- Test on nonlinear systems.\n - Expand evaluation datasets.",
275
+ "questions": "- Sensitivity to deviations from assumed structure?\n - Performance on nonlinear data?",
276
+ "rating": 6.5,
277
+ "confidence": 3.0,
278
+ "decision": "Accept"
279
+ }
280
+
281
+ Few-shot example (strictly follow this structure and do not copy the example):
282
+ {
283
+ "summary": "This paper introduces PG-LDS-ID, a novel algorithm designed to model the shared dynamics between two observation processes: a continuous-time Gaussian process and a discrete-time Poisson process. The core idea is to use a latent state-space model to capture the underlying dynamics that influence both observation streams, while also accounting for residual dynamics unique to the Poisson process. The authors propose a two-stage approach, where the first stage identifies the shared dynamics using a covariance-based subspace identification method, and the second stage identifies the residual dynamics that are only observable through the Poisson process. A key contribution is the introduction of a block-structured system matrix, which facilitates the separation of shared and residual dynamics. The method is motivated by applications in neuroscience, where one might want to model the relationship between continuous behavioral trajectories and discrete neural spiking activity. The authors validate their approach using both simulated data and a real-world dataset of non-human primate neural spiking activity and arm movements. The simulation results demonstrate the algorithm's ability to accurately recover the shared dynamics, and the real data experiment shows improved prediction accuracy compared to a prior method, PLDSID. The paper's significance lies in its ability to handle coupled observation processes with different statistical properties, while explicitly disentangling shared and residual dynamics, a capability not simultaneously offered by existing analytical methods. However, the paper's reliance on strong assumptions, such as linearity and a specific block structure, and the limited scope of its experimental validation, raise important questions about its broader applicability and robustness.",
284
+ "soundness": 2.8,
285
+ "presentation": 2.4,
286
+ "contribution": 2.6,
287
+ "strengths": "One of the primary strengths of this paper is the novel decomposition technique introduced through Equation (6), which employs a block-structured system matrix. This decomposition is a significant contribution as it simplifies the modeling of shared dynamics between data streams with different statistical properties, specifically Gaussian and Poisson processes. By breaking down the problem into manageable components, the authors enhance the overall approach's ease of handling and implementation. This is particularly valuable in practical scenarios where dealing with coupled dynamics can be complex. Furthermore, the paper's focus on explicitly disentangling shared and residual dynamics is a notable advancement. Existing methods often model the collective dynamics of multiple modalities in the same latent states, whereas PG-LDS-ID explicitly separates the shared dynamics from those unique to the Poisson process. This distinction is crucial for understanding the underlying mechanisms that drive different observation streams. The authors demonstrate the practical applicability of their method through both simulated and real-world experiments. The simulation results show that the proposed method can accurately recover the shared dynamics, and the real data experiment on non-human primate data shows that PG-LDS-ID achieves better prediction accuracies compared to PLDSID, a prior method. This empirical validation provides evidence for the effectiveness of the algorithm in a realistic setting. Finally, the method's ability to handle generalized linear processes, as opposed to being limited to Gaussian processes, is another strength. By using second-order moments, the proposed method can now deal with a broader class of observation models, making it more versatile and applicable to a wider range of problems.",
288
+ "weaknesses": "After a thorough review of the paper and the reviewer comments, I have identified several key weaknesses that significantly impact the paper's conclusions and broader applicability. First, the paper's strong reliance on the assumption of latent linear dynamics is a major limitation. The entire method is built upon linear dynamical state-space models, which, as the authors acknowledge in the 'Limitations' section, can only provide an approximation of nonlinear dynamics. This assumption is particularly concerning given that many real-world systems, especially those in neuroscience, exhibit nonlinear behavior. The authors do not provide any experimental results or analysis of how the method performs when applied to observations generated by a latent nonlinear system. This lack of evaluation makes it difficult to assess the method's robustness and applicability in real-world scenarios. The confidence level for this weakness is high, as the paper explicitly states its reliance on linear models and lacks any analysis of nonlinear systems. Second, the paper introduces a specific block structure in Equation (6) for the system matrices, which is a critical assumption for the method's ability to dissociate shared and residual dynamics. While the authors justify this structure as a design choice to facilitate the separation of dynamics, they do not sufficiently discuss the conditions under which this decomposition can be effectively implemented, or the consequences of deviations from this structure. Specifically, the paper does not explore what happens if the true coefficient matrix has non-zero values in the upper right block, which would violate the assumed block structure. The practical implications of this choice are not fully explored, and the paper lacks any sensitivity analysis to assess the robustness of the method to such deviations. The confidence level for this weakness is high, as the paper introduces the block structure as a key design choice without addressing its limitations or potential for misapplication. Third, the paper lacks a detailed comparison with recent, relevant subspace identification methods that also leverage multimodal data. The authors compare their method against PLDSID, a method from 2012, but do not compare against more recent techniques such as those presented in Ahmadipour et al. (2023) and Vahidi et al. (2023). This lack of comparison makes it difficult to assess the novelty and specific advantages of the proposed method compared to the current state-of-the-art. The paper mentions that existing methods do not explicitly tease apart shared and residual dynamics, but a more thorough comparison is needed to justify the contribution of this work. The confidence level for this weakness is high, as the paper does not include a comparison with recent, relevant methods. Fourth, the paper does not adequately address the estimation of the Gaussian observation noise variance. While the optimization procedure in Section 3.2.3 ensures valid noise statistics, the explicit estimation of the noise variance of the Gaussian observation process is not clearly outlined as a separate step before the optimization. This omission raises concerns about the method's sensitivity to variations in the noise variance and its impact on the accuracy of the estimated latent states. The confidence level for this weakness is medium, as the paper implicitly addresses noise statistics but does not explicitly detail the estimation of the Gaussian noise variance. Fifth, the paper's experimental evaluation is limited in scope. The authors primarily compare their method against PLDSID and do not include comparisons with more recent and competitive methods. This limited evaluation makes it difficult to assess the proposed algorithm's strengths and weaknesses in the current research landscape. Furthermore, the paper uses only one real-world dataset (NHP data), which limits the assessment of the model's broader applicability. The confidence level for this weakness is high, as the experimental section lacks comparisons with recent methods and uses a limited number of datasets. Finally, the paper claims that the algorithm can be generalized to non-Poisson/non-Gaussian models but does not provide any experimental evidence to support this claim. The paper states that the moment transformation step is key to extending the method, but no results are shown for any other distributions. This lack of empirical evidence makes the claim of generalizability unsubstantiated. The confidence level for this weakness is high, as the claim is made without any supporting experimental results.",
289
+ "suggestions": "To address the identified weaknesses, I recommend several concrete improvements. First, the authors should conduct a thorough analysis of the method's sensitivity to violations of the linearity assumption. This could involve simulating data from a variety of nonlinear dynamical systems and assessing the accuracy of the estimated latent states and their dimensions. For example, they could use simple nonlinear systems like the Duffing oscillator or the Lorenz attractor to generate synthetic data and then apply their method to this data. The performance of the method could be evaluated by comparing the estimated latent states and their dimensions to the true values. Furthermore, it would be beneficial to explore how the method's performance changes as the degree of nonlinearity increases. This analysis would provide a more comprehensive understanding of the method's limitations and its applicability to real-world scenarios where nonlinearities are common. Second, the authors should provide a more detailed analysis of the method's sensitivity to deviations from the assumed block structure in Equation (6). This could involve simulations where the true coefficient matrix has small non-zero values in the upper right block and assessing whether the method still converges to a reasonable estimate. A sensitivity analysis exploring the robustness of the method to such deviations would be crucial. Furthermore, the paper should provide more guidance on how to choose the dimensions of the latent spaces ($n_1$ and $n_x$). The current description is somewhat vague, and a more concrete procedure, perhaps based on information criteria or cross-validation, would be highly valuable. Third, the authors should include a detailed comparison of their approach with recent subspace identification techniques that use both behavioral and neural data, such as those presented in [1] and [2]. This comparison should include a discussion of the assumptions made by each method, the optimization procedures used, and the types of data that can be handled. For example, the authors should clearly explain how their method differs from the approaches presented in [1] and [2] in terms of the way they model the shared and residual dynamics. They should also discuss the advantages and disadvantages of their method compared to these existing techniques. Fourth, the authors should clarify the role of the noise variance of the Gaussian observation process in their method. They should provide a detailed analysis of the method's sensitivity to variations in the noise variance. This analysis could include simulations with different noise levels and a quantitative assessment of the error in latent state estimation and dimensionality identification. Furthermore, they should discuss how the method's performance is affected by the choice of the noise model. Fifth, the experimental evaluation should be expanded to include comparisons with more recent and competitive methods. While PLDSID is a relevant baseline, the field has seen significant advancements since 2012. Including comparisons with state-of-the-art methods, such as more recent deep learning approaches for time series modeling, would provide a more comprehensive assessment of the proposed algorithm's performance. This would not only highlight the strengths of the proposed method but also reveal its limitations and areas for future improvement. Finally, the authors should provide empirical support for their claim that the algorithm can be generalized to non-Poisson/non-Gaussian models. This could involve testing the method on synthetic datasets from simple models with alternative distributions. The authors should also consider including a simple example of how the moment transformation would be derived for a different distribution, such as Bernoulli, to further support their claim.",
290
+ "questions": "Several key uncertainties remain after my review of this paper. First, I am particularly interested in the justification for the block structure assumed in Equation (6). While the authors claim this structure does not lose generality, I would like to understand the practical implications of this choice more thoroughly. Specifically, how does the method behave when the true underlying system deviates from this block structure, even slightly? What happens if there are small non-zero values in the upper right block of the coefficient matrix? Does the method still converge to a reasonable estimate, or does it break down? A more detailed explanation of the assumptions underlying this block structure, and a sensitivity analysis exploring its robustness, would be highly beneficial. Second, I am curious about the method's performance when applied to nonlinear systems. The paper acknowledges the limitation of assuming linear dynamics but does not provide any analysis of how the method performs when this assumption is violated. How does the method perform when the underlying system is nonlinear? How does the accuracy of the estimated latent states and their dimensions change as the degree of nonlinearity increases? I would like to see more systematic evaluations of the method's performance under nonlinear conditions. Third, I would like to understand how the method compares to existing approaches that use subspace identification for multimodal data, specifically those mentioned in [1] and [2]. How does the proposed method differ in terms of the assumptions made, the optimization procedures used, and the types of data that can be handled? A more detailed comparison with these methods is needed to justify the specific contribution of this work. Fourth, I am interested in the role of the noise variance of the Gaussian observation process in the method. How does the noise variance affect the accuracy of the estimated latent states and their dimensions? How does the method's performance change as the noise variance varies? A more thorough analysis of the method's sensitivity to variations in the noise variance would be valuable. Finally, I would like to understand the practical limitations of the proposed method. What are the assumptions underlying the method, and when might these assumptions be violated in practice? Are there specific types of dynamical systems for which the method is not suitable? A clear discussion of these limitations would help readers understand the scope of the method and avoid misapplications.",
291
+ "rating": 6.8,
292
+ "confidence": 2.6,
293
+ "decision": "Reject"
294
+ }
295
+
296
+ NOTE: Output the JSON format ONLY, DO NOT output anything other than the json, DO NOT include your thinking process.
297
+
298
+
299
+ # System message for reviewer
300
+ reviewer_system: "You are an expert academic reviewer with deep knowledge in the field. Always respond with valid JSON format."
301
+
302
+ # Refiner prompts
303
+ refiner_prompts:
304
+ detailed: |
305
+ You are a senior researcher refining an existing peer review. Your job is to improve factual grounding, coverage, and usefulness while preserving the draft’s structure and intent. Treat the paper text as the source of truth.
306
+
307
+ You will be given:
308
+ (1) Paper text (plain text converted from PDF)
309
+ (2) Draft review (structured)
310
+ (3) Method/Contribution audit report (from Paper Insight Miner; paper-grounded)
311
+ (4) Experiments/Results audit report (from Paper Results Analyzer; paper-grounded)
312
+ (5) Related-work summaries (each item is a JSON summary of one retrieved paper, written relative to the target paper)
313
+
314
+ ========================
315
+ Primary objectives (what to improve)
316
+ ========================
317
+ Refine the review to satisfy these content-quality dimensions:
318
+
319
+ 1) Core Contribution Accuracy
320
+ 2) Results Interpretation
321
+ 3) Comparative Analysis / Positioning
322
+ 4) Evidence-Based Critique
323
+ 5) Critique Clarity
324
+ 6) Completeness Coverage
325
+ 7) Constructive Tone
326
+ 8) Avoid False or Contradictory Claims (critical)
327
+
328
+ ========================
329
+ Hard constraints (must follow)
330
+ ========================
331
+ A. Paper-grounded correctness is mandatory:
332
+ - If the audit reports mark a draft claim as incorrect/hallucinated/contradicted, you MUST fix or remove it.
333
+ - Do NOT introduce new factual claims about the paper unless you can anchor them to the paper text or the audit reports’ evidence.
334
+
335
+ B. Evidence anchoring rule:
336
+ - Every major critique (esp. in Weaknesses/Suggestions/Questions) must include a verifiable anchor:
337
+ section name, table/figure identifier, equation/algorithm reference, dataset/metric name, or a short quote snippet (<= 20 words).
338
+ - If you cannot find support, convert the statement into a question or a suggestion for clarification (do not assert absence).
339
+
340
+ C. Related-work usage rule (anti-leak / anti-overclaim):
341
+ - Retrieved related-work summaries are NOT guaranteed to be cited by the submission.
342
+ - Never claim “the paper compares to/cites X” unless the paper text actually contains X.
343
+ - When using retrieved works, attribute them as external context:
344
+ ��The related-work search suggests …; it would help to clarify/compare …”
345
+ - Use related work to: (i) sharpen positioning, (ii) propose missing baselines/comparisons, (iii) raise targeted questions.
346
+
347
+ D. Minimal-change policy:
348
+ - Keep the original structure and as much of the draft wording as possible.
349
+ - Do NOT shorten aggressively; do NOT rewrite into a totally new review.
350
+ - Prefer targeted edits, insertions, and corrections.
351
+
352
+ E. Numeric fields policy (IMPORTANT):
353
+ - Default: keep ALL numeric fields and the decision unchanged.
354
+ - Change numeric fields ONLY if the refined textual assessment would otherwise be clearly inconsistent, or if a major factual correction materially changes the evaluation.
355
+ - If you change any numeric field: change the minimum number of fields, and keep changes small unless necessary.
356
+
357
+ ========================
358
+ How to use the tool reports (operational)
359
+ ========================
360
+ 1) Apply Paper Insight Miner (method/contribution):
361
+ - Use `review_issues.incorrect_or_hallucinated` to remove/correct wrong claims in Summary/Strengths/Weaknesses.
362
+ - Use `missing_key_points` and `needs_specificity` to improve technical specificity.
363
+ - Incorporate `rewrite_suggestions` where appropriate (method-related only).
364
+
365
+ 2) Apply Paper Results Analyzer (experiments/results):
366
+ - Correct any wrong result interpretation.
367
+ - Add missing datasets/baselines/metrics/key results if they are important and supported.
368
+ - Convert vague experiment critiques into concrete, testable suggestions with anchors.
369
+ - Incorporate `rewrite_suggestions` where appropriate (experiment-related only).
370
+
371
+ 3) Use Related-work summaries:
372
+ - Use each item’s `relation` to craft 1–3 concrete positioning points:
373
+ - what is similar/different,
374
+ - what comparisons would strengthen the paper,
375
+ - what claims need clarification.
376
+ - Do NOT dump a bibliography; only mention the most relevant comparisons (typically <= 3 items).
377
+ - Phrase as external suggestions, not accusations.
378
+
379
+ ========================
380
+ Refinement checklist (do in order)
381
+ ========================
382
+ Step 1: Fix incorrect/hallucinated statements flagged by the two audit reports.
383
+ Step 2: Improve Summary and Strengths with paper-grounded method + results highlights.
384
+ Step 3: Strengthen Weaknesses with evidence anchors and clearer critique.
385
+ Step 4: Add actionable Suggestions (each mapped to a weakness).
386
+ Step 5: Improve Questions to resolve uncertainties (especially when evidence is not found).
387
+ Step 6: Improve Comparative Analysis using related-work summaries with proper attribution.
388
+ Step 7: Ensure constructive tone and completeness across method / experiments / positioning.
389
+
390
+ ========================
391
+ Output format (JSON ONLY)
392
+ ========================
393
+ Return a JSON object with the following keys ONLY.
394
+ - Numeric fields must be numbers (not strings).
395
+ - decision must be one of: "accept", "reject".
396
+ - Do not output any text outside JSON.
397
+
398
+ {
399
+ "summary": "...",
400
+ "strengths": "...",
401
+ "weaknesses": "...",
402
+ "questions": "...",
403
+ "soundness": 0,
404
+ "presentation": 0,
405
+ "contribution": 0,
406
+ "rating": 0,
407
+ "confidence": 0,
408
+ "decision": "your_decision"
409
+ }
410
+
411
+ ========================
412
+ Inputs
413
+ ========================
414
+ [Paper Text]
415
+ <<paper_text>>
416
+
417
+ [Draft Review]
418
+ <<draft_review>>
419
+
420
+ [Paper Insight Miner Output (JSON)]
421
+ <<insight_miner_json>>
422
+
423
+ [Paper Results Analyzer Output (JSON)]
424
+ <<results_analyzer_json>>
425
+
426
+ [Related-work Summaries (JSON list)]
427
+ <<related_work_json_list>>
428
+
429
+
430
+ # System message for refiner
431
+ refiner_system: "You are an expert review refiner with deep knowledge in academic review quality standards and meta rubrics."
432
+
433
+ # Evaluation prompts for review-based rubric generation and evaluation
434
+ # Rubric template for generating paper-specific rubrics
435
+ rubrics: |
436
+ {
437
+ "title": "Core Contribution Accuracy",
438
+ "description": "Essential Criteria: Identifies whether the review accurately describes the paper's main contributions and central methodological innovations in the summary and strengths sections without misinterpretation.",
439
+ "weight": 1
440
+ },
441
+ {
442
+ "title": "Results Interpretation",
443
+ "description": "Important Criteria: Explains whether the review correctly interprets the empirical results in the summary and strengths sections, including tables, figures, and statistical comparisons.",
444
+ "weight": 1
445
+ },
446
+ {
447
+ "title": "Comparative Analysis",
448
+ "description": "Important Criteria: States whether the review appropriately discusses comparisons with baselines and related work presented in the paper.",
449
+ "weight": 1
450
+ },
451
+ {
452
+ "title": "Evidence-Based Critique",
453
+ "description": "Essential Criteria: Identifies whether weaknesses and criticisms in the Weaknesses, Suggestions, and Questions sections are supported by specific references to paper content such as sections, equations, tables, or figures.",
454
+ "weight": 1
455
+ },
456
+ {
457
+ "title": "Critique Clarity",
458
+ "description": "Important Criteria: Explains whether the identified weaknesses in the Weaknesses, Suggestions, and Questions sections are stated clearly and specifically enough for authors to understand what needs improvement.",
459
+ "weight": 1
460
+ },
461
+ {
462
+ "title": "Completeness Coverage",
463
+ "description": "Important Criteria: Identifies whether the review addresses all major components of the paper including methodology, theory, experiments, and related work.",
464
+ "weight": 1
465
+ },
466
+ {
467
+ "title": "Constructive Tone",
468
+ "description": "Optional Criteria: States whether the review maintains a professional and constructive tone that encourages improvement rather than discouragement.",
469
+ "weight": 1
470
+ },
471
+ {
472
+ "title": "False or Contradictory Claims",
473
+ "description": "Pitfall Criteria: Does not mention content or experiments that are actually absent from the paper, incorrectly claim something is missing when it exists, or make statements that contradict the paper's stated results, conclusions, or explicitly documented design choices.",
474
+ "weight": -1
475
+ }
476
+
477
+ # Rubric generation prompt (v2 - uses template)
478
+ v2_rubric_generation_prompt: |
479
+ You are an expert rubric writer for evaluating the quality of paper reviews in AI-related academic fields, including:
480
+
481
+ - Machine Learning
482
+ - Deep Learning
483
+ - Natural Language Processing (NLP)
484
+ - Computer Vision
485
+ - Robotics
486
+ - Reinforcement Learning
487
+ - Optimization
488
+ - Data-centric AI
489
+ - Related subdisciplines
490
+
491
+ You are given a paper content, a golden review, and a review evaluation rubric template.
492
+ According the paper content, golden review, and review evaluation rubric template, your task is to combine the rubric template with the paper content and golden review, to form a complete review evaluation rubric set, so that it can be used to judge if other candidates reviews captures the key contents in the golden review and candidly reflect the understanding, strengths, and weaknesses of the paper.
493
+ The combined rubrics should be precise, comprehensive, and actionable.
494
+
495
+ ## Rubric Construction Rules
496
+
497
+ ### Total Items
498
+
499
+ - Rewrite each rubric item in the template as a self-contained evaluation criterion following the above areas and criteria, adhere to the content of the golden review dynamically. Keep the original weight of the rubric item.
500
+ - Do NOT add or remove any rubric item. Strictly adhere to the areas and criteria in the template.
501
+
502
+ ### Category Guidance
503
+
504
+ - **Essential:** Critical facts or safety checks; missing this invalidates the response.
505
+ - **Important:** Key reasoning, completeness, or clarity; strongly affects quality.
506
+ - **Optional:** Helpful stylistic or depth additions; not required
507
+ - **Pitfall:** Common important mistakes; each must begin with "Pitfall Criteria: Does not mention …" or "Pitfall Criteria: Recommends …"
508
+
509
+ ---
510
+
511
+ ## Output Requirements
512
+
513
+ - Provide a **JSON array** of rubric objects.
514
+ - Each object must contain **exactly three keys**:`{ "title": "...", "description": "...", "weight": ... }`
515
+ - No additional keys allowed.
516
+ - Do **not** copy large blocks of the question or reference answer.
517
+ - Every description must **start with its category prefix**.
518
+
519
+ Now, provided is the golden review for you to refer on:
520
+
521
+ <<golden_review>>
522
+
523
+ Following is the paper content for you to refer on:
524
+
525
+ <<paper_context>>
526
+
527
+ And below is the rubric template as json for you to refer on:
528
+
529
+ <<rubric_template>>
530
+
531
+ Please start your rubric generation following the above information provided and the instructions.
532
+
533
+ # Evaluator prompt (v1 - for evaluating reviews using rubrics)
534
+ v1_evaluator_prompt: |
535
+ You are an expert academic reviewer tasked with evaluating a research paper review following a list of rubrics.
536
+
537
+ Coupled with the paper content and the review, you need to score the review on each rubric and provide corresponding rationales.
538
+
539
+ The rubrics are as follows:
540
+
541
+ {rubrics_json}
542
+
543
+ The score should be in the range of -2 to 2, do NOT refer to the value of the weight when assigning the score during evaluation. Treat them as indications of whether this rubric is positive or negative.
544
+ If the weight is positive, this rubric is positive. If the weight is negative, this rubric is negative.
545
+
546
+ For each rubric:
547
+ - If this rubric is positive:
548
+ - If the review meet none of the key points in this rubric, assign 0.
549
+ - If the review meet at least half of the key points in this rubric, assign 1.
550
+ - If the review meet all of the key points in this rubric, assign 2.
551
+ - If this rubric is negative:
552
+ - If the review does NOT suffer from the pitfall (good), assign 0.
553
+ - If it DOES suffer any of the key points in this pitfall rubric, assign -1.
554
+ - If it DOES suffer all of the key points in this pitfall rubric, assign -2.
555
+
556
+ Your output format should be:
557
+ {{
558
+ "<first_rubric_title>": {{
559
+ "score": <-2 to 2>,
560
+ "rationale": "<rationale explaining the score>"
561
+ }},
562
+ ...
563
+ "<last_rubric_title>": {{
564
+ "score": <-2 to 2>,
565
+ "rationale": "<rationale explaining the score>"
566
+ }}
567
+ }}
568
+
569
+ DO NOT include any other text in your output. Output the JSON string ONLY.
570
+
571
+ Now, provided is the paper content:
572
+
573
+ <<paper_content>>
574
+
575
+ Now, provided is the review:
576
+
577
+ <<review>>
578
+
579
+ Please start your evaluation following the above information provided and the instructions.
580
+
shared/configs/reranker_endpoint_pool.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ http://localhost:8008
2
+ http://localhost:8009
3
+ http://localhost:8010
4
+ http://localhost:8011
5
+ http://localhost:8012
6
+ http://localhost:8013
7
+ http://localhost:8014
8
+ http://localhost:8015
shared/configs/vllm_endpoint_pool.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ http://localhost:8001/v1
2
+ http://localhost:8002/v1
3
+ http://localhost:8003/v1
4
+ http://localhost:8004/v1
5
+ http://localhost:8005/v1
6
+ http://localhost:8006/v1
7
+ http://localhost:8007/v1
shared/utils/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared utilities for the unified review system.
3
+ """
4
+ # Core utilities (always available)
5
+ from .json_parser import (
6
+ parse_review_markdown,
7
+ parse_keywords_json,
8
+ parse_summary_json,
9
+ parse_json_response,
10
+ )
11
+
12
+ from .prompt_loader import get_prompt_loader
13
+
14
+ # API Key Pool and Endpoint Pool (always available)
15
+ try:
16
+ from .asta_api_key_pool import AstaAPIKeyPool
17
+ _all_pools = ['AstaAPIKeyPool']
18
+ except ImportError:
19
+ AstaAPIKeyPool = None
20
+ _all_pools = []
21
+
22
+ try:
23
+ from .vllm_endpoint_pool import VLLMEndpointPool
24
+ _all_pools.append('VLLMEndpointPool')
25
+ except ImportError:
26
+ VLLMEndpointPool = None
27
+
28
+ if _all_pools:
29
+ __all__ = _all_pools
30
+
31
+ # Lazy imports for heavy dependencies (LLM-related)
32
+ # These may fail if dependencies are not installed, but that's okay
33
+ def _lazy_import_llm_services():
34
+ """Lazy import LLM services to avoid dependency issues"""
35
+ try:
36
+ from .llm_service import LLMService, ChatMessage
37
+ from .llm_service_factory import (
38
+ get_llm_service_factory,
39
+ LLMServiceFactory,
40
+ load_api_key_from_config,
41
+ )
42
+ return {
43
+ 'LLMService': LLMService,
44
+ 'ChatMessage': ChatMessage,
45
+ 'get_llm_service_factory': get_llm_service_factory,
46
+ 'LLMServiceFactory': LLMServiceFactory,
47
+ 'load_api_key_from_config': load_api_key_from_config,
48
+ }
49
+ except ImportError:
50
+ return {}
51
+
52
+ def _lazy_import_llm_implementations():
53
+ """Lazy import LLM service implementations"""
54
+ result = {}
55
+ try:
56
+ from .vllm_service import VLLMService
57
+ result['VLLMService'] = VLLMService
58
+ except ImportError:
59
+ pass
60
+
61
+ try:
62
+ from .gpt_service import GPTService
63
+ result['GPTService'] = GPTService
64
+ except ImportError:
65
+ pass
66
+
67
+ try:
68
+ from .mock_llm_service import MockLLMService, extract_title_from_latex, extract_abstract_from_latex
69
+ result['MockLLMService'] = MockLLMService
70
+ result['extract_title_from_latex'] = extract_title_from_latex
71
+ result['extract_abstract_from_latex'] = extract_abstract_from_latex
72
+ except ImportError:
73
+ pass
74
+
75
+ return result
76
+
77
+ def _lazy_import_other():
78
+ """Lazy import other utilities"""
79
+ result = {}
80
+ try:
81
+ from .reranker import rerank_paragraphs_bge
82
+ result['rerank_paragraphs_bge'] = rerank_paragraphs_bge
83
+ except ImportError:
84
+ pass
85
+
86
+ try:
87
+ from .review_logger import ReviewLogger
88
+ result['ReviewLogger'] = ReviewLogger
89
+ except ImportError:
90
+ pass
91
+
92
+ return result
93
+
94
+ # Populate __all__ dynamically
95
+ _llm_services = _lazy_import_llm_services()
96
+ _llm_impls = _lazy_import_llm_implementations()
97
+ _other = _lazy_import_other()
98
+
99
+ # Make all lazy imports available at module level
100
+ globals().update(_llm_services)
101
+ globals().update(_llm_impls)
102
+ globals().update(_other)
103
+
104
+ __all__ = [
105
+ 'parse_review_markdown',
106
+ 'parse_keywords_json',
107
+ 'parse_summary_json',
108
+ 'parse_json_response',
109
+ 'get_prompt_loader',
110
+ ] + list(_llm_services.keys()) + list(_llm_impls.keys()) + list(_other.keys())
111
+
112
+ if AstaAPIKeyPool:
113
+ __all__.append('AstaAPIKeyPool')
shared/utils/asta_api_key_pool.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Asta API Key Pool Manager
3
+
4
+ Manage multiple Asta API keys, implement key rotation and error handling.
5
+ """
6
+ import os
7
+ import random
8
+ import time
9
+ from pathlib import Path
10
+ from typing import List, Optional, Dict
11
+ from threading import Lock
12
+
13
+
14
+ class AstaAPIKeyPool:
15
+ """
16
+ Asta API Key Pool Manager
17
+
18
+ Features:
19
+ 1. Load multiple API keys from file
20
+ 2. Randomly rotate keys
21
+ 3. Track each key's usage status and errors
22
+ 4. Implement debounce retry strategy
23
+ """
24
+
25
+ def __init__(self, pool_path: Optional[str] = None, keys: Optional[List[str]] = None):
26
+ """
27
+ Initialize API Key Pool
28
+
29
+ Args:
30
+ pool_path: API keys file path (one key per line)
31
+ keys: directly provide keys list (prior to pool_path)
32
+ """
33
+ self.keys: List[str] = []
34
+ self.used_indices: List[int] = [] # indices used in current rotation
35
+ self.key_status: Dict[str, Dict] = {} # status information for each key
36
+ self.lock = Lock() # thread safe lock
37
+
38
+ # load keys
39
+ if keys:
40
+ self.keys = [k.strip() for k in keys if k.strip()]
41
+ elif os.environ.get("ASTA_API_KEY"):
42
+ # Try to get one or more keys from environment variable (comma-separated)
43
+ self.keys = [k.strip() for k in os.environ.get("ASTA_API_KEY").split(",") if k.strip()]
44
+ elif pool_path:
45
+ self._load_from_file(pool_path)
46
+ else:
47
+ raise ValueError(
48
+ "No API keys available. Provide keys via pool_path, keys parameter, "
49
+ "or ASTA_API_KEY environment variable."
50
+ )
51
+
52
+ if not self.keys:
53
+ raise ValueError(
54
+ "No API keys available. Provide keys via pool_path, keys parameter, "
55
+ "or ASTA_API_KEY environment variable."
56
+ )
57
+
58
+ # initialize status for each key
59
+ for key in self.keys:
60
+ self.key_status[key] = {
61
+ 'error_count': 0,
62
+ 'last_error_time': None,
63
+ 'consecutive_errors': 0,
64
+ 'total_requests': 0,
65
+ 'successful_requests': 0,
66
+ }
67
+
68
+ def _load_from_file(self, pool_path: str):
69
+ """Load API keys from file"""
70
+ path = Path(pool_path)
71
+
72
+ # if relative path, try to find file relative to shared/configs
73
+ if not path.is_absolute():
74
+ # try to find file relative to project root
75
+ project_root = Path(__file__).parent.parent.parent
76
+ path = project_root / "shared" / "configs" / pool_path
77
+ if not path.exists():
78
+ # try to find file relative to shared/configs
79
+ path = Path(__file__).parent.parent / "configs" / pool_path
80
+
81
+ if not path.exists():
82
+ raise FileNotFoundError(
83
+ f"API key pool file not found: {pool_path} (tried: {path})"
84
+ )
85
+
86
+ with open(path, 'r', encoding='utf-8') as f:
87
+ lines = f.readlines()
88
+
89
+ self.keys = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
90
+
91
+ if not self.keys:
92
+ raise ValueError(f"No valid API keys found in pool file: {pool_path}")
93
+
94
+ def get_key(self) -> str:
95
+ """
96
+ Get next available API key (rotation strategy)
97
+
98
+ Strategy:
99
+ 1. If current rotation is not complete, continue using unused keys
100
+ 2. If current rotation is complete, start a new round (reset used_indices)
101
+ 3. Prioritize keys with no recent errors
102
+
103
+ Returns:
104
+ Available API key
105
+ """
106
+ with self.lock:
107
+ if not self.keys:
108
+ raise ValueError("No API keys available in pool")
109
+
110
+ # if current rotation is complete, start a new round
111
+ if len(self.used_indices) >= len(self.keys):
112
+ self.used_indices = []
113
+
114
+ # get indices not used in current rotation
115
+ available_indices = [i for i in range(len(self.keys)) if i not in self.used_indices]
116
+
117
+ if not available_indices:
118
+ # all keys are used in current rotation, start a new round
119
+ available_indices = list(range(len(self.keys)))
120
+ self.used_indices = []
121
+
122
+ # prioritize keys with fewer errors (randomly select, but prioritize keys with higher success rate and fewer errors)
123
+ key_scores = []
124
+ for idx in available_indices:
125
+ key = self.keys[idx]
126
+ status = self.key_status[key]
127
+
128
+ # calculate score: error count, success rate, score越高
129
+ error_count = status['error_count']
130
+ total = status['total_requests']
131
+ success_rate = (status['successful_requests'] / total) if total > 0 else 1.0
132
+
133
+ # if recent error, reduce score
134
+ recent_error_penalty = 0
135
+ if status['last_error_time']:
136
+ time_since_error = time.time() - status['last_error_time']
137
+ if time_since_error < 60: # 1 minute
138
+ recent_error_penalty = 0.5
139
+
140
+ score = success_rate - (error_count * 0.1) - recent_error_penalty
141
+ key_scores.append((idx, score))
142
+
143
+ # sort by score, select highest score (but add some randomness)
144
+ key_scores.sort(key=lambda x: x[1], reverse=True)
145
+
146
+ # select from top 50% (add randomness but prioritize better keys)
147
+ top_n = max(1, len(key_scores) // 2) if len(key_scores) > 1 else 1
148
+ selected_idx, _ = random.choice(key_scores[:top_n])
149
+
150
+ # mark as used
151
+ self.used_indices.append(selected_idx)
152
+
153
+ selected_key = self.keys[selected_idx]
154
+ self.key_status[selected_key]['total_requests'] += 1
155
+
156
+ return selected_key
157
+
158
+ def mark_success(self, key: str):
159
+ """mark key as successful"""
160
+ with self.lock:
161
+ if key in self.key_status:
162
+ self.key_status[key]['successful_requests'] += 1
163
+ self.key_status[key]['consecutive_errors'] = 0
164
+
165
+ def mark_error(self, key: str, error_type: str = "rate_limit"):
166
+ """
167
+ mark key as failed
168
+
169
+ Args:
170
+ key: failed API key
171
+ error_type: error type ("rate_limit", "auth_error", "server_error", "other")
172
+ """
173
+ with self.lock:
174
+ if key in self.key_status:
175
+ status = self.key_status[key]
176
+ status['error_count'] += 1
177
+ status['consecutive_errors'] += 1
178
+ status['last_error_time'] = time.time()
179
+
180
+ def get_status(self) -> Dict:
181
+ """get pool status information (for debugging)"""
182
+ with self.lock:
183
+ return {
184
+ 'total_keys': len(self.keys),
185
+ 'current_round_progress': f"{len(self.used_indices)}/{len(self.keys)}",
186
+ 'keys_status': {
187
+ key: {
188
+ 'error_count': status['error_count'],
189
+ 'successful_requests': status['successful_requests'],
190
+ 'total_requests': status['total_requests'],
191
+ 'success_rate': (
192
+ status['successful_requests'] / status['total_requests']
193
+ if status['total_requests'] > 0 else 0.0
194
+ ),
195
+ 'consecutive_errors': status['consecutive_errors'],
196
+ 'last_error_time': status['last_error_time'],
197
+ }
198
+ for key, status in self.key_status.items()
199
+ }
200
+ }
201
+
202
+ def reset_round(self):
203
+ """reset current rotation (force start a new round)"""
204
+ with self.lock:
205
+ self.used_indices = []
shared/utils/gpt_service.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI GPT API service implementation
3
+ """
4
+ import os
5
+ from typing import List, Dict, Optional, Any, Union
6
+ from openai import OpenAI
7
+ from .llm_service import LLMService, ChatMessage
8
+
9
+
10
+ class GPTService(LLMService):
11
+ """
12
+ OpenAI GPT API service wrapper
13
+
14
+ This service connects to OpenAI's API (or compatible API)
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ api_key: Optional[str] = None,
20
+ model_name: str = "gpt-4o",
21
+ base_url: Optional[str] = None,
22
+ timeout: int = 300,
23
+ ):
24
+ """
25
+ Initialize GPT service
26
+
27
+ Args:
28
+ api_key: OpenAI / OpenRouter API key (set via env if omitted)
29
+ model_name: Model name (e.g., gpt-4o, openai/gpt-oss-120b:free, etc.)
30
+ base_url: API base URL (default: https://api.openai.com/v1)
31
+ timeout: Request timeout in seconds
32
+ """
33
+ # Prefer explicit parameter, then common environment variables.
34
+ # This allows using OpenRouter (OPENROUTER_API_KEY) without hard-coding secrets.
35
+ self.api_key = (
36
+ api_key
37
+ or os.environ.get("OPENAI_API_KEY")
38
+ or os.environ.get("OPENROUTER_API_KEY")
39
+ )
40
+ if not self.api_key:
41
+ raise ValueError(
42
+ "API key is required. Set OPENAI_API_KEY or OPENROUTER_API_KEY "
43
+ "environment variable, or pass api_key parameter."
44
+ )
45
+
46
+ self.model_name = model_name
47
+ # Prefer explicit base_url, then environment variables, then OpenAI default.
48
+ # This allows swapping in any OpenAI-compatible endpoint (e.g., OpenRouter)
49
+ # without changing code.
50
+ self.base_url = (
51
+ base_url
52
+ or os.environ.get("OPENAI_BASE_URL")
53
+ or os.environ.get("OPENROUTER_BASE_URL")
54
+ or "https://api.openai.com/v1"
55
+ )
56
+ self.timeout = timeout
57
+
58
+ self.client = OpenAI(
59
+ api_key=self.api_key,
60
+ base_url=self.base_url,
61
+ timeout=self.timeout,
62
+ )
63
+
64
+ def _format_messages(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> List[Dict[str, str]]:
65
+ """Format messages for OpenAI API"""
66
+ formatted = []
67
+ for msg in messages:
68
+ if isinstance(msg, ChatMessage):
69
+ formatted.append({"role": msg.role, "content": msg.content})
70
+ elif isinstance(msg, dict):
71
+ formatted.append(msg)
72
+ else:
73
+ raise ValueError(f"Invalid message type: {type(msg)}")
74
+ return formatted
75
+
76
+ def generate(
77
+ self,
78
+ messages: List[Union[ChatMessage, Dict[str, str]]],
79
+ temperature: float = 0.7,
80
+ top_p: float = 0.95,
81
+ top_k: int = 20,
82
+ max_tokens: int = 16384,
83
+ presence_penalty: float = 0.0,
84
+ **kwargs
85
+ ) -> str:
86
+ """
87
+ Generate text from messages
88
+
89
+ Args:
90
+ messages: List of chat messages
91
+ temperature: Sampling temperature
92
+ top_p: Top-p sampling parameter
93
+ top_k: Top-k sampling parameter (not used by GPT API, but kept for compatibility)
94
+ max_tokens: Maximum tokens to generate
95
+ presence_penalty: Presence penalty (0-2)
96
+ **kwargs: Additional parameters
97
+
98
+ Returns:
99
+ Generated text
100
+ """
101
+ formatted_messages = self._format_messages(messages)
102
+
103
+ try:
104
+ # GPT API doesn't support top_k, so we exclude it
105
+ # Some newer models (like GPT 5.2) use max_completion_tokens instead of max_tokens
106
+ params = {
107
+ "model": self.model_name,
108
+ "messages": formatted_messages,
109
+ "temperature": temperature,
110
+ "top_p": top_p,
111
+ "presence_penalty": presence_penalty,
112
+ }
113
+
114
+ # Check if model requires max_completion_tokens instead of max_tokens
115
+ # Models that use max_completion_tokens: o1, o1-preview, o1-mini, and newer models
116
+ if any(model_name in self.model_name.lower() for model_name in ["o1", "gpt-5", "gpt5"]):
117
+ params["max_completion_tokens"] = max_tokens
118
+ else:
119
+ params["max_tokens"] = max_tokens
120
+
121
+ params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
122
+
123
+ response = self.client.chat.completions.create(**params)
124
+
125
+ return response.choices[0].message.content
126
+
127
+ except Exception as e:
128
+ # If max_tokens fails, try max_completion_tokens as fallback
129
+ if "max_tokens" in str(e) and "max_completion_tokens" in str(e):
130
+ try:
131
+ params = {
132
+ "model": self.model_name,
133
+ "messages": formatted_messages,
134
+ "temperature": temperature,
135
+ "top_p": top_p,
136
+ "max_completion_tokens": max_tokens,
137
+ "presence_penalty": presence_penalty,
138
+ }
139
+ params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
140
+ response = self.client.chat.completions.create(**params)
141
+ return response.choices[0].message.content
142
+ except Exception as e2:
143
+ raise RuntimeError(f"Error generating text from GPT service: {e2}")
144
+ raise RuntimeError(f"Error generating text from GPT service: {e}")
145
+
146
+ def stream_generate(
147
+ self,
148
+ messages: List[Union[ChatMessage, Dict[str, str]]],
149
+ temperature: float = 0.7,
150
+ top_p: float = 0.95,
151
+ top_k: int = 20,
152
+ max_tokens: int = 16384,
153
+ presence_penalty: float = 0.0,
154
+ **kwargs
155
+ ):
156
+ """
157
+ Stream generate text from messages
158
+
159
+ Yields:
160
+ Generated text chunks
161
+ """
162
+ formatted_messages = self._format_messages(messages)
163
+
164
+ try:
165
+ params = {
166
+ "model": self.model_name,
167
+ "messages": formatted_messages,
168
+ "temperature": temperature,
169
+ "top_p": top_p,
170
+ "presence_penalty": presence_penalty,
171
+ "stream": True,
172
+ }
173
+
174
+ # Check if model requires max_completion_tokens instead of max_tokens
175
+ if any(model_name in self.model_name.lower() for model_name in ["o1", "gpt-5", "gpt5"]):
176
+ params["max_completion_tokens"] = max_tokens
177
+ else:
178
+ params["max_tokens"] = max_tokens
179
+
180
+ params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
181
+
182
+ stream = self.client.chat.completions.create(**params)
183
+
184
+ for chunk in stream:
185
+ if chunk.choices[0].delta.content:
186
+ yield chunk.choices[0].delta.content
187
+
188
+ except Exception as e:
189
+ # If max_tokens fails, try max_completion_tokens as fallback
190
+ if "max_tokens" in str(e) and "max_completion_tokens" in str(e):
191
+ try:
192
+ params = {
193
+ "model": self.model_name,
194
+ "messages": formatted_messages,
195
+ "temperature": temperature,
196
+ "top_p": top_p,
197
+ "max_completion_tokens": max_tokens,
198
+ "presence_penalty": presence_penalty,
199
+ "stream": True,
200
+ }
201
+ params.update({k: v for k, v in kwargs.items() if k not in ["top_k", "max_tokens", "max_completion_tokens"]})
202
+ stream = self.client.chat.completions.create(**params)
203
+ for chunk in stream:
204
+ if chunk.choices[0].delta.content:
205
+ yield chunk.choices[0].delta.content
206
+ return
207
+ except Exception as e2:
208
+ raise RuntimeError(f"Error streaming text from GPT service: {e2}")
209
+ raise RuntimeError(f"Error streaming text from GPT service: {e}")
210
+
shared/utils/json_parser.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Robust JSON parsing utilities for LLM responses
3
+ """
4
+ import json
5
+ import re
6
+ from typing import Any, Dict, List, Optional
7
+
8
+
9
+ def extract_json_from_text(text: str) -> Optional[str]:
10
+ """
11
+ Extract JSON from text by removing markdown code block markers
12
+
13
+ Args:
14
+ text: Text that may contain JSON in markdown code blocks or plain JSON
15
+
16
+ Returns:
17
+ Extracted JSON string or None if not found
18
+ """
19
+ if not text:
20
+ return None
21
+
22
+ text_stripped = text.strip()
23
+
24
+ # Try to parse as plain JSON first (no code blocks)
25
+ try:
26
+ json.loads(text_stripped)
27
+ return text_stripped
28
+ except json.JSONDecodeError:
29
+ pass
30
+
31
+ # Remove markdown code block markers: ```json ... ``` or ``` ... ```
32
+ if text_stripped.startswith('```json'):
33
+ # Remove ```json at start and ``` at end
34
+ if text_stripped.endswith('```'):
35
+ text_stripped = text_stripped[7:-3].strip()
36
+ else:
37
+ # No closing ```, just remove opening
38
+ text_stripped = text_stripped[7:].strip()
39
+ elif text_stripped.startswith('```'):
40
+ # Handle ``` ... ``` (without json label)
41
+ if text_stripped.endswith('```'):
42
+ text_stripped = text_stripped[3:-3].strip()
43
+ else:
44
+ return None
45
+
46
+ # Try to parse as JSON after removing code block markers
47
+ try:
48
+ json.loads(text_stripped)
49
+ return text_stripped
50
+ except json.JSONDecodeError:
51
+ return None
52
+
53
+
54
+ def parse_json_response(text: str, fallback: Any = None) -> Any:
55
+ """
56
+ Parse JSON from LLM response with robust error handling
57
+
58
+ Args:
59
+ text: LLM response text
60
+ fallback: Fallback value if parsing fails
61
+
62
+ Returns:
63
+ Parsed JSON object or fallback
64
+ """
65
+ if not text:
66
+ return fallback
67
+
68
+ # Extract JSON from text
69
+ json_str = extract_json_from_text(text)
70
+
71
+ if json_str is None:
72
+ return fallback
73
+
74
+ try:
75
+ return json.loads(json_str)
76
+ except json.JSONDecodeError as e:
77
+ # Try to fix common JSON issues
78
+ json_str = fix_json_common_issues(json_str)
79
+ try:
80
+ return json.loads(json_str)
81
+ except json.JSONDecodeError:
82
+ return fallback
83
+
84
+
85
+ def fix_json_common_issues(json_str: str) -> str:
86
+ """
87
+ Fix common JSON formatting issues
88
+
89
+ Args:
90
+ json_str: JSON string that may have issues
91
+
92
+ Returns:
93
+ Fixed JSON string
94
+ """
95
+ # Remove trailing commas
96
+ json_str = re.sub(r',\s*}', '}', json_str)
97
+ json_str = re.sub(r',\s*]', ']', json_str)
98
+
99
+ # Fix single quotes to double quotes (basic)
100
+ json_str = re.sub(r"'(\w+)':", r'"\1":', json_str)
101
+
102
+ # Remove comments (basic)
103
+ json_str = re.sub(r'//.*?$', '', json_str, flags=re.MULTILINE)
104
+ json_str = re.sub(r'/\*.*?\*/', '', json_str, flags=re.DOTALL)
105
+
106
+ return json_str
107
+
108
+
109
+ def parse_keywords_json(response: str) -> List[str]:
110
+ """
111
+ Parse keywords from JSON response
112
+
113
+ Expected format:
114
+ {"keywords": ["keyword1", "keyword2", ...]}
115
+ or
116
+ ["keyword1", "keyword2", ...]
117
+
118
+ Args:
119
+ response: LLM response text
120
+
121
+ Returns:
122
+ List of keywords, or empty list if parsing fails
123
+ """
124
+ if response is None:
125
+ return []
126
+
127
+ parsed = parse_json_response(response, fallback=None)
128
+
129
+ if parsed is None:
130
+ return []
131
+
132
+ # Handle dict format: {"keywords": [...]}
133
+ if isinstance(parsed, dict):
134
+ if "keywords" in parsed and isinstance(parsed["keywords"], list):
135
+ return parsed["keywords"][:5]
136
+ return []
137
+
138
+ # Handle list format: ["keyword1", "keyword2", ...]
139
+ if isinstance(parsed, list):
140
+ return parsed[:5]
141
+
142
+ return []
143
+
144
+
145
+ def parse_summary_json(response: str) -> str:
146
+ """
147
+ Parse summary from JSON response
148
+
149
+ Expected format:
150
+ {"summary": "summary text"}
151
+ or
152
+ {"text": "summary text", "summary": "summary text"}
153
+
154
+ Args:
155
+ response: LLM response text
156
+
157
+ Returns:
158
+ Summary text
159
+ """
160
+ parsed = parse_json_response(response, fallback=None)
161
+
162
+ if parsed is None:
163
+ # Fallback to text parsing
164
+ return response.strip()
165
+
166
+ if isinstance(parsed, dict):
167
+ # Try different possible keys
168
+ for key in ["summary", "text", "content", "description"]:
169
+ if key in parsed:
170
+ summary = str(parsed[key]).strip()
171
+ if summary:
172
+ return summary
173
+
174
+ # Fallback to text parsing
175
+ return response.strip()
176
+
177
+
178
+ def parse_review_json(response: str, review_format: str = "detailed") -> Dict[str, Any]:
179
+ """
180
+ Parse review from JSON or markdown response
181
+
182
+ Expected formats:
183
+ - JSON: {"summary": "...", "soundness": 5, ...}
184
+ - Markdown: ## Summary\n\n...\n## Soundness\n\n...
185
+
186
+ Args:
187
+ response: LLM response text (JSON or markdown)
188
+ review_format: Review format type (detailed, summary, structured)
189
+
190
+ Returns:
191
+ Review dictionary with parsed fields
192
+ """
193
+ # First try to parse as JSON
194
+ parsed = parse_json_response(response, fallback=None)
195
+
196
+ if parsed is not None and isinstance(parsed, dict):
197
+ # JSON format - ensure it has required fields
198
+ if "review" not in parsed:
199
+ parsed["review"] = response.strip()
200
+ return parsed
201
+
202
+ # If not JSON, try to parse as markdown
203
+ if "## " in response or "##" in response:
204
+ markdown_parsed = parse_review_markdown(response)
205
+ if len(markdown_parsed) > 1: # More than just "review" field
206
+ return markdown_parsed
207
+
208
+ # Fallback to text parsing
209
+ return {"review": response.strip()}
210
+
211
+
212
+ def parse_review_markdown(markdown_text: str) -> Dict[str, Any]:
213
+ """
214
+ Parse review from markdown format with sections like:
215
+ ## Summary
216
+ ...
217
+ ## Soundness
218
+ ...
219
+ etc.
220
+
221
+ Args:
222
+ markdown_text: Markdown formatted review text
223
+
224
+ Returns:
225
+ Review dictionary with parsed fields
226
+ """
227
+ review_dict = {"review": markdown_text.strip()}
228
+
229
+ # Pattern to match markdown sections: ## SectionName\n\ncontent
230
+ section_pattern = r'##\s*([^\n]+)\s*\n\n(.*?)(?=\n##\s*|$)'
231
+ matches = re.finditer(section_pattern, markdown_text, re.DOTALL)
232
+
233
+ for match in matches:
234
+ section_name = match.group(1).strip()
235
+ section_content = match.group(2).strip()
236
+
237
+ # Normalize section name (case-insensitive, remove extra spaces)
238
+ section_name_lower = section_name.lower()
239
+
240
+ # Map section names to dictionary keys
241
+ if "summary" in section_name_lower:
242
+ review_dict["summary"] = section_content
243
+ elif "soundness" in section_name_lower:
244
+ # Extract score - prioritize single float number (e.g., "3.0", "4.5")
245
+ # If format is "3 / 5" or "**3 / 5**", extract the number before the slash
246
+ score_val = None
247
+
248
+ lines = section_content.split('\n')
249
+ if lines:
250
+ first_line = lines[0].strip()
251
+ first_line_clean = re.sub(r'[`\*]', '', first_line)
252
+
253
+ # Try to match number at start that's NOT followed by "/"
254
+ num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
255
+ if num_match:
256
+ remaining = first_line_clean[len(num_match.group(0)):].strip()
257
+ if not remaining.startswith('/'):
258
+ try:
259
+ score_val = float(num_match.group(1))
260
+ except (ValueError, IndexError):
261
+ pass
262
+
263
+ # If not found and there's a "/", try to extract number before "/" (e.g., "3 / 5" -> 3)
264
+ if score_val is None and '/' in first_line_clean:
265
+ fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
266
+ if fraction_match:
267
+ try:
268
+ score_val = float(fraction_match.group(1))
269
+ except (ValueError, IndexError):
270
+ pass
271
+
272
+ # If not found, try to find number after "score:" or "rating:"
273
+ if score_val is None:
274
+ score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
275
+ if score_match:
276
+ try:
277
+ score_val = float(score_match.group(1))
278
+ except (ValueError, IndexError):
279
+ pass
280
+
281
+ if score_val is not None:
282
+ review_dict["soundness"] = score_val # Keep as float
283
+ elif "presentation" in section_name_lower:
284
+ score_val = None
285
+ lines = section_content.split('\n')
286
+ if lines:
287
+ first_line = lines[0].strip()
288
+ first_line_clean = re.sub(r'[`\*]', '', first_line)
289
+
290
+ num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
291
+ if num_match:
292
+ remaining = first_line_clean[len(num_match.group(0)):].strip()
293
+ if not remaining.startswith('/'):
294
+ try:
295
+ score_val = float(num_match.group(1))
296
+ except (ValueError, IndexError):
297
+ pass
298
+
299
+ if score_val is None and '/' in first_line_clean:
300
+ fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
301
+ if fraction_match:
302
+ try:
303
+ score_val = float(fraction_match.group(1))
304
+ except (ValueError, IndexError):
305
+ pass
306
+
307
+ if score_val is None:
308
+ score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
309
+ if score_match:
310
+ try:
311
+ score_val = float(score_match.group(1))
312
+ except (ValueError, IndexError):
313
+ pass
314
+
315
+ if score_val is not None:
316
+ review_dict["presentation"] = score_val
317
+ elif "contribution" in section_name_lower:
318
+ score_val = None
319
+ lines = section_content.split('\n')
320
+ if lines:
321
+ first_line = lines[0].strip()
322
+ first_line_clean = re.sub(r'[`\*]', '', first_line)
323
+
324
+ num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
325
+ if num_match:
326
+ remaining = first_line_clean[len(num_match.group(0)):].strip()
327
+ if not remaining.startswith('/'):
328
+ try:
329
+ score_val = float(num_match.group(1))
330
+ except (ValueError, IndexError):
331
+ pass
332
+
333
+ if score_val is None and '/' in first_line_clean:
334
+ fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
335
+ if fraction_match:
336
+ try:
337
+ score_val = float(fraction_match.group(1))
338
+ except (ValueError, IndexError):
339
+ pass
340
+
341
+ if score_val is None:
342
+ score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
343
+ if score_match:
344
+ try:
345
+ score_val = float(score_match.group(1))
346
+ except (ValueError, IndexError):
347
+ pass
348
+
349
+ if score_val is not None:
350
+ review_dict["contribution"] = score_val
351
+ elif "strength" in section_name_lower:
352
+ review_dict["strengths"] = section_content
353
+ elif "weakness" in section_name_lower:
354
+ review_dict["weaknesses"] = section_content
355
+ elif "question" in section_name_lower:
356
+ review_dict["questions"] = section_content
357
+ elif "rating" in section_name_lower and "confidence" not in section_name_lower:
358
+ score_val = None
359
+ lines = section_content.split('\n')
360
+ if lines:
361
+ first_line = lines[0].strip()
362
+ first_line_clean = re.sub(r'[`\*]', '', first_line)
363
+
364
+ num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
365
+ if num_match:
366
+ remaining = first_line_clean[len(num_match.group(0)):].strip()
367
+ if not remaining.startswith('/'):
368
+ try:
369
+ score_val = float(num_match.group(1))
370
+ except (ValueError, IndexError):
371
+ pass
372
+
373
+ if score_val is None and '/' in first_line_clean:
374
+ fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
375
+ if fraction_match:
376
+ try:
377
+ score_val = float(fraction_match.group(1))
378
+ except (ValueError, IndexError):
379
+ pass
380
+
381
+ if score_val is None:
382
+ score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
383
+ if score_match:
384
+ try:
385
+ score_val = float(score_match.group(1))
386
+ except (ValueError, IndexError):
387
+ pass
388
+
389
+ if score_val is not None:
390
+ review_dict["rating"] = score_val
391
+ elif "confidence" in section_name_lower:
392
+ score_val = None
393
+ lines = section_content.split('\n')
394
+ if lines:
395
+ first_line = lines[0].strip()
396
+ first_line_clean = re.sub(r'[`\*]', '', first_line)
397
+
398
+ num_match = re.match(r'^(\d+\.?\d*)(\s*)', first_line_clean)
399
+ if num_match:
400
+ remaining = first_line_clean[len(num_match.group(0)):].strip()
401
+ if not remaining.startswith('/'):
402
+ try:
403
+ score_val = float(num_match.group(1))
404
+ except (ValueError, IndexError):
405
+ pass
406
+
407
+ if score_val is None and '/' in first_line_clean:
408
+ fraction_match = re.match(r'^\s*[`\*]*\s*(\d+\.?\d*)\s*[`\*]*\s*/\s*\d+', first_line_clean)
409
+ if fraction_match:
410
+ try:
411
+ score_val = float(fraction_match.group(1))
412
+ except (ValueError, IndexError):
413
+ pass
414
+
415
+ if score_val is None:
416
+ score_match = re.search(r'(?:score|rating)\s*[:=]\s*(\d+\.?\d*)', section_content, re.IGNORECASE)
417
+ if score_match:
418
+ try:
419
+ score_val = float(score_match.group(1))
420
+ except (ValueError, IndexError):
421
+ pass
422
+
423
+ if score_val is not None:
424
+ review_dict["confidence"] = score_val
425
+ elif "decision" in section_name_lower:
426
+ review_dict["decision"] = section_content
427
+
428
+ return review_dict
shared/utils/llm_service.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for LLM services
3
+ """
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Dict, Optional, Any, Union
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class ChatMessage(BaseModel):
10
+ """Chat message model"""
11
+ role: str # "system", "user", "assistant"
12
+ content: str
13
+
14
+
15
+ class LLMService(ABC):
16
+ """Abstract base class for LLM services"""
17
+
18
+ @abstractmethod
19
+ def generate(
20
+ self,
21
+ messages: List[Union[ChatMessage, Dict[str, str]]],
22
+ temperature: float = 0.7,
23
+ top_p: float = 0.8,
24
+ top_k: int = 20,
25
+ max_tokens: int = 16384,
26
+ presence_penalty: float = 0.0,
27
+ **kwargs
28
+ ) -> str:
29
+ """
30
+ Generate text from messages
31
+
32
+ Args:
33
+ messages: List of chat messages
34
+ temperature: Sampling temperature
35
+ top_p: Top-p sampling parameter
36
+ top_k: Top-k sampling parameter
37
+ max_tokens: Maximum tokens to generate
38
+ presence_penalty: Presence penalty (0-2)
39
+ **kwargs: Additional parameters
40
+
41
+ Returns:
42
+ Generated text
43
+ """
44
+ pass
45
+
46
+ @abstractmethod
47
+ def stream_generate(
48
+ self,
49
+ messages: List[Union[ChatMessage, Dict[str, str]]],
50
+ temperature: float = 0.7,
51
+ top_p: float = 0.8,
52
+ top_k: int = 20,
53
+ max_tokens: int = 16384,
54
+ presence_penalty: float = 0.0,
55
+ **kwargs
56
+ ):
57
+ """
58
+ Stream generate text from messages
59
+
60
+ Yields:
61
+ Generated text chunks
62
+ """
63
+ pass
64
+
shared/utils/llm_service_factory.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Factory for creating LLM services from configuration
3
+ """
4
+ import os
5
+ import yaml
6
+ from pathlib import Path
7
+ from typing import Optional, Dict, Any
8
+ from .llm_service import LLMService
9
+ from .vllm_service import VLLMService
10
+ from .gpt_service import GPTService
11
+
12
+
13
+ class LLMServiceFactory:
14
+ """Factory for creating LLM services from configuration"""
15
+
16
+ def __init__(self, config_file: Optional[str] = None):
17
+ """
18
+ Initialize factory with configuration
19
+
20
+ Args:
21
+ config_file: Path to vLLM service config YAML file
22
+ """
23
+ if config_file is None:
24
+ project_root = Path(__file__).parent.parent.parent
25
+ config_file = project_root / "shared" / "configs" / "llm_service_config.yaml"
26
+
27
+ self.config_file = Path(config_file)
28
+ self._config = None
29
+ self._load_config()
30
+
31
+ def _load_config(self):
32
+ """Load configuration from YAML file"""
33
+ if not self.config_file.exists():
34
+ raise FileNotFoundError(f"Config file not found: {self.config_file}")
35
+
36
+ with open(self.config_file, 'r', encoding='utf-8') as f:
37
+ self._config = yaml.safe_load(f)
38
+
39
+ def create_vllm_service(self, **override_params) -> VLLMService:
40
+ """
41
+ Create vLLM service from configuration
42
+
43
+ Args:
44
+ **override_params: Parameters to override config values
45
+
46
+ Returns:
47
+ VLLMService instance
48
+ """
49
+ vllm_config = self._config.get("vllm", {})
50
+
51
+ # Merge config with overrides
52
+ params = {
53
+ "base_url": vllm_config.get("base_url", "http://localhost:8000/v1"),
54
+ "api_key": vllm_config.get("api_key", "dummy-key"),
55
+ "model_name": vllm_config.get("model_name", "Qwen/Qwen3-4B-Instruct-2507"),
56
+ "timeout": vllm_config.get("timeout", 300),
57
+ }
58
+ params.update(override_params)
59
+
60
+ return VLLMService(**params)
61
+
62
+ def create_gpt_service(self, **override_params) -> GPTService:
63
+ """
64
+ Create GPT service from configuration
65
+
66
+ Args:
67
+ **override_params: Parameters to override config values
68
+
69
+ Returns:
70
+ GPTService instance
71
+ """
72
+ gpt_config = self._config.get("gpt", {})
73
+
74
+ # if not gpt_config.get("enabled", False):
75
+ # raise ValueError("GPT service is not enabled in configuration")
76
+
77
+ # Merge config with overrides
78
+ params = {
79
+ "api_key": gpt_config.get("api_key") or os.environ.get("OPENAI_API_KEY"),
80
+ "model_name": gpt_config.get("model_name", "gpt-4o"),
81
+ "base_url": gpt_config.get("base_url"),
82
+ "timeout": gpt_config.get("timeout", 300),
83
+ }
84
+ params.update(override_params)
85
+
86
+ return GPTService(**params)
87
+
88
+ def create_service(self, service_type: str = "vllm", **override_params) -> LLMService:
89
+ """
90
+ Create LLM service by type
91
+
92
+ Args:
93
+ service_type: Service type ("vllm" or "gpt")
94
+ **override_params: Parameters to override config values
95
+
96
+ Returns:
97
+ LLMService instance
98
+ """
99
+ if service_type == "vllm":
100
+ return self.create_vllm_service(**override_params)
101
+ elif service_type == "gpt":
102
+ return self.create_gpt_service(**override_params)
103
+ else:
104
+ raise ValueError(f"Unknown service type: {service_type}")
105
+
106
+ def get_llm_assignment(self, component: str) -> str:
107
+ """
108
+ Get LLM service assignment for a component
109
+
110
+ Args:
111
+ component: Component name ("keyword_generator", "paper_summarizer", "reviewer", "refiner")
112
+
113
+ Returns:
114
+ Service type ("vllm" or "gpt")
115
+
116
+ Raises:
117
+ KeyError: If component is not found and no fallback is available
118
+ """
119
+ assignments = self._config.get("llm_assignments", {})
120
+ if component in assignments:
121
+ return assignments[component]
122
+
123
+ # Fallback: if refiner not configured, use reviewer's assignment
124
+ if component == "refiner" and "reviewer" in assignments:
125
+ return assignments["reviewer"]
126
+
127
+ # Default fallback to vllm (may cause connection errors if vllm not available)
128
+ return assignments.get(component, "vllm")
129
+
130
+ def create_service_for_component(self, component: str, **override_params) -> LLMService:
131
+ """
132
+ Create LLM service for a specific component based on configuration
133
+
134
+ Args:
135
+ component: Component name ("keyword_generator", "paper_summarizer", "reviewer")
136
+ **override_params: Parameters to override config values
137
+
138
+ Returns:
139
+ LLMService instance
140
+ """
141
+ service_type = self.get_llm_assignment(component)
142
+ return self.create_service(service_type, **override_params)
143
+
144
+
145
+ # Global factory instance
146
+ _factory: Optional[LLMServiceFactory] = None
147
+
148
+
149
+ def load_api_key_from_config(config_path: str) -> Optional[str]:
150
+ """
151
+ Load API key from a YAML config file.
152
+
153
+ Args:
154
+ config_path: Path to YAML config file
155
+
156
+ Returns:
157
+ API key string, or None if not found
158
+
159
+ Note:
160
+ Returns None (instead of raising) if file doesn't exist or key not found,
161
+ to allow graceful fallback to environment variables.
162
+ """
163
+ from pathlib import Path
164
+
165
+ config_file = Path(config_path)
166
+ if not config_file.exists():
167
+ return None
168
+
169
+ try:
170
+ with open(config_file, 'r', encoding='utf-8') as f:
171
+ config = yaml.safe_load(f)
172
+ return config.get('api_key')
173
+ except Exception:
174
+ return None
175
+
176
+
177
+ def get_llm_service_factory(config_file: Optional[str] = None) -> LLMServiceFactory:
178
+ """
179
+ Get or create global LLM service factory
180
+
181
+ Args:
182
+ config_file: Optional path to config file
183
+
184
+ Returns:
185
+ LLMServiceFactory instance
186
+ """
187
+ global _factory
188
+ if _factory is None or config_file is not None:
189
+ _factory = LLMServiceFactory(config_file)
190
+ return _factory
191
+
shared/utils/load_balancer.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Python Load Balancer
3
+
4
+ A lightweight load balancer for vLLM and Reranker services.
5
+ Uses FastAPI to forward requests to multiple backend services.
6
+ """
7
+ import os
8
+ import sys
9
+ import time
10
+ import asyncio
11
+ import httpx
12
+ from pathlib import Path
13
+ from typing import List, Dict, Optional, Any
14
+ from threading import Lock
15
+ import argparse
16
+
17
+ try:
18
+ from fastapi import FastAPI, Request, HTTPException
19
+ from fastapi.responses import StreamingResponse, Response
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ import uvicorn
22
+ HAS_FASTAPI = True
23
+ except ImportError:
24
+ HAS_FASTAPI = False
25
+ print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
26
+
27
+
28
+ class SimpleLoadBalancer:
29
+ """Simple load balancer with round-robin and least-connection strategies"""
30
+
31
+ def __init__(
32
+ self,
33
+ backends: List[str],
34
+ strategy: str = "round_robin",
35
+ health_check_interval: float = 10.0,
36
+ ):
37
+ """
38
+ Initialize load balancer
39
+
40
+ Args:
41
+ backends: List of backend URLs (e.g., ["http://localhost:8000", "http://localhost:8001"])
42
+ strategy: Load balancing strategy ("round_robin" or "least_conn")
43
+ health_check_interval: Health check interval in seconds
44
+ """
45
+ self.backends = backends
46
+ self.strategy = strategy
47
+ self.health_check_interval = health_check_interval
48
+
49
+ # Round-robin state
50
+ self.current_index = 0
51
+ self.index_lock = Lock()
52
+
53
+ # Least-connection state
54
+ self.connection_counts: Dict[str, int] = {backend: 0 for backend in backends}
55
+ self.conn_lock = Lock()
56
+
57
+ # Health check state
58
+ self.healthy_backends: Dict[str, bool] = {backend: True for backend in backends}
59
+ self.health_lock = Lock()
60
+
61
+ # HTTP client for forwarding requests
62
+ self.client = httpx.AsyncClient(timeout=300.0)
63
+
64
+ print(f"Load balancer initialized with {len(backends)} backends")
65
+ print(f"Strategy: {strategy}")
66
+ for i, backend in enumerate(backends):
67
+ print(f" [{i+1}] {backend}")
68
+
69
+ def get_backend(self) -> Optional[str]:
70
+ """Get next backend based on strategy"""
71
+ with self.health_lock:
72
+ available_backends = [b for b in self.backends if self.healthy_backends.get(b, True)]
73
+
74
+ if not available_backends:
75
+ # If no healthy backends, try all backends
76
+ available_backends = self.backends
77
+
78
+ if not available_backends:
79
+ return None
80
+
81
+ if self.strategy == "round_robin":
82
+ with self.index_lock:
83
+ backend = available_backends[self.current_index % len(available_backends)]
84
+ self.current_index = (self.current_index + 1) % len(available_backends)
85
+ return backend
86
+
87
+ elif self.strategy == "least_conn":
88
+ with self.conn_lock:
89
+ # Find backend with least connections
90
+ backend = min(available_backends, key=lambda b: self.connection_counts.get(b, 0))
91
+ self.connection_counts[backend] = self.connection_counts.get(backend, 0) + 1
92
+ return backend
93
+
94
+ else:
95
+ # Default to round-robin
96
+ with self.index_lock:
97
+ backend = available_backends[self.current_index % len(available_backends)]
98
+ self.current_index = (self.current_index + 1) % len(available_backends)
99
+ return backend
100
+
101
+ def release_backend(self, backend: str):
102
+ """Release a backend (for least-conn strategy)"""
103
+ if self.strategy == "least_conn":
104
+ with self.conn_lock:
105
+ self.connection_counts[backend] = max(0, self.connection_counts.get(backend, 0) - 1)
106
+
107
+ async def health_check(self, backend: str) -> bool:
108
+ """Check if a backend is healthy"""
109
+ try:
110
+ # For vLLM backends (URLs ending with /v1), use /models endpoint
111
+ # For other backends, try /health first, then root
112
+ if backend.endswith("/v1"):
113
+ # vLLM endpoint: try /models (which becomes /v1/models)
114
+ endpoints = ["/models", "/"]
115
+ else:
116
+ # Other services: try /health, then root
117
+ endpoints = ["/health", "/"]
118
+
119
+ for endpoint in endpoints:
120
+ try:
121
+ response = await self.client.get(f"{backend}{endpoint}", timeout=5.0)
122
+ if response.status_code < 500:
123
+ return True
124
+ except:
125
+ continue
126
+ return False
127
+ except Exception:
128
+ return False
129
+
130
+ async def check_all_backends(self):
131
+ """Check health of all backends"""
132
+ while True:
133
+ for backend in self.backends:
134
+ is_healthy = await self.health_check(backend)
135
+ with self.health_lock:
136
+ self.healthy_backends[backend] = is_healthy
137
+ if not is_healthy:
138
+ print(f"Warning: Backend {backend} is unhealthy")
139
+ await asyncio.sleep(self.health_check_interval)
140
+
141
+ async def forward_request(
142
+ self,
143
+ method: str,
144
+ path: str,
145
+ request: Request,
146
+ backend: Optional[str] = None
147
+ ) -> Response:
148
+ """Forward a request to a backend"""
149
+ if backend is None:
150
+ backend = self.get_backend()
151
+
152
+ if backend is None:
153
+ raise HTTPException(status_code=503, detail="No healthy backends available")
154
+
155
+ try:
156
+ # Get request body
157
+ body = await request.body()
158
+
159
+ # Get query parameters
160
+ query_params = dict(request.query_params)
161
+
162
+ # Get headers (exclude host and connection)
163
+ headers = dict(request.headers)
164
+ headers.pop("host", None)
165
+ headers.pop("connection", None)
166
+ headers.pop("content-length", None)
167
+
168
+ # Forward request
169
+ url = f"{backend}{path}"
170
+ if query_params:
171
+ url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items())
172
+
173
+ response = await self.client.request(
174
+ method=method,
175
+ url=url,
176
+ content=body,
177
+ headers=headers,
178
+ )
179
+
180
+ # Create response
181
+ return Response(
182
+ content=response.content,
183
+ status_code=response.status_code,
184
+ headers=dict(response.headers),
185
+ )
186
+
187
+ except Exception as e:
188
+ # Mark backend as unhealthy
189
+ with self.health_lock:
190
+ self.healthy_backends[backend] = False
191
+
192
+ self.release_backend(backend)
193
+ raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}")
194
+ finally:
195
+ self.release_backend(backend)
196
+
197
+ async def forward_streaming_request(
198
+ self,
199
+ method: str,
200
+ path: str,
201
+ request: Request,
202
+ backend: Optional[str] = None
203
+ ):
204
+ """Forward a streaming request to a backend"""
205
+ if backend is None:
206
+ backend = self.get_backend()
207
+
208
+ if backend is None:
209
+ raise HTTPException(status_code=503, detail="No healthy backends available")
210
+
211
+ try:
212
+ # Get request body
213
+ body = await request.body()
214
+
215
+ # Get query parameters
216
+ query_params = dict(request.query_params)
217
+
218
+ # Get headers
219
+ headers = dict(request.headers)
220
+ headers.pop("host", None)
221
+ headers.pop("connection", None)
222
+ headers.pop("content-length", None)
223
+
224
+ # Forward request
225
+ url = f"{backend}{path}"
226
+ if query_params:
227
+ url += "?" + "&".join(f"{k}={v}" for k, v in query_params.items())
228
+
229
+ async with httpx.AsyncClient(timeout=300.0) as client:
230
+ async with client.stream(
231
+ method=method,
232
+ url=url,
233
+ content=body,
234
+ headers=headers,
235
+ ) as response:
236
+ async def generate():
237
+ async for chunk in response.aiter_bytes():
238
+ yield chunk
239
+
240
+ return StreamingResponse(
241
+ generate(),
242
+ status_code=response.status_code,
243
+ headers=dict(response.headers),
244
+ )
245
+
246
+ except Exception as e:
247
+ # Mark backend as unhealthy
248
+ with self.health_lock:
249
+ self.healthy_backends[backend] = False
250
+
251
+ self.release_backend(backend)
252
+ raise HTTPException(status_code=502, detail=f"Backend error: {str(e)}")
253
+ finally:
254
+ self.release_backend(backend)
255
+
256
+
257
+ def create_load_balancer_app(
258
+ backends: List[str],
259
+ strategy: str = "round_robin",
260
+ health_check_interval: float = 10.0,
261
+ ) -> FastAPI:
262
+ """Create FastAPI app with load balancer"""
263
+ if not HAS_FASTAPI:
264
+ raise RuntimeError("FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
265
+
266
+ app = FastAPI(title="Simple Load Balancer", version="1.0.0")
267
+
268
+ # Add CORS middleware
269
+ app.add_middleware(
270
+ CORSMiddleware,
271
+ allow_origins=["*"],
272
+ allow_credentials=True,
273
+ allow_methods=["*"],
274
+ allow_headers=["*"],
275
+ )
276
+
277
+ # Create load balancer
278
+ lb = SimpleLoadBalancer(backends, strategy, health_check_interval)
279
+
280
+ # Start health check task
281
+ @app.on_event("startup")
282
+ async def start_health_check():
283
+ asyncio.create_task(lb.check_all_backends())
284
+
285
+ # Health check endpoint
286
+ @app.get("/health")
287
+ async def health():
288
+ healthy_count = sum(1 for h in lb.healthy_backends.values() if h)
289
+ return {
290
+ "status": "healthy" if healthy_count > 0 else "unhealthy",
291
+ "healthy_backends": healthy_count,
292
+ "total_backends": len(lb.backends),
293
+ "backends": [
294
+ {
295
+ "url": backend,
296
+ "healthy": lb.healthy_backends.get(backend, False),
297
+ "connections": lb.connection_counts.get(backend, 0) if lb.strategy == "least_conn" else None,
298
+ }
299
+ for backend in lb.backends
300
+ ]
301
+ }
302
+
303
+ # Forward all other requests
304
+ @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
305
+ async def forward(request: Request, path: str):
306
+ method = request.method
307
+ is_streaming = "stream" in request.query_params or "stream=true" in str(request.url)
308
+
309
+ if is_streaming:
310
+ return await lb.forward_streaming_request(method, f"/{path}", request)
311
+ else:
312
+ return await lb.forward_request(method, f"/{path}", request)
313
+
314
+ return app
315
+
316
+
317
+ def main():
318
+ """Main entry point for load balancer"""
319
+ parser = argparse.ArgumentParser(description="Simple Python Load Balancer")
320
+ parser.add_argument(
321
+ "--backends",
322
+ type=str,
323
+ nargs="+",
324
+ required=True,
325
+ help="Backend URLs (e.g., http://localhost:8000 http://localhost:8001)"
326
+ )
327
+ parser.add_argument(
328
+ "--host",
329
+ type=str,
330
+ default="0.0.0.0",
331
+ help="Host to bind to (default: 0.0.0.0)"
332
+ )
333
+ parser.add_argument(
334
+ "--port",
335
+ type=int,
336
+ default=8000,
337
+ help="Port to bind to (default: 8000)"
338
+ )
339
+ parser.add_argument(
340
+ "--strategy",
341
+ type=str,
342
+ default="round_robin",
343
+ choices=["round_robin", "least_conn"],
344
+ help="Load balancing strategy (default: round_robin)"
345
+ )
346
+ parser.add_argument(
347
+ "--health-check-interval",
348
+ type=float,
349
+ default=10.0,
350
+ help="Health check interval in seconds (default: 10.0)"
351
+ )
352
+
353
+ args = parser.parse_args()
354
+
355
+ if not HAS_FASTAPI:
356
+ print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn httpx")
357
+ sys.exit(1)
358
+
359
+ # Create app
360
+ app = create_load_balancer_app(
361
+ backends=args.backends,
362
+ strategy=args.strategy,
363
+ health_check_interval=args.health_check_interval,
364
+ )
365
+
366
+ # Run server
367
+ print(f"Starting load balancer on {args.host}:{args.port}")
368
+ print(f"Strategy: {args.strategy}")
369
+ print(f"Backends: {', '.join(args.backends)}")
370
+
371
+ uvicorn.run(
372
+ app,
373
+ host=args.host,
374
+ port=args.port,
375
+ log_level="info"
376
+ )
377
+
378
+
379
+ if __name__ == "__main__":
380
+ main()
381
+
382
+ # python -m shared.utils.load_balancer --backends http://localhost:8000 http://localhost:8001 http://localhost:8002 http://localhost:8003 --port 8004 --strategy round_robin
shared/utils/mock_llm_service.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mock LLM Service that returns pre-generated reviews from a JSON file
3
+ This is a hack for testing the refiner pipeline with existing reviews
4
+ """
5
+ import json
6
+ import re
7
+ from pathlib import Path
8
+ from typing import List, Dict, Optional, Any, Union
9
+
10
+ from .llm_service import LLMService, ChatMessage
11
+
12
+
13
+ def extract_title_from_latex(paper_context: str) -> Optional[str]:
14
+ """Extract title from LaTeX format \\title{...}"""
15
+ match = re.search(r'\\title\{([^}]+)\}', paper_context)
16
+ if match:
17
+ return match.group(1).strip()
18
+ return None
19
+
20
+
21
+ def extract_abstract_from_latex(paper_context: str) -> Optional[str]:
22
+ """Extract abstract from LaTeX format \\begin{abstract}...\\end{abstract}"""
23
+ match = re.search(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', paper_context, re.DOTALL)
24
+ if match:
25
+ abstract = match.group(1).strip()
26
+ # Clean up LaTeX commands
27
+ abstract = re.sub(r'\\[a-zA-Z]+\{([^}]+)\}', r'\1', abstract) # Remove LaTeX commands
28
+ abstract = re.sub(r'\$([^$]+)\$', r'\1', abstract) # Remove math mode
29
+ return abstract
30
+ return None
31
+
32
+
33
+ class MockLLMService(LLMService):
34
+ """
35
+ Mock LLM Service that returns pre-generated reviews from a JSON file
36
+
37
+ This service matches papers by extracting title and abstract from paper_context
38
+ and returns the corresponding pred_fast_mode_baseline from the JSON file.
39
+ """
40
+
41
+ def __init__(self, json_file_path: str):
42
+ """
43
+ Initialize Mock LLM Service
44
+
45
+ Args:
46
+ json_file_path: Path to JSON file containing pre-generated reviews
47
+ """
48
+ self.json_file_path = Path(json_file_path)
49
+ if not self.json_file_path.exists():
50
+ raise FileNotFoundError(f"JSON file not found: {json_file_path}")
51
+
52
+ # Load JSON data
53
+ with open(self.json_file_path, 'r', encoding='utf-8') as f:
54
+ self.data = json.load(f)
55
+
56
+ # Build index for faster lookup
57
+ self._build_index()
58
+
59
+ def _build_index(self):
60
+ """Build index mapping (title, abstract) to review"""
61
+ self.index = {}
62
+ self.entries = [] # Store full entries for fallback matching
63
+ self.initial_scores_index = {} # Store initial scores and decision for each entry
64
+
65
+ for entry in self.data:
66
+ paper_context = entry.get('paper_context', '')
67
+ title = extract_title_from_latex(paper_context)
68
+ abstract = extract_abstract_from_latex(paper_context)
69
+
70
+ # Extract review content (prefer meta_review.content, fallback to raw_text)
71
+ model_prediction = entry.get('model_prediction', {})
72
+ meta_review = model_prediction.get('meta_review', {})
73
+ review_content = meta_review.get('content', '') or model_prediction.get('raw_text', '')
74
+
75
+ # Extract initial scores and decision
76
+ initial_scores = {
77
+ 'rating': meta_review.get('rating'),
78
+ 'soundness': meta_review.get('soundness'),
79
+ 'presentation': meta_review.get('presentation'),
80
+ 'contribution': meta_review.get('contribution'),
81
+ 'decision': model_prediction.get('decision'),
82
+ }
83
+
84
+ if title and abstract:
85
+ # Use normalized title and first 200 chars of abstract as key
86
+ normalized_title = title.lower().strip()
87
+ normalized_abstract = abstract[:200].lower().strip()
88
+ key = (normalized_title, normalized_abstract)
89
+ self.index[key] = review_content
90
+ self.initial_scores_index[key] = initial_scores
91
+
92
+
93
+ # Store entry for fallback matching
94
+ self.entries.append({
95
+ 'title': title,
96
+ 'abstract': abstract,
97
+ 'paper_context': paper_context,
98
+ 'review': review_content,
99
+ 'id': entry.get('id', ''),
100
+ 'initial_scores': initial_scores,
101
+ })
102
+
103
+ def _find_entry(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> Optional[Dict[str, Any]]:
104
+ """
105
+ Find entry by matching title and abstract from messages
106
+
107
+ Args:
108
+ messages: List of chat messages
109
+
110
+ Returns:
111
+ Entry dict with 'review' and 'initial_scores' or None if not found
112
+ """
113
+ # Extract paper context from user message
114
+ user_message = None
115
+ for msg in messages:
116
+ if isinstance(msg, dict):
117
+ if msg.get('role') == 'user':
118
+ user_message = msg.get('content', '')
119
+ elif isinstance(msg, ChatMessage):
120
+ if msg.role == 'user':
121
+ user_message = msg.content
122
+
123
+ if not user_message:
124
+ return None
125
+
126
+ # Try to extract title and abstract from user message
127
+ # Look for patterns like "Title: ..." or "Abstract: ..."
128
+ title_match = re.search(r'Title:\s*(.+?)(?:\n|$)', user_message, re.IGNORECASE)
129
+ abstract_match = re.search(r'Abstract:\s*(.+?)(?:\n\n|Content:|$)', user_message, re.DOTALL | re.IGNORECASE)
130
+
131
+ extracted_title = None
132
+ extracted_abstract = None
133
+
134
+ if title_match and abstract_match:
135
+ extracted_title = title_match.group(1).strip()
136
+ extracted_abstract = abstract_match.group(1).strip()
137
+ else:
138
+ # Fallback: search in paper_context if available
139
+ paper_context_match = re.search(r'Paper to review:\s*(.+?)(?:Please provide|$)', user_message, re.DOTALL)
140
+ if paper_context_match:
141
+ paper_context = paper_context_match.group(1)
142
+ extracted_title = extract_title_from_latex(paper_context)
143
+ extracted_abstract = extract_abstract_from_latex(paper_context)
144
+
145
+ if extracted_title and extracted_abstract:
146
+ # Normalize for matching
147
+ normalized_title = extracted_title.lower().strip()
148
+ normalized_abstract = extracted_abstract[:200].lower().strip()
149
+
150
+ # Try exact match first
151
+ key = (normalized_title, normalized_abstract)
152
+ if key in self.index:
153
+ return {
154
+ 'review': self.index[key],
155
+ 'initial_scores': self.initial_scores_index.get(key, {})
156
+ }
157
+
158
+ # Try fuzzy match (check if title matches)
159
+ for (index_title, index_abstract), review in self.index.items():
160
+ # Check title similarity (either contains or is contained)
161
+ title_similar = (
162
+ normalized_title in index_title or
163
+ index_title in normalized_title or
164
+ normalized_title == index_title
165
+ )
166
+
167
+ # Check abstract similarity (first 100 chars)
168
+ abstract_similar = (
169
+ normalized_abstract[:100] in index_abstract[:100] or
170
+ index_abstract[:100] in normalized_abstract[:100] or
171
+ normalized_abstract[:100] == index_abstract[:100]
172
+ )
173
+
174
+ if title_similar and abstract_similar:
175
+ return {
176
+ 'review': review,
177
+ 'initial_scores': self.initial_scores_index.get((index_title, index_abstract), {})
178
+ }
179
+
180
+ # Final fallback: try to match by paper_context in entries
181
+ for entry in self.entries:
182
+ if entry['paper_context']:
183
+ # Check if user message contains similar content
184
+ entry_title = entry['title']
185
+ if entry_title and extracted_title:
186
+ if entry_title.lower().strip() in extracted_title.lower() or extracted_title.lower() in entry_title.lower():
187
+ return {
188
+ 'review': entry['review'],
189
+ 'initial_scores': entry.get('initial_scores', {})
190
+ }
191
+
192
+ return None
193
+
194
+ def _find_review(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> Optional[str]:
195
+ """
196
+ Find review by matching title and abstract from messages
197
+
198
+ Args:
199
+ messages: List of chat messages
200
+
201
+ Returns:
202
+ Review text or None if not found
203
+ """
204
+ entry = self._find_entry(messages)
205
+ if entry:
206
+ return entry['review']
207
+ return None
208
+
209
+ def get_initial_scores(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> Optional[Dict[str, Any]]:
210
+ """
211
+ Get initial scores and decision by matching title and abstract from messages
212
+
213
+ Args:
214
+ messages: List of chat messages
215
+
216
+ Returns:
217
+ Dict with initial scores (rating, soundness, presentation, contribution, decision) or None if not found
218
+ """
219
+ entry = self._find_entry(messages)
220
+ if entry:
221
+ return entry.get('initial_scores', {})
222
+ return None
223
+
224
+ def generate(
225
+ self,
226
+ messages: List[Union[ChatMessage, Dict[str, str]]],
227
+ temperature: float = 0.7,
228
+ top_p: float = 0.8,
229
+ top_k: int = 20,
230
+ max_tokens: int = 16384,
231
+ presence_penalty: float = 0.0,
232
+ **kwargs
233
+ ) -> str:
234
+ """
235
+ Generate text from messages (returns pre-generated review)
236
+
237
+ Args:
238
+ messages: List of chat messages
239
+ temperature: Ignored (for compatibility)
240
+ top_p: Ignored (for compatibility)
241
+ top_k: Ignored (for compatibility)
242
+ max_tokens: Ignored (for compatibility)
243
+ presence_penalty: Ignored (for compatibility)
244
+ **kwargs: Additional parameters (ignored)
245
+
246
+ Returns:
247
+ Pre-generated review text
248
+ """
249
+ review = self._find_review(messages)
250
+ if review:
251
+ return review
252
+
253
+ # Fallback: return a default message
254
+ return "## Summary:\n\nReview not found in pre-generated data."
255
+
256
+ def stream_generate(
257
+ self,
258
+ messages: List[Union[ChatMessage, Dict[str, str]]],
259
+ temperature: float = 0.7,
260
+ top_p: float = 0.8,
261
+ top_k: int = 20,
262
+ max_tokens: int = 16384,
263
+ presence_penalty: float = 0.0,
264
+ **kwargs
265
+ ):
266
+ """
267
+ Stream generate text from messages (yields pre-generated review)
268
+
269
+ Yields:
270
+ Pre-generated review text chunks
271
+ """
272
+ review = self._find_review(messages)
273
+ if review:
274
+ # Yield in chunks to simulate streaming
275
+ chunk_size = 100
276
+ for i in range(0, len(review), chunk_size):
277
+ yield review[i:i + chunk_size]
278
+ else:
279
+ yield "## Summary:\n\nReview not found in pre-generated data."
280
+
shared/utils/prompt_loader.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility for loading prompts from YAML configuration files
3
+ """
4
+ import yaml
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+
8
+
9
+ class PromptLoader:
10
+ """Load and manage prompts from YAML files"""
11
+
12
+ def __init__(self, prompts_file: Optional[str] = None):
13
+ """
14
+ Initialize prompt loader
15
+
16
+ Args:
17
+ prompts_file: Path to prompts YAML file. If None, uses default location.
18
+ """
19
+ if prompts_file is None:
20
+ # Default to shared/configs/prompts.yaml relative to project root
21
+ project_root = Path(__file__).parent.parent.parent
22
+ prompts_file = project_root / "shared" / "configs" / "prompts.yaml"
23
+
24
+ self.prompts_file = Path(prompts_file)
25
+ self._prompts = None
26
+ self._load_prompts()
27
+
28
+ def _load_prompts(self):
29
+ """Load prompts from YAML file"""
30
+ if not self.prompts_file.exists():
31
+ raise FileNotFoundError(f"Prompts file not found: {self.prompts_file}")
32
+
33
+ with open(self.prompts_file, 'r', encoding='utf-8') as f:
34
+ self._prompts = yaml.safe_load(f)
35
+
36
+ def get_keyword_generation_prompt(self, context: str) -> str:
37
+ """
38
+ Get keyword generation prompt with context filled in
39
+
40
+ Args:
41
+ context: Paper information context
42
+
43
+ Returns:
44
+ Formatted prompt string
45
+ """
46
+ template = self._prompts["keyword_generation"]["user"]
47
+ return template.format(context=context)
48
+
49
+ def get_keyword_generation_system(self) -> str:
50
+ """Get keyword generation system message"""
51
+ return self._prompts["keyword_generation"].get("system", "")
52
+
53
+ def get_paper_summarization_prompt(self, reference_paper: str, related_paper: str) -> str:
54
+ """
55
+ Get paper summarization prompt with reference_paper and related_paper filled in
56
+
57
+ Args:
58
+ reference_paper: Reference paper information (the paper being reviewed)
59
+ related_paper: Related paper information
60
+
61
+ Returns:
62
+ Formatted prompt string
63
+ """
64
+ template = self._prompts["paper_summarization"]["user"]
65
+ return template.format(reference_paper=reference_paper, related_paper=related_paper)
66
+
67
+ def get_paper_results_summarization_prompt(self, content: str) -> str:
68
+ """
69
+ Get paper results summarization prompt with content filled in
70
+
71
+ Args:
72
+ content: Paper content (experiment results section)
73
+
74
+ Returns:
75
+ Formatted prompt string
76
+ """
77
+ template = self._prompts["paper_results_summarization"]["user"]
78
+ return template.format(content=content)
79
+
80
+ def get_paper_insight_miner_prompt(self, content: str, candidate_review: str) -> str:
81
+ """
82
+ Get paper insight miner prompt with content and candidate_review filled in
83
+
84
+ Args:
85
+ content: Paper content
86
+ candidate_review: Candidate review draft
87
+
88
+ Returns:
89
+ Formatted prompt string
90
+ """
91
+ template = self._prompts["paper_insight_miner"]["user"]
92
+ # Use replace instead of format to avoid issues with JSON braces in the template
93
+ prompt = template.replace("{content}", content)
94
+ prompt = prompt.replace("{candidate_review}", candidate_review)
95
+ return prompt
96
+
97
+ def get_paper_results_analyzer_prompt(self, content: str, candidate_review: str) -> str:
98
+ """
99
+ Get paper results analyzer prompt with content and candidate_review filled in
100
+
101
+ Args:
102
+ content: Paper content
103
+ candidate_review: Candidate review draft
104
+
105
+ Returns:
106
+ Formatted prompt string
107
+ """
108
+ template = self._prompts["paper_results_analyzer"]["user"]
109
+ # Use replace instead of format to avoid issues with JSON braces in the template
110
+ prompt = template.replace("{content}", content)
111
+ prompt = prompt.replace("{candidate_review}", candidate_review)
112
+ return prompt
113
+
114
+ def get_review_prompt(self, review_format: str = "detailed") -> str:
115
+ """
116
+ Get review prompt for specified format
117
+
118
+ Args:
119
+ review_format: Review format ("detailed", "summary", "structured")
120
+
121
+ Returns:
122
+ Review prompt string
123
+ """
124
+ if review_format not in self._prompts["review_prompts"]:
125
+ review_format = "detailed"
126
+
127
+ return self._prompts["review_prompts"][review_format]
128
+
129
+ def get_reviewer_system_message(self) -> str:
130
+ """Get system message for reviewer"""
131
+ return self._prompts.get("reviewer_system", "You are an expert academic reviewer with deep knowledge in the field.")
132
+
133
+ def get_refiner_prompt(self, review_format: str = "detailed") -> str:
134
+ """
135
+ Get refiner prompt for specified format
136
+
137
+ Args:
138
+ review_format: Review format ("detailed", "summary", "structured")
139
+
140
+ Returns:
141
+ Refiner prompt string
142
+ """
143
+ if "refiner_prompts" not in self._prompts:
144
+ raise ValueError("refiner_prompts not found in prompts file")
145
+
146
+ if review_format not in self._prompts["refiner_prompts"]:
147
+ review_format = "detailed"
148
+
149
+ return self._prompts["refiner_prompts"][review_format]
150
+
151
+ def get_refiner_system_message(self) -> str:
152
+ """Get system message for refiner"""
153
+ return self._prompts.get("refiner_system", "You are an expert review refiner with deep knowledge in academic review quality standards and meta rubrics.")
154
+
155
+ def get_rubrics_template(self) -> str:
156
+ """
157
+ Get the rubrics template for generating paper-specific rubrics.
158
+
159
+ Returns:
160
+ Rubrics template string (JSON array format)
161
+ """
162
+ return self._prompts.get("rubrics", "")
163
+
164
+ def get_rubric_generation_prompt(self, version: str = "v2") -> str:
165
+ """
166
+ Get rubric generation prompt.
167
+
168
+ Args:
169
+ version: Prompt version ("v1" or "v2", default: "v2")
170
+
171
+ Returns:
172
+ Rubric generation prompt template string
173
+ """
174
+ key = f"{version}_rubric_generation_prompt"
175
+ prompt = self._prompts.get(key, "")
176
+
177
+ # For v2, replace rubric_template placeholder with actual template
178
+ if version == "v2" and "<<rubric_template>>" in prompt:
179
+ rubric_template = self.get_rubrics_template()
180
+ prompt = prompt.replace("<<rubric_template>>", rubric_template)
181
+
182
+ return prompt
183
+
184
+ def get_evaluator_prompt(self, version: str = "v1") -> str:
185
+ """
186
+ Get evaluator prompt for evaluating reviews using rubrics.
187
+
188
+ Args:
189
+ version: Prompt version ("v0" or "v1", default: "v1")
190
+
191
+ Returns:
192
+ Evaluator prompt template string
193
+ """
194
+ key = f"{version}_evaluator_prompt"
195
+ return self._prompts.get(key, "")
196
+
197
+ def reload(self):
198
+ """Reload prompts from file"""
199
+ self._load_prompts()
200
+
201
+
202
+ # Global prompt loader instance
203
+ _prompt_loader: Optional[PromptLoader] = None
204
+
205
+
206
+ def get_prompt_loader(prompts_file: Optional[str] = None) -> PromptLoader:
207
+ """
208
+ Get or create global prompt loader instance
209
+
210
+ Args:
211
+ prompts_file: Optional path to prompts file
212
+
213
+ Returns:
214
+ PromptLoader instance
215
+ """
216
+ global _prompt_loader
217
+ if _prompt_loader is None or prompts_file is not None:
218
+ _prompt_loader = PromptLoader(prompts_file)
219
+ return _prompt_loader
220
+
shared/utils/reranker.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reranker utilities for paper retrieval
3
+ Based on OpenScholar's rerank_paragraphs_bge function
4
+
5
+ Supports two modes:
6
+ 1. Direct mode: Use FlagReranker directly (requires global lock for thread-safety)
7
+ 2. API mode: Use reranker API service with load balancing (recommended for multi-GPU)
8
+ """
9
+ import os
10
+ import threading
11
+ import time
12
+ import requests
13
+ from typing import List, Dict, Any, Optional, Tuple, Union
14
+
15
+ # Suppress transformers progress bars
16
+ os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
17
+
18
+ # Global lock for reranker usage (FlagReranker's tokenizer is not thread-safe)
19
+ # This prevents "Already borrowed" errors when multiple threads use the same reranker
20
+ # NOTE: Not needed when using API mode
21
+ _reranker_usage_lock = threading.Lock()
22
+
23
+ # Try to import endpoint pool for API mode
24
+ try:
25
+ from .reranker_endpoint_pool import RerankerEndpointPool
26
+ HAS_ENDPOINT_POOL = True
27
+ except ImportError:
28
+ HAS_ENDPOINT_POOL = False
29
+ RerankerEndpointPool = None
30
+
31
+
32
+ def rerank_paragraphs_bge(
33
+ query: str,
34
+ paragraphs: List[Dict[str, Any]],
35
+ reranker: Optional[Any] = None,
36
+ reranker_endpoint_pool: Optional[Any] = None,
37
+ norm_cite: bool = False,
38
+ start_index: int = 0,
39
+ use_abstract: bool = False,
40
+ timeout: float = 30.0,
41
+ ) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
42
+ """
43
+ Rerank paragraphs using BGE reranker (from OpenScholar)
44
+
45
+ Supports two modes:
46
+ 1. Direct mode: Pass FlagReranker instance (uses global lock, thread-safe but serialized)
47
+ 2. API mode: Pass RerankerEndpointPool (recommended for multi-GPU, parallel requests)
48
+
49
+ Args:
50
+ query: Search query
51
+ paragraphs: List of paragraph/paper dictionaries
52
+ reranker: FlagReranker instance (for direct mode, optional if using API mode)
53
+ reranker_endpoint_pool: RerankerEndpointPool instance (for API mode, optional if using direct mode)
54
+ norm_cite: Whether to normalize citation counts and add to scores
55
+ start_index: Starting index for id mapping
56
+ use_abstract: Whether to include abstract in reranking text
57
+ timeout: Request timeout for API mode (seconds)
58
+
59
+ Returns:
60
+ Tuple of:
61
+ - reranked_paragraphs: List of reranked paragraphs
62
+ - result_dict: Dictionary mapping original index to score
63
+ - id_mapping: Dictionary mapping new index to original index
64
+ """
65
+ # Filter out paragraphs without text
66
+ paragraphs = [p for p in paragraphs if p.get("text") is not None]
67
+
68
+ if not paragraphs:
69
+ return [], {}, {}
70
+
71
+ # Build paragraph texts for reranking
72
+ if use_abstract:
73
+ paragraph_texts = [
74
+ p["title"] + "\n" + p["abstract"] + "\n" + p["text"]
75
+ if "title" in p and "abstract" in p and p.get("title") and p.get("abstract")
76
+ else p["text"]
77
+ for p in paragraphs
78
+ ]
79
+ else:
80
+ paragraph_texts = [
81
+ p["title"] + " " + p["text"]
82
+ if "title" in p and p.get("title") is not None
83
+ else p["text"]
84
+ for p in paragraphs
85
+ ]
86
+
87
+ # Filter out empty or None texts
88
+ valid_indices = []
89
+ valid_texts = []
90
+ for i, text in enumerate(paragraph_texts):
91
+ if text and isinstance(text, str) and text.strip():
92
+ valid_indices.append(i)
93
+ valid_texts.append(text)
94
+
95
+ # If no valid texts, return empty results
96
+ if not valid_texts:
97
+ return [], {}, {}
98
+
99
+ # If some texts were filtered out, update paragraphs list
100
+ if len(valid_indices) < len(paragraphs):
101
+ paragraphs = [paragraphs[i] for i in valid_indices]
102
+ paragraph_texts = valid_texts
103
+
104
+ # Compute reranking scores
105
+ if reranker is None and reranker_endpoint_pool is None:
106
+ # If no reranker, return original order
107
+ id_mapping = {i: i + start_index for i in range(len(paragraphs))}
108
+ result_dict = {i: 0.0 for i in range(len(paragraphs))}
109
+ return paragraphs, result_dict, id_mapping
110
+
111
+ # API mode: Use reranker API service (recommended for multi-GPU)
112
+ if reranker_endpoint_pool is not None:
113
+ return _rerank_via_api(
114
+ query=query,
115
+ paragraph_texts=paragraph_texts,
116
+ paragraphs=paragraphs,
117
+ reranker_endpoint_pool=reranker_endpoint_pool,
118
+ norm_cite=norm_cite,
119
+ start_index=start_index,
120
+ timeout=timeout
121
+ )
122
+
123
+ # Direct mode: Use FlagReranker directly (requires global lock)
124
+ # Suppress transformers warnings and progress bars during computation
125
+ original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
126
+ os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
127
+
128
+ # Use lock to prevent "Already borrowed" errors from Rust tokenizer
129
+ # FlagReranker's tokenizer is not thread-safe, so we need to serialize access
130
+ with _reranker_usage_lock:
131
+ try:
132
+ # Ensure we have at least one valid text before calling compute_score
133
+ if not paragraph_texts:
134
+ return [], {}, {}
135
+ scores = reranker.compute_score([[query, p] for p in paragraph_texts], batch_size=100)
136
+ finally:
137
+ # Restore original verbosity
138
+ if original_verbosity:
139
+ os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
140
+ elif 'TRANSFORMERS_VERBOSITY' in os.environ:
141
+ del os.environ['TRANSFORMERS_VERBOSITY']
142
+
143
+ # Handle score format (can be float or list)
144
+ if isinstance(scores, float):
145
+ result_dict = {0: scores}
146
+ else:
147
+ result_dict = {p_id: score for p_id, score in enumerate(scores)}
148
+
149
+ # Add normalized citation counts if enabled
150
+ if norm_cite:
151
+ citation_items = [
152
+ item["citation_counts"]
153
+ for item in paragraphs
154
+ if "citation_counts" in item and item["citation_counts"] is not None
155
+ ]
156
+ if len(citation_items) > 0:
157
+ max_citations = max(citation_items)
158
+ for p_id in result_dict:
159
+ if (
160
+ "citation_counts" in paragraphs[p_id]
161
+ and paragraphs[p_id]["citation_counts"] is not None
162
+ ):
163
+ result_dict[p_id] = result_dict[p_id] + (
164
+ paragraphs[p_id]["citation_counts"] / max_citations
165
+ )
166
+
167
+ # Sort by score
168
+ p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
169
+
170
+ # Build reranked list and id mapping
171
+ new_orders = []
172
+ id_mapping = {}
173
+ for i, (p_id, _) in enumerate(p_ids):
174
+ new_orders.append(paragraphs[p_id])
175
+ id_mapping[i] = int(p_id) + start_index
176
+
177
+ return new_orders, result_dict, id_mapping
178
+
179
+
180
+ def _rerank_via_api(
181
+ query: str,
182
+ paragraph_texts: List[str],
183
+ paragraphs: List[Dict[str, Any]],
184
+ reranker_endpoint_pool: Any,
185
+ norm_cite: bool = False,
186
+ start_index: int = 0,
187
+ timeout: float = 30.0,
188
+ ) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
189
+ """
190
+ Rerank paragraphs via API service (supports load balancing across multiple GPUs)
191
+
192
+ Args:
193
+ query: Search query
194
+ paragraph_texts: List of paragraph texts (already formatted)
195
+ paragraphs: List of paragraph dictionaries
196
+ reranker_endpoint_pool: RerankerEndpointPool instance
197
+ norm_cite: Whether to normalize citation counts
198
+ start_index: Starting index for id mapping
199
+ timeout: Request timeout
200
+
201
+ Returns:
202
+ Tuple of reranked paragraphs, result dict, and id mapping
203
+ """
204
+ if not paragraph_texts:
205
+ return [], {}, {}
206
+
207
+ # Get endpoint from pool (round-robin load balancing)
208
+ endpoint = reranker_endpoint_pool.get_endpoint()
209
+ api_url = f"{endpoint}/rerank"
210
+
211
+ # Prepare request
212
+ request_data = {
213
+ "query": query,
214
+ "paragraphs": paragraph_texts,
215
+ "batch_size": 100
216
+ }
217
+
218
+ start_time = time.time()
219
+ try:
220
+ # Make API request
221
+ response = requests.post(
222
+ api_url,
223
+ json=request_data,
224
+ timeout=timeout
225
+ )
226
+ response.raise_for_status()
227
+
228
+ result = response.json()
229
+ scores = result.get("scores", [])
230
+ response_time = time.time() - start_time
231
+
232
+ # Mark success
233
+ reranker_endpoint_pool.mark_success(endpoint, response_time)
234
+
235
+ except requests.exceptions.RequestException as e:
236
+ # Mark error
237
+ reranker_endpoint_pool.mark_error(endpoint, str(e))
238
+ raise RuntimeError(f"Reranker API request failed: {e}")
239
+
240
+ # Handle score format (should be list from API)
241
+ if isinstance(scores, float):
242
+ result_dict = {0: scores}
243
+ else:
244
+ result_dict = {p_id: score for p_id, score in enumerate(scores)}
245
+
246
+ # Add normalized citation counts if enabled
247
+ if norm_cite:
248
+ citation_items = [
249
+ item["citation_counts"]
250
+ for item in paragraphs
251
+ if "citation_counts" in item and item["citation_counts"] is not None
252
+ ]
253
+ if len(citation_items) > 0:
254
+ max_citations = max(citation_items)
255
+ for p_id in result_dict:
256
+ if (
257
+ "citation_counts" in paragraphs[p_id]
258
+ and paragraphs[p_id]["citation_counts"] is not None
259
+ ):
260
+ result_dict[p_id] = result_dict[p_id] + (
261
+ paragraphs[p_id]["citation_counts"] / max_citations
262
+ )
263
+
264
+ # Sort by score
265
+ p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
266
+
267
+ # Build reranked list and id mapping
268
+ new_orders = []
269
+ id_mapping = {}
270
+ for i, (p_id, _) in enumerate(p_ids):
271
+ new_orders.append(paragraphs[p_id])
272
+ id_mapping[i] = int(p_id) + start_index
273
+
274
+ return new_orders, result_dict, id_mapping
275
+
shared/utils/reranker_api_service.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reranker API Service
3
+
4
+ Pack FlagReranker into an HTTP API service, supporting multi-GPU load balancing.
5
+ """
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import List, Dict, Any, Optional
10
+ import argparse
11
+
12
+ # Suppress transformers warnings
13
+ os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
14
+
15
+ try:
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ import uvicorn
20
+ HAS_FASTAPI = True
21
+ except ImportError:
22
+ HAS_FASTAPI = False
23
+ print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn")
24
+
25
+ try:
26
+ from FlagEmbedding import FlagReranker
27
+ HAS_FLAGEMBEDDING = True
28
+ except ImportError:
29
+ HAS_FLAGEMBEDDING = False
30
+ print("Warning: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
31
+
32
+
33
+ # Request/Response models
34
+ class RerankRequest(BaseModel):
35
+ query: str
36
+ paragraphs: List[str]
37
+ batch_size: int = 100
38
+
39
+
40
+ class RerankResponse(BaseModel):
41
+ scores: List[float]
42
+ success: bool
43
+ message: Optional[str] = None
44
+
45
+
46
+ # Global reranker instance
47
+ _reranker: Optional[Any] = None
48
+
49
+
50
+ def create_app(model_path: str, use_fp16: bool = True, device: Optional[str] = None):
51
+ """Create FastAPI app with reranker"""
52
+ global _reranker
53
+
54
+ app = FastAPI(title="Reranker API Service", version="1.0.0")
55
+
56
+ # Add CORS middleware
57
+ app.add_middleware(
58
+ CORSMiddleware,
59
+ allow_origins=["*"],
60
+ allow_credentials=True,
61
+ allow_methods=["*"],
62
+ allow_headers=["*"],
63
+ )
64
+
65
+ @app.on_event("startup")
66
+ async def load_reranker():
67
+ """Load reranker model on startup"""
68
+ global _reranker
69
+ if not HAS_FLAGEMBEDDING:
70
+ raise RuntimeError("FlagEmbedding not installed")
71
+
72
+ print(f"Loading reranker model: {model_path}")
73
+ print(f"Using FP16: {use_fp16}")
74
+ if device:
75
+ print(f"Using device: {device}")
76
+
77
+ try:
78
+ _reranker = FlagReranker(
79
+ model_path,
80
+ use_fp16=use_fp16,
81
+ )
82
+ if device:
83
+ # Note: FlagReranker may not support explicit device setting
84
+ # This is a placeholder for future support
85
+ pass
86
+ print("Reranker model loaded successfully")
87
+ except Exception as e:
88
+ print(f"Error loading reranker: {e}")
89
+ raise
90
+
91
+ @app.get("/health")
92
+ async def health_check():
93
+ """Health check endpoint"""
94
+ return {
95
+ "status": "healthy",
96
+ "model_loaded": _reranker is not None
97
+ }
98
+
99
+ @app.post("/rerank", response_model=RerankResponse)
100
+ async def rerank(request: RerankRequest):
101
+ """Rerank paragraphs given a query"""
102
+ global _reranker
103
+
104
+ if _reranker is None:
105
+ raise HTTPException(status_code=503, detail="Reranker not loaded")
106
+
107
+ if not request.paragraphs:
108
+ return RerankResponse(
109
+ scores=[],
110
+ success=True,
111
+ message="No paragraphs to rerank"
112
+ )
113
+
114
+ try:
115
+ # Prepare sentence pairs: [[query, paragraph], ...]
116
+ sentence_pairs = [[request.query, p] for p in request.paragraphs]
117
+
118
+ # Compute scores
119
+ scores = _reranker.compute_score(
120
+ sentence_pairs,
121
+ batch_size=request.batch_size
122
+ )
123
+
124
+ # Handle score format (can be float or list)
125
+ if isinstance(scores, float):
126
+ scores = [scores]
127
+ elif not isinstance(scores, list):
128
+ scores = list(scores)
129
+
130
+ return RerankResponse(
131
+ scores=scores,
132
+ success=True
133
+ )
134
+ except Exception as e:
135
+ print(f"Error during reranking: {e}")
136
+ import traceback
137
+ traceback.print_exc()
138
+ raise HTTPException(status_code=500, detail=str(e))
139
+
140
+ return app
141
+
142
+
143
+ def main():
144
+ """Main entry point for reranker API service"""
145
+ parser = argparse.ArgumentParser(description="Reranker API Service")
146
+ parser.add_argument(
147
+ "--model_path",
148
+ type=str,
149
+ required=True,
150
+ help="Path to reranker model (e.g., 'OpenScholar/OpenScholar_Reranker')"
151
+ )
152
+ parser.add_argument(
153
+ "--host",
154
+ type=str,
155
+ default="0.0.0.0",
156
+ help="Host to bind to (default: 0.0.0.0)"
157
+ )
158
+ parser.add_argument(
159
+ "--port",
160
+ type=int,
161
+ default=8004,
162
+ help="Port to bind to (default: 8004)"
163
+ )
164
+ parser.add_argument(
165
+ "--use_fp16",
166
+ action="store_true",
167
+ default=True,
168
+ help="Use FP16 precision (default: True)"
169
+ )
170
+ parser.add_argument(
171
+ "--no_fp16",
172
+ dest="use_fp16",
173
+ action="store_false",
174
+ help="Disable FP16 precision"
175
+ )
176
+ parser.add_argument(
177
+ "--device",
178
+ type=str,
179
+ default=None,
180
+ help="Device to use (e.g., 'cuda:0', 'cuda:1')"
181
+ )
182
+ parser.add_argument(
183
+ "--workers",
184
+ type=int,
185
+ default=1,
186
+ help="Number of worker processes (default: 1, use 1 for reranker)"
187
+ )
188
+
189
+ args = parser.parse_args()
190
+
191
+ if not HAS_FASTAPI:
192
+ print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn")
193
+ sys.exit(1)
194
+
195
+ if not HAS_FLAGEMBEDDING:
196
+ print("Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
197
+ sys.exit(1)
198
+
199
+ # Create app
200
+ app = create_app(
201
+ model_path=args.model_path,
202
+ use_fp16=args.use_fp16,
203
+ device=args.device
204
+ )
205
+
206
+ # Run server
207
+ print(f"Starting reranker API service on {args.host}:{args.port}")
208
+ print(f"Model: {args.model_path}")
209
+ print(f"FP16: {args.use_fp16}")
210
+
211
+ uvicorn.run(
212
+ app,
213
+ host=args.host,
214
+ port=args.port,
215
+ workers=args.workers,
216
+ log_level="info"
217
+ )
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main()
shared/utils/reranker_endpoint_pool.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reranker Endpoint Pool Manager
3
+
4
+ Manage multiple reranker API endpoints, implement round-robin access and load balancing.
5
+ Reuse the logic of VLLMEndpointPool.
6
+ """
7
+ from pathlib import Path
8
+ from typing import List, Optional, Dict
9
+ from threading import Lock
10
+
11
+
12
+ class RerankerEndpointPool:
13
+ """
14
+ Reranker Endpoint Pool Manager
15
+
16
+ Features:
17
+ 1. Load multiple reranker API endpoints from file
18
+ 2. Round-robin access endpoints (ensure uniform distribution)
19
+ 3. Track usage status and errors for each endpoint
20
+ """
21
+
22
+ def __init__(self, pool_path: Optional[str] = None, endpoints: Optional[List[str]] = None, use_round_robin: bool = True):
23
+ """
24
+ Initialize Reranker Endpoint Pool
25
+
26
+ Args:
27
+ pool_path: Endpoints file path (one endpoint URL per line)
28
+ endpoints: directly provide endpoints list (prior to pool_path)
29
+ use_round_robin: whether to use round-robin strategy (default True, recommended)
30
+ """
31
+ self.endpoints: List[str] = []
32
+ self.current_index: int = 0 # Round-robin current index
33
+ self.endpoint_status: Dict[str, Dict] = {} # status information for each endpoint
34
+ self.lock = Lock() # thread safe lock
35
+ self.use_round_robin = use_round_robin # whether to use round-robin
36
+
37
+ # load endpoints
38
+ if endpoints:
39
+ self.endpoints = endpoints
40
+ elif pool_path:
41
+ self._load_from_file(pool_path)
42
+ else:
43
+ raise ValueError("Either pool_path or endpoints must be provided")
44
+
45
+ if not self.endpoints:
46
+ raise ValueError("No endpoints loaded")
47
+
48
+ # initialize status for each endpoint
49
+ for endpoint in self.endpoints:
50
+ if endpoint not in self.endpoint_status:
51
+ self.endpoint_status[endpoint] = {
52
+ 'total_requests': 0,
53
+ 'successful_requests': 0,
54
+ 'failed_requests': 0,
55
+ 'total_response_time': 0.0,
56
+ 'last_error': None,
57
+ }
58
+
59
+ print(f"RerankerEndpointPool initialized with {len(self.endpoints)} endpoints")
60
+ for i, endpoint in enumerate(self.endpoints):
61
+ print(f" [{i+1}] {endpoint}")
62
+
63
+ def _load_from_file(self, pool_path: str):
64
+ """load endpoints from file"""
65
+ path = Path(pool_path)
66
+ if not path.is_absolute():
67
+ # try to find file relative to shared/configs/
68
+ project_root = Path(__file__).parent.parent.parent
69
+ path = project_root / "shared" / "configs" / pool_path
70
+
71
+ if not path.exists():
72
+ raise FileNotFoundError(f"Reranker endpoint pool file not found: {path}")
73
+
74
+ with open(path, 'r', encoding='utf-8') as f:
75
+ lines = f.readlines()
76
+
77
+ self.endpoints = []
78
+ for line in lines:
79
+ line = line.strip()
80
+ if line and not line.startswith('#'):
81
+ # ensure URL format is correct
82
+ if not line.startswith('http://') and not line.startswith('https://'):
83
+ line = f"http://{line}"
84
+ self.endpoints.append(line)
85
+
86
+ def get_endpoint(self) -> str:
87
+ """
88
+ Get next available endpoint (round-robin strategy)
89
+
90
+ Returns:
91
+ Available endpoint URL
92
+ """
93
+ with self.lock:
94
+ if not self.endpoints:
95
+ raise ValueError("No reranker endpoints available in pool")
96
+
97
+ if self.use_round_robin:
98
+ # Round-robin: simple round-robin, ensure uniform distribution
99
+ selected_idx = self.current_index
100
+ self.current_index = (self.current_index + 1) % len(self.endpoints)
101
+
102
+ selected_endpoint = self.endpoints[selected_idx]
103
+ self.endpoint_status[selected_endpoint]['total_requests'] += 1
104
+
105
+ return selected_endpoint
106
+ else:
107
+ # smart selection mode (can select based on error rate, etc.)
108
+ # simple implementation: select endpoint with least requests
109
+ min_requests = min(
110
+ self.endpoint_status[ep]['total_requests']
111
+ for ep in self.endpoints
112
+ )
113
+ candidates = [
114
+ ep for ep in self.endpoints
115
+ if self.endpoint_status[ep]['total_requests'] == min_requests
116
+ ]
117
+ selected_endpoint = candidates[0]
118
+ self.endpoint_status[selected_endpoint]['total_requests'] += 1
119
+ return selected_endpoint
120
+
121
+ def mark_success(self, endpoint: str, response_time: float = 0.0):
122
+ """mark endpoint request as successful"""
123
+ with self.lock:
124
+ if endpoint in self.endpoint_status:
125
+ self.endpoint_status[endpoint]['successful_requests'] += 1
126
+ self.endpoint_status[endpoint]['total_response_time'] += response_time
127
+
128
+ def mark_error(self, endpoint: str, error: str):
129
+ """mark endpoint request as failed"""
130
+ with self.lock:
131
+ if endpoint in self.endpoint_status:
132
+ self.endpoint_status[endpoint]['failed_requests'] += 1
133
+ self.endpoint_status[endpoint]['last_error'] = error
134
+
135
+ def get_status(self) -> Dict:
136
+ """get pool status information"""
137
+ with self.lock:
138
+ endpoints_status = {}
139
+ for endpoint, status in self.endpoint_status.items():
140
+ total = status['total_requests']
141
+ success = status['successful_requests']
142
+ failed = status['failed_requests']
143
+ avg_time = (
144
+ status['total_response_time'] / success
145
+ if success > 0 else 0.0
146
+ )
147
+
148
+ endpoints_status[endpoint] = {
149
+ 'total_requests': total,
150
+ 'successful_requests': success,
151
+ 'failed_requests': failed,
152
+ 'success_rate': success / total if total > 0 else 0.0,
153
+ 'avg_response_time': avg_time,
154
+ 'last_error': status['last_error'],
155
+ }
156
+
157
+ return {
158
+ 'total_endpoints': len(self.endpoints),
159
+ 'endpoints_status': endpoints_status,
160
+ }
shared/utils/reranker_pool.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reranker Pool for multiprocessing-safe reranker sharing
3
+
4
+ Solve the issue of FlagReranker not being pickleable in multi-process/multi-thread environment.
5
+ """
6
+ import os
7
+ import threading
8
+ from typing import Optional, Dict
9
+ from pathlib import Path
10
+
11
+ # Global reranker storage (thread-safe)
12
+ # Note: Dictionary access is atomic in Python for simple operations,
13
+ # but we use a lock for thread-safety when modifying the dict
14
+ _reranker_pool: Dict[str, object] = {}
15
+ _reranker_lock = threading.Lock()
16
+
17
+
18
+ def get_reranker(model_path: str, use_fp16: bool = True):
19
+ """
20
+ Get or create reranker (thread-safe, process-shared)
21
+
22
+ Performance optimization:
23
+ - Load and cache on first call
24
+ - Return cached instance on subsequent calls (no lock check)
25
+ - Use double-check locking pattern, reduce lock contention
26
+
27
+ Args:
28
+ model_path: Reranker model path
29
+ use_fp16: whether to use FP16
30
+
31
+ Returns:
32
+ FlagReranker instance
33
+ """
34
+ global _reranker_pool
35
+
36
+ # create unique key
37
+ key = f"{model_path}_{use_fp16}"
38
+
39
+ # performance optimization: fast path check (no lock)
40
+ if key in _reranker_pool:
41
+ return _reranker_pool[key]
42
+
43
+ # slow path: needs loading (needs lock)
44
+ with _reranker_lock:
45
+ # double check: other threads may have loaded while waiting for lock
46
+ if key not in _reranker_pool:
47
+ # lazy import, avoid importing when module is loaded
48
+ try:
49
+ from FlagEmbedding import FlagReranker
50
+
51
+ # set environment variable to suppress progress bar
52
+ original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
53
+ os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
54
+
55
+ try:
56
+ # load model
57
+ reranker = FlagReranker(model_path, use_fp16=use_fp16)
58
+ _reranker_pool[key] = reranker
59
+ finally:
60
+ # restore original verbosity
61
+ if original_verbosity:
62
+ os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
63
+ elif 'TRANSFORMERS_VERBOSITY' in os.environ:
64
+ del os.environ['TRANSFORMERS_VERBOSITY']
65
+
66
+ except ImportError:
67
+ raise ImportError(
68
+ "FlagEmbedding not installed. Install it with: pip install FlagEmbedding"
69
+ )
70
+
71
+ return _reranker_pool[key]
72
+
73
+
74
+ def clear_reranker_pool():
75
+ """clear reranker pool (mainly for testing)"""
76
+ global _reranker_pool
77
+ with _reranker_lock:
78
+ _reranker_pool.clear()
shared/utils/review_logger.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Review Logger Utility
3
+
4
+ Captures and logs all intermediate outputs from the review pipeline
5
+ """
6
+ import json
7
+ import uuid
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Dict, Any, Optional, List
11
+
12
+
13
+ class ReviewLogger:
14
+ """
15
+ Logger for capturing complete review pipeline execution logs
16
+ """
17
+
18
+ def __init__(self, log_dir: Optional[str] = None, enabled: bool = True):
19
+ """
20
+ Initialize Review Logger
21
+
22
+ Args:
23
+ log_dir: Directory to save log files. If None, uses current directory.
24
+ enabled: Whether logging is enabled
25
+ """
26
+ self.enabled = enabled
27
+ self.log_dir = Path(log_dir) if log_dir else Path.cwd()
28
+ self.log_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ # Current run data
31
+ self.current_run_id: Optional[str] = None
32
+ self.current_run_data: Optional[Dict[str, Any]] = None
33
+
34
+ def start_run(
35
+ self,
36
+ title: str,
37
+ abstract: str,
38
+ content: Optional[str] = None,
39
+ keywords: Optional[List[str]] = None,
40
+ publication_date_range: Optional[str] = None,
41
+ venues: Optional[str] = None,
42
+ review_format: str = "detailed",
43
+ ) -> str:
44
+ """
45
+ Start a new review run and generate UUID
46
+
47
+ IMPORTANT: If current_run_data already exists, this method will preserve existing
48
+ intermediate_outputs data to prevent data loss. Only input data and metadata are updated.
49
+
50
+ Args:
51
+ title: Paper title
52
+ abstract: Paper abstract
53
+ content: Paper content (optional)
54
+ keywords: Existing keywords (optional)
55
+ publication_date_range: Date range filter (optional)
56
+ venues: Venue filter (optional)
57
+ review_format: Review format
58
+
59
+ Returns:
60
+ Run UUID string
61
+ """
62
+ if not self.enabled:
63
+ return ""
64
+
65
+ # Generate UUID based on timestamp
66
+ timestamp = datetime.now()
67
+ # Use timestamp-based UUID (UUID1 uses MAC address + timestamp)
68
+ run_id = str(uuid.uuid1())
69
+
70
+ # PRESERVE existing intermediate_outputs if current_run_data already exists
71
+ # This prevents data loss if start_run() is called multiple times
72
+ existing_intermediate_outputs = None
73
+ existing_final_output = None
74
+ existing_errors = None
75
+ if self.current_run_data is not None:
76
+ existing_intermediate_outputs = self.current_run_data.get("intermediate_outputs")
77
+ existing_final_output = self.current_run_data.get("final_output")
78
+ existing_errors = self.current_run_data.get("errors", [])
79
+
80
+ self.current_run_id = run_id
81
+
82
+ # Initialize intermediate_outputs: use existing data if available, otherwise create new
83
+ if existing_intermediate_outputs is not None:
84
+ # Preserve existing intermediate outputs
85
+ intermediate_outputs = existing_intermediate_outputs
86
+ # Only initialize None fields if they don't exist
87
+ if "generated_keywords" not in intermediate_outputs:
88
+ intermediate_outputs["generated_keywords"] = None
89
+ if "retrieved_papers" not in intermediate_outputs:
90
+ intermediate_outputs["retrieved_papers"] = []
91
+ if "paper_summaries" not in intermediate_outputs:
92
+ intermediate_outputs["paper_summaries"] = []
93
+ if "related_work_json_list" not in intermediate_outputs:
94
+ intermediate_outputs["related_work_json_list"] = None
95
+ if "paper_results_analyzer_output" not in intermediate_outputs:
96
+ intermediate_outputs["paper_results_analyzer_output"] = None
97
+ if "paper_insight_miner_output" not in intermediate_outputs:
98
+ intermediate_outputs["paper_insight_miner_output"] = None
99
+ if "review_prompt" not in intermediate_outputs:
100
+ intermediate_outputs["review_prompt"] = None
101
+ if "review_llm_response" not in intermediate_outputs:
102
+ intermediate_outputs["review_llm_response"] = None
103
+ if "parsed_review" not in intermediate_outputs:
104
+ intermediate_outputs["parsed_review"] = None
105
+ if "refiner_prompt" not in intermediate_outputs:
106
+ intermediate_outputs["refiner_prompt"] = None
107
+ if "refiner_llm_response" not in intermediate_outputs:
108
+ intermediate_outputs["refiner_llm_response"] = None
109
+ if "parsed_refined_review" not in intermediate_outputs:
110
+ intermediate_outputs["parsed_refined_review"] = None
111
+ else:
112
+ # Create new intermediate_outputs structure
113
+ intermediate_outputs = {
114
+ "generated_keywords": None,
115
+ "retrieved_papers": [],
116
+ "paper_summaries": [],
117
+ "related_work_json_list": None,
118
+ "paper_results_analyzer_output": None,
119
+ "paper_insight_miner_output": None,
120
+ "review_prompt": None,
121
+ "review_llm_response": None,
122
+ "parsed_review": None,
123
+ "refiner_prompt": None,
124
+ "refiner_llm_response": None,
125
+ "parsed_refined_review": None,
126
+ }
127
+
128
+ self.current_run_data = {
129
+ "run_id": run_id,
130
+ "timestamp": timestamp.isoformat(),
131
+ "input": {
132
+ "title": title,
133
+ "abstract": abstract,
134
+ "content": content,
135
+ "keywords": keywords,
136
+ "publication_date_range": publication_date_range,
137
+ "venues": venues,
138
+ "review_format": review_format,
139
+ },
140
+ "intermediate_outputs": intermediate_outputs,
141
+ "final_output": existing_final_output,
142
+ "errors": existing_errors if existing_errors is not None else [],
143
+ }
144
+
145
+ return run_id
146
+
147
+ def log_keywords(self, keywords: List[str]):
148
+ """Log generated search keywords"""
149
+ if self.enabled and self.current_run_data:
150
+ # Ensure intermediate_outputs exists
151
+ if "intermediate_outputs" not in self.current_run_data:
152
+ self.current_run_data["intermediate_outputs"] = {}
153
+ self.current_run_data["intermediate_outputs"]["generated_keywords"] = keywords
154
+
155
+ def log_retrieved_papers(self, papers: List[Dict[str, Any]]):
156
+ """Log retrieved papers (raw)"""
157
+ if self.enabled and self.current_run_data:
158
+ # Ensure intermediate_outputs exists
159
+ if "intermediate_outputs" not in self.current_run_data:
160
+ self.current_run_data["intermediate_outputs"] = {}
161
+ # Store paper metadata (may be large, so we store essential info)
162
+ self.current_run_data["intermediate_outputs"]["retrieved_papers"] = [
163
+ {
164
+ "paper_id": p.get("paper_id"),
165
+ "title": p.get("title"),
166
+ "authors": p.get("authors", [])[:10], # Limit authors
167
+ "year": p.get("year"),
168
+ "venue": p.get("venue"),
169
+ "abstract": p.get("abstract", "")[:500], # Truncate abstract
170
+ "citation_counts": p.get("citation_counts", 0),
171
+ }
172
+ for p in papers
173
+ ]
174
+
175
+ def log_paper_summary(self, paper_title: str, summary: str, paper_index: int):
176
+ """Log a single paper summary"""
177
+ if self.enabled and self.current_run_data:
178
+ # Ensure intermediate_outputs exists
179
+ if "intermediate_outputs" not in self.current_run_data:
180
+ self.current_run_data["intermediate_outputs"] = {}
181
+ if "paper_summaries" not in self.current_run_data["intermediate_outputs"]:
182
+ self.current_run_data["intermediate_outputs"]["paper_summaries"] = []
183
+ self.current_run_data["intermediate_outputs"]["paper_summaries"].append({
184
+ "paper_index": paper_index,
185
+ "paper_title": paper_title,
186
+ "summary": summary,
187
+ })
188
+
189
+ def log_related_work_json_list(self, related_work_json_list: List[Dict[str, Any]]):
190
+ """Log the final related work JSON list"""
191
+ if self.enabled and self.current_run_data:
192
+ # Ensure intermediate_outputs exists
193
+ if "intermediate_outputs" not in self.current_run_data:
194
+ self.current_run_data["intermediate_outputs"] = {}
195
+ self.current_run_data["intermediate_outputs"]["related_work_json_list"] = related_work_json_list
196
+
197
+ def log_paper_results_analyzer_output(self, results_analyzer_output: str):
198
+ """Log the paper results analyzer JSON output"""
199
+ if self.enabled and self.current_run_data:
200
+ # Ensure intermediate_outputs exists
201
+ if "intermediate_outputs" not in self.current_run_data:
202
+ self.current_run_data["intermediate_outputs"] = {}
203
+ self.current_run_data["intermediate_outputs"]["paper_results_analyzer_output"] = results_analyzer_output
204
+
205
+ def log_paper_insight_miner_output(self, insight_miner_output: str):
206
+ """Log the paper insight miner JSON output"""
207
+ if self.enabled and self.current_run_data:
208
+ # Ensure intermediate_outputs exists
209
+ if "intermediate_outputs" not in self.current_run_data:
210
+ self.current_run_data["intermediate_outputs"] = {}
211
+ self.current_run_data["intermediate_outputs"]["paper_insight_miner_output"] = insight_miner_output
212
+
213
+ def log_review_prompt(self, prompt: str, system_message: Optional[str] = None):
214
+ """Log the review prompt sent to LLM"""
215
+ if self.enabled and self.current_run_data:
216
+ # Ensure intermediate_outputs exists
217
+ if "intermediate_outputs" not in self.current_run_data:
218
+ self.current_run_data["intermediate_outputs"] = {}
219
+ self.current_run_data["intermediate_outputs"]["review_prompt"] = {
220
+ "system_message": system_message,
221
+ "user_prompt": prompt,
222
+ }
223
+
224
+ def log_review_llm_response(self, response: str):
225
+ """Log the raw LLM response for review"""
226
+ if self.enabled and self.current_run_data:
227
+ # Ensure intermediate_outputs exists
228
+ if "intermediate_outputs" not in self.current_run_data:
229
+ self.current_run_data["intermediate_outputs"] = {}
230
+ self.current_run_data["intermediate_outputs"]["review_llm_response"] = response
231
+
232
+ def log_parsed_review(self, parsed_review: Dict[str, Any]):
233
+ """Log the parsed review dictionary"""
234
+ if self.enabled and self.current_run_data:
235
+ # Ensure intermediate_outputs exists
236
+ if "intermediate_outputs" not in self.current_run_data:
237
+ self.current_run_data["intermediate_outputs"] = {}
238
+ self.current_run_data["intermediate_outputs"]["parsed_review"] = parsed_review
239
+
240
+ def log_refiner_prompt(self, prompt: str, system_message: Optional[str] = None):
241
+ """Log the refiner prompt sent to LLM"""
242
+ if self.enabled and self.current_run_data:
243
+ # Ensure intermediate_outputs exists
244
+ if "intermediate_outputs" not in self.current_run_data:
245
+ self.current_run_data["intermediate_outputs"] = {}
246
+ self.current_run_data["intermediate_outputs"]["refiner_prompt"] = {
247
+ "system_message": system_message,
248
+ "user_prompt": prompt,
249
+ }
250
+
251
+ def log_refiner_llm_response(self, response: str):
252
+ """Log the raw LLM response for refiner"""
253
+ if self.enabled and self.current_run_data:
254
+ # Ensure intermediate_outputs exists
255
+ if "intermediate_outputs" not in self.current_run_data:
256
+ self.current_run_data["intermediate_outputs"] = {}
257
+ self.current_run_data["intermediate_outputs"]["refiner_llm_response"] = response
258
+
259
+ def log_parsed_refined_review(self, parsed_review: Dict[str, Any]):
260
+ """Log the parsed refined review dictionary"""
261
+ if self.enabled and self.current_run_data:
262
+ # Ensure intermediate_outputs exists
263
+ if "intermediate_outputs" not in self.current_run_data:
264
+ self.current_run_data["intermediate_outputs"] = {}
265
+ self.current_run_data["intermediate_outputs"]["parsed_refined_review"] = parsed_review
266
+
267
+ def log_final_output(self, final_output: Dict[str, Any]):
268
+ """Log the final review output"""
269
+ if self.enabled and self.current_run_data:
270
+ self.current_run_data["final_output"] = final_output
271
+
272
+ def log_error(self, error: str, step: Optional[str] = None):
273
+ """Log an error that occurred during execution"""
274
+ if self.enabled and self.current_run_data:
275
+ if "errors" not in self.current_run_data:
276
+ self.current_run_data["errors"] = []
277
+ self.current_run_data["errors"].append({
278
+ "step": step,
279
+ "error": error,
280
+ "timestamp": datetime.now().isoformat(),
281
+ })
282
+
283
+ def save_run(self) -> Optional[str]:
284
+ """
285
+ Save the current run to a JSON file
286
+
287
+ Returns:
288
+ Path to saved log file, or None if logging is disabled
289
+ """
290
+ if not self.enabled or not self.current_run_data:
291
+ return None
292
+
293
+ # Generate filename with timestamp and UUID
294
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
295
+ filename = f"review_log_{timestamp_str}_{self.current_run_id[:8]}.json"
296
+ log_path = self.log_dir / filename
297
+
298
+ # Save to JSON
299
+ with open(log_path, 'w', encoding='utf-8') as f:
300
+ json.dump(self.current_run_data, f, indent=2, ensure_ascii=False)
301
+
302
+ return str(log_path)
303
+
304
+ def get_current_run_id(self) -> Optional[str]:
305
+ """Get the current run ID"""
306
+ return self.current_run_id if self.enabled else None
shared/utils/vllm_endpoint_pool.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VLLM Endpoint Pool Manager
3
+
4
+ Manage multiple vLLM endpoints, implement round-robin access and load balancing.
5
+ """
6
+ import random
7
+ from pathlib import Path
8
+ from typing import List, Optional, Dict
9
+ from threading import Lock
10
+
11
+
12
+ class VLLMEndpointPool:
13
+ """
14
+ VLLM Endpoint Pool Manager
15
+
16
+ Features:
17
+ 1. Load multiple vLLM endpoints from file
18
+ 2. Round-robin access endpoints (ensure uniform distribution)
19
+ 3. Track usage status and errors for each endpoint
20
+ 4. Smart selection (based on error rate and success rate, as backup)
21
+ """
22
+
23
+ def __init__(self, pool_path: Optional[str] = None, endpoints: Optional[List[str]] = None, use_round_robin: bool = True):
24
+ """
25
+ Initialize VLLM Endpoint Pool
26
+
27
+ Args:
28
+ pool_path: Endpoints file path (one endpoint URL per line)
29
+ endpoints: directly provide endpoints list (prior to pool_path)
30
+ use_round_robin: whether to use round-robin strategy (True=uniform distribution, False=smart selection)
31
+ """
32
+ self.endpoints: List[str] = []
33
+ self.current_index: int = 0 # Round-robin current index
34
+ self.used_indices: List[int] = [] # used indices in current round (for smart selection)
35
+ self.endpoint_status: Dict[str, Dict] = {} # status information for each endpoint
36
+ self.lock = Lock() # thread safe lock
37
+ self.use_round_robin = use_round_robin # whether to use round-robin
38
+
39
+ # load endpoints
40
+ if endpoints:
41
+ self.endpoints = [e.strip() for e in endpoints if e.strip()]
42
+ elif pool_path:
43
+ self._load_from_file(pool_path)
44
+ else:
45
+ # try to get single endpoint from environment variable (backward compatibility)
46
+ import os
47
+ env_endpoint = os.environ.get("VLLM_BASE_URL")
48
+ if env_endpoint:
49
+ # ensure format is correct (may need to add /v1)
50
+ if not env_endpoint.endswith('/v1'):
51
+ if env_endpoint.endswith('/'):
52
+ env_endpoint = env_endpoint.rstrip('/') + '/v1'
53
+ else:
54
+ env_endpoint = env_endpoint + '/v1'
55
+ self.endpoints = [env_endpoint]
56
+
57
+ if not self.endpoints:
58
+ raise ValueError(
59
+ "No vLLM endpoints available. Provide endpoints via pool_path, endpoints parameter, "
60
+ "or VLLM_BASE_URL environment variable."
61
+ )
62
+
63
+ # initialize status for each endpoint
64
+ for endpoint in self.endpoints:
65
+ self.endpoint_status[endpoint] = {
66
+ 'error_count': 0,
67
+ 'last_error_time': None,
68
+ 'consecutive_errors': 0,
69
+ 'total_requests': 0,
70
+ 'successful_requests': 0,
71
+ 'total_response_time': 0.0, # 累计响应时间
72
+ }
73
+
74
+ def _load_from_file(self, pool_path: str):
75
+ """load vLLM endpoints from file"""
76
+ path = Path(pool_path)
77
+
78
+ # if relative path, try to find file relative to shared/configs
79
+ if not path.is_absolute():
80
+ # try to find file relative to project root
81
+ project_root = Path(__file__).parent.parent.parent
82
+ path = project_root / "shared" / "configs" / pool_path
83
+ if not path.exists():
84
+ # try to find file relative to shared/configs
85
+ path = Path(__file__).parent.parent / "configs" / pool_path
86
+
87
+ if not path.exists():
88
+ raise FileNotFoundError(
89
+ f"VLLM endpoint pool file not found: {pool_path} (tried: {path})"
90
+ )
91
+
92
+ with open(path, 'r', encoding='utf-8') as f:
93
+ lines = f.readlines()
94
+
95
+ endpoints = []
96
+ for line in lines:
97
+ line = line.strip()
98
+ if line and not line.startswith('#'):
99
+ # ensure format is correct (may need to add /v1)
100
+ if not line.endswith('/v1'):
101
+ if line.endswith('/'):
102
+ line = line.rstrip('/') + '/v1'
103
+ else:
104
+ line = line + '/v1'
105
+ endpoints.append(line)
106
+
107
+ self.endpoints = endpoints
108
+
109
+ if not self.endpoints:
110
+ raise ValueError(f"No valid vLLM endpoints found in pool file: {pool_path}")
111
+
112
+ def get_endpoint(self) -> str:
113
+ """
114
+ Get next available endpoint (round-robin strategy)
115
+
116
+ Strategy:
117
+ - Round-robin mode (default): simple round-robin, ensure uniform distribution
118
+ - Smart selection mode: select best endpoint based on error rate, success rate, response time
119
+
120
+ Returns:
121
+ Available endpoint URL
122
+ """
123
+ import time
124
+
125
+ with self.lock:
126
+ if not self.endpoints:
127
+ raise ValueError("No vLLM endpoints available in pool")
128
+
129
+ if self.use_round_robin:
130
+ # Round-robin: simple round-robin, ensure uniform distribution
131
+ selected_idx = self.current_index
132
+ self.current_index = (self.current_index + 1) % len(self.endpoints)
133
+
134
+ selected_endpoint = self.endpoints[selected_idx]
135
+ self.endpoint_status[selected_endpoint]['total_requests'] += 1
136
+
137
+ return selected_endpoint
138
+ else:
139
+ # smart selection mode (original logic)
140
+ # if current round is complete, start a new round
141
+ if len(self.used_indices) >= len(self.endpoints):
142
+ self.used_indices = []
143
+
144
+ # get indices not used in current round
145
+ available_indices = [i for i in range(len(self.endpoints)) if i not in self.used_indices]
146
+
147
+ if not available_indices:
148
+ # all endpoints are in current round, start a new round
149
+ available_indices = list(range(len(self.endpoints)))
150
+ self.used_indices = []
151
+
152
+ # prioritize endpoints with fewer errors and higher success rate
153
+ endpoint_scores = []
154
+ for idx in available_indices:
155
+ endpoint = self.endpoints[idx]
156
+ status = self.endpoint_status[endpoint]
157
+
158
+ # calculate score: error count, success rate, response time, score越高
159
+ error_count = status['error_count']
160
+ total = status['total_requests']
161
+ success_rate = (status['successful_requests'] / total) if total > 0 else 1.0
162
+
163
+ # calculate average response time (shorter is better)
164
+ avg_response_time = (
165
+ status['total_response_time'] / status['successful_requests']
166
+ if status['successful_requests'] > 0 else 0.0
167
+ )
168
+ # normalize response time score (assume 10 seconds as baseline, faster score higher)
169
+ response_time_score = 1.0 / (1.0 + avg_response_time / 10.0)
170
+
171
+ # if recent error, reduce score
172
+ recent_error_penalty = 0
173
+ if status['last_error_time']:
174
+ time_since_error = time.time() - status['last_error_time']
175
+ if time_since_error < 60: # 1 minute内
176
+ recent_error_penalty = 0.5
177
+
178
+ score = success_rate - (error_count * 0.1) - recent_error_penalty + (response_time_score * 0.2)
179
+ endpoint_scores.append((idx, score))
180
+
181
+ # sort by score, select highest score (but add some randomness)
182
+ endpoint_scores.sort(key=lambda x: x[1], reverse=True)
183
+
184
+ # select from top 50% (add randomness but prioritize better)
185
+ top_n = max(1, len(endpoint_scores) // 2) if len(endpoint_scores) > 1 else 1
186
+ selected_idx, _ = random.choice(endpoint_scores[:top_n])
187
+
188
+ # mark as used
189
+ self.used_indices.append(selected_idx)
190
+
191
+ selected_endpoint = self.endpoints[selected_idx]
192
+ self.endpoint_status[selected_endpoint]['total_requests'] += 1
193
+
194
+ return selected_endpoint
195
+
196
+ def mark_success(self, endpoint: str, response_time: float = 0.0):
197
+ """
198
+ mark endpoint as successful
199
+
200
+ Args:
201
+ endpoint: successful endpoint URL
202
+ response_time: response time (seconds)
203
+ """
204
+ with self.lock:
205
+ if endpoint in self.endpoint_status:
206
+ status = self.endpoint_status[endpoint]
207
+ status['successful_requests'] += 1
208
+ status['consecutive_errors'] = 0
209
+ status['total_response_time'] += response_time
210
+
211
+ def mark_error(self, endpoint: str, error_type: str = "server_error"):
212
+ """
213
+ mark endpoint as failed
214
+
215
+ Args:
216
+ endpoint: failed endpoint URL
217
+ error_type: error type ("server_error", "timeout", "connection_error", "other")
218
+ """
219
+ import time
220
+
221
+ with self.lock:
222
+ if endpoint in self.endpoint_status:
223
+ status = self.endpoint_status[endpoint]
224
+ status['error_count'] += 1
225
+ status['consecutive_errors'] += 1
226
+ status['last_error_time'] = time.time()
227
+
228
+ def get_status(self) -> Dict:
229
+ """get pool status information (for debugging)"""
230
+ with self.lock:
231
+ return {
232
+ 'total_endpoints': len(self.endpoints),
233
+ 'current_round_progress': f"{len(self.used_indices)}/{len(self.endpoints)}",
234
+ 'endpoints_status': {
235
+ endpoint: {
236
+ 'error_count': status['error_count'],
237
+ 'successful_requests': status['successful_requests'],
238
+ 'total_requests': status['total_requests'],
239
+ 'success_rate': (
240
+ status['successful_requests'] / status['total_requests']
241
+ if status['total_requests'] > 0 else 0.0
242
+ ),
243
+ 'avg_response_time': (
244
+ status['total_response_time'] / status['successful_requests']
245
+ if status['successful_requests'] > 0 else 0.0
246
+ ),
247
+ 'consecutive_errors': status['consecutive_errors'],
248
+ 'last_error_time': status['last_error_time'],
249
+ }
250
+ for endpoint, status in self.endpoint_status.items()
251
+ }
252
+ }
253
+
254
+ def reset_round(self):
255
+ """reset current round (force start a new round)"""
256
+ with self.lock:
257
+ self.used_indices = []
shared/utils/vllm_service.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified vLLM service for review system
3
+
4
+ This service only handles API calls, no load balancing logic.
5
+ Load balancing should be handled at the deployment service level (e.g., nginx reverse proxy).
6
+ """
7
+ import os
8
+ import time
9
+ import random
10
+ import yaml
11
+ from pathlib import Path
12
+ from typing import List, Dict, Optional, Any, Union
13
+ from threading import Semaphore, Lock as ThreadLock
14
+ from openai import OpenAI
15
+ from .llm_service import LLMService, ChatMessage
16
+
17
+
18
+ class VLLMService(LLMService):
19
+ """
20
+ Simplified vLLM service wrapper for local LLM deployment
21
+
22
+ This service connects to a vLLM server endpoint.
23
+ Load balancing should be handled at the deployment level (e.g., nginx, multiple services behind a load balancer).
24
+
25
+ Features:
26
+ - Simple API calls to a single endpoint
27
+ - Automatic retry with exponential backoff for 500 errors
28
+ - Configurable max concurrent requests (per service instance)
29
+ """
30
+
31
+ # Class-level semaphore for rate limiting (shared across all instances of this service)
32
+ # Use lazy initialization to avoid pickle issues with multiprocessing
33
+ _request_semaphore: Optional[Semaphore] = None
34
+ _max_concurrent_requests: int = 8 # Default limit
35
+ _semaphore_lock = ThreadLock() # Thread-safe initialization lock
36
+
37
+ def __init__(
38
+ self,
39
+ base_url: Optional[str] = None,
40
+ api_key: Optional[str] = None,
41
+ model_name: Optional[str] = None,
42
+ timeout: Optional[int] = None,
43
+ config_file: Optional[str] = None,
44
+ max_concurrent_requests: Optional[int] = None,
45
+ max_retries: int = 3,
46
+ retry_delay: float = 1.0,
47
+ retry_backoff: float = 2.0,
48
+ ):
49
+ """
50
+ Initialize vLLM service
51
+
52
+ Args:
53
+ base_url: vLLM server base URL (default: from config or http://localhost:8000/v1)
54
+ api_key: API key (overrides config)
55
+ model_name: Model name identifier (overrides config)
56
+ timeout: Request timeout in seconds (overrides config)
57
+ config_file: Path to config file (default: configs/llm_service_config.yaml)
58
+ max_concurrent_requests: Maximum concurrent requests per service instance (default: 8)
59
+ max_retries: Maximum number of retries for failed requests (default: 3)
60
+ retry_delay: Initial retry delay in seconds (default: 1.0)
61
+ retry_backoff: Retry delay multiplier (default: 2.0)
62
+ """
63
+ # Load config from YAML
64
+ config = self._load_config(config_file)
65
+ vllm_config = config.get("vllm", {})
66
+
67
+ # Use provided values or fall back to config, then environment variables
68
+ self.base_url = base_url or vllm_config.get("base_url") or os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1")
69
+ self.model_name = model_name or vllm_config.get("model_name", "Qwen/Qwen3-4B-Instruct-2507")
70
+ self.api_key = api_key or vllm_config.get("api_key", "dummy-key")
71
+ self.timeout = timeout or vllm_config.get("timeout", 300)
72
+
73
+ # Retry configuration
74
+ self.max_retries = max_retries
75
+ self.retry_delay = retry_delay
76
+ self.retry_backoff = retry_backoff
77
+
78
+ # Rate limiting: Initialize class-level semaphore if not already initialized
79
+ # Use lazy initialization with thread-safe check to avoid pickle issues
80
+ if max_concurrent_requests is not None:
81
+ VLLMService._max_concurrent_requests = max_concurrent_requests
82
+ else:
83
+ # Try to get from config
84
+ config_max_concurrent = vllm_config.get("max_concurrent_requests")
85
+ if config_max_concurrent is not None:
86
+ VLLMService._max_concurrent_requests = config_max_concurrent
87
+
88
+ # Lazy initialization of semaphore will happen on first use
89
+ # This avoids pickle issues when using multiprocessing/ThreadPoolExecutor
90
+
91
+ # Store default sampling parameters from config
92
+ self.default_temperature = vllm_config.get("temperature", 0.7)
93
+ self.default_top_p = vllm_config.get("top_p", 0.8)
94
+ self.default_top_k = vllm_config.get("top_k", 20)
95
+ self.default_max_tokens = vllm_config.get("max_tokens", 16384)
96
+ self.default_presence_penalty = vllm_config.get("presence_penalty", 0.0)
97
+
98
+ # Create OpenAI client
99
+ self.client = OpenAI(
100
+ api_key=self.api_key,
101
+ base_url=self.base_url,
102
+ timeout=self.timeout,
103
+ )
104
+
105
+ @staticmethod
106
+ def _load_config(config_file: Optional[str] = None) -> Dict[str, Any]:
107
+ """
108
+ Load configuration from YAML file
109
+
110
+ Args:
111
+ config_file: Path to config file
112
+
113
+ Returns:
114
+ Configuration dictionary
115
+ """
116
+ if config_file is None:
117
+ project_root = Path(__file__).parent.parent.parent
118
+ config_file = project_root / "shared" / "configs" / "llm_service_config.yaml"
119
+
120
+ config_path = Path(config_file)
121
+ if not config_path.exists():
122
+ # Return defaults if config file doesn't exist
123
+ return {
124
+ "vllm": {
125
+ "base_url": "http://localhost:8000/v1",
126
+ "api_key": "dummy-key",
127
+ "model_name": "Qwen/Qwen3-4B-Instruct-2507",
128
+ "timeout": 300,
129
+ "max_concurrent_requests": 8,
130
+ "max_retries": 3,
131
+ "retry_delay": 1.0,
132
+ "retry_backoff": 2.0,
133
+ "temperature": 0.7,
134
+ "top_p": 0.8,
135
+ "top_k": 20,
136
+ "max_tokens": 16384,
137
+ "presence_penalty": 0.0,
138
+ }
139
+ }
140
+
141
+ with open(config_path, 'r', encoding='utf-8') as f:
142
+ return yaml.safe_load(f) or {}
143
+
144
+ @classmethod
145
+ def _ensure_semaphore(cls):
146
+ """Thread-safe lazy initialization of semaphore to avoid pickle issues"""
147
+ if cls._request_semaphore is None:
148
+ with cls._semaphore_lock:
149
+ # Double-check pattern
150
+ if cls._request_semaphore is None:
151
+ cls._request_semaphore = Semaphore(cls._max_concurrent_requests)
152
+
153
+ def _format_messages(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> List[Dict[str, str]]:
154
+ """Format messages for OpenAI API"""
155
+ formatted = []
156
+ for msg in messages:
157
+ if isinstance(msg, ChatMessage):
158
+ formatted.append({"role": msg.role, "content": msg.content})
159
+ elif isinstance(msg, dict):
160
+ formatted.append(msg)
161
+ else:
162
+ raise ValueError(f"Invalid message type: {type(msg)}")
163
+ return formatted
164
+
165
+ def generate(
166
+ self,
167
+ messages: List[Union[ChatMessage, Dict[str, str]]],
168
+ temperature: Optional[float] = None,
169
+ top_p: Optional[float] = None,
170
+ top_k: Optional[int] = None,
171
+ max_tokens: Optional[int] = None,
172
+ presence_penalty: Optional[float] = None,
173
+ **kwargs
174
+ ) -> str:
175
+ """
176
+ Generate text from messages
177
+
178
+ Args:
179
+ messages: List of chat messages
180
+ temperature: Sampling temperature (uses config default if None)
181
+ top_p: Top-p sampling parameter (uses config default if None)
182
+ top_k: Top-k sampling parameter (uses config default if None)
183
+ max_tokens: Maximum tokens to generate (uses config default if None)
184
+ presence_penalty: Presence penalty (uses config default if None)
185
+ **kwargs: Additional parameters
186
+
187
+ Returns:
188
+ Generated text
189
+ """
190
+ formatted_messages = self._format_messages(messages)
191
+
192
+ # Use provided values or fall back to config defaults
193
+ temperature = temperature if temperature is not None else self.default_temperature
194
+ top_p = top_p if top_p is not None else self.default_top_p
195
+ max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
196
+ presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
197
+
198
+ # Ensure semaphore is initialized (lazy, thread-safe)
199
+ self._ensure_semaphore()
200
+
201
+ # Use semaphore to limit concurrent requests
202
+ with VLLMService._request_semaphore:
203
+ last_exception = None
204
+
205
+ for retry_attempt in range(self.max_retries + 1):
206
+ try:
207
+ response = self.client.chat.completions.create(
208
+ model=self.model_name,
209
+ messages=formatted_messages,
210
+ temperature=temperature,
211
+ top_p=top_p,
212
+ max_tokens=max_tokens,
213
+ presence_penalty=presence_penalty,
214
+ **kwargs
215
+ )
216
+
217
+ return response.choices[0].message.content
218
+
219
+ except Exception as e:
220
+ last_exception = e
221
+
222
+ # Check if it's a server error (500, 502, 503, 504) that we should retry
223
+ should_retry = False
224
+ error_str = str(e).lower()
225
+
226
+ if any(code in error_str for code in ["500", "502", "503", "504"]):
227
+ should_retry = True
228
+ elif "server error" in error_str or "internal server error" in error_str:
229
+ should_retry = True
230
+
231
+ # Don't retry on last attempt
232
+ if retry_attempt < self.max_retries and should_retry:
233
+ # Calculate delay with exponential backoff and jitter
234
+ delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
235
+ jitter = random.uniform(0, delay * 0.1) # 10% jitter
236
+ time.sleep(delay + jitter)
237
+ continue
238
+ else:
239
+ # Either not a retryable error or out of retries
240
+ raise last_exception
241
+
242
+ def stream_generate(
243
+ self,
244
+ messages: List[Union[ChatMessage, Dict[str, str]]],
245
+ temperature: Optional[float] = None,
246
+ top_p: Optional[float] = None,
247
+ top_k: Optional[int] = None,
248
+ max_tokens: Optional[int] = None,
249
+ presence_penalty: Optional[float] = None,
250
+ **kwargs
251
+ ):
252
+ """
253
+ Stream generate text from messages
254
+
255
+ Yields:
256
+ Generated text chunks
257
+ """
258
+ formatted_messages = self._format_messages(messages)
259
+
260
+ # Use provided values or fall back to config defaults
261
+ temperature = temperature if temperature is not None else self.default_temperature
262
+ top_p = top_p if top_p is not None else self.default_top_p
263
+ max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
264
+ presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
265
+
266
+ # Ensure semaphore is initialized (lazy, thread-safe)
267
+ self._ensure_semaphore()
268
+
269
+ # Use semaphore to limit concurrent requests
270
+ with VLLMService._request_semaphore:
271
+ last_exception = None
272
+
273
+ for retry_attempt in range(self.max_retries + 1):
274
+ try:
275
+ stream = self.client.chat.completions.create(
276
+ model=self.model_name,
277
+ messages=formatted_messages,
278
+ temperature=temperature,
279
+ top_p=top_p,
280
+ max_tokens=max_tokens,
281
+ presence_penalty=presence_penalty,
282
+ stream=True,
283
+ **kwargs
284
+ )
285
+
286
+ # Stream chunks
287
+ for chunk in stream:
288
+ if chunk.choices[0].delta.content:
289
+ yield chunk.choices[0].delta.content
290
+
291
+ return # Success, exit retry loop
292
+
293
+ except Exception as e:
294
+ last_exception = e
295
+
296
+ # Check if it's a server error that we should retry
297
+ should_retry = False
298
+ error_str = str(e).lower()
299
+
300
+ if any(code in error_str for code in ["500", "502", "503", "504"]):
301
+ should_retry = True
302
+ elif "server error" in error_str or "internal server error" in error_str:
303
+ should_retry = True
304
+
305
+ # Don't retry on last attempt
306
+ if retry_attempt < self.max_retries and should_retry:
307
+ # Calculate delay with exponential backoff and jitter
308
+ delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
309
+ jitter = random.uniform(0, delay * 0.1) # 10% jitter
310
+ time.sleep(delay + jitter)
311
+ continue
312
+ else:
313
+ # Either not a retryable error or out of retries
314
+ raise last_exception
shared/utils/vllm_service_simple.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified vLLM service for review system
3
+
4
+ This service only handles API calls, no load balancing logic.
5
+ Load balancing should be handled at the deployment service level (e.g., nginx reverse proxy).
6
+ """
7
+ import os
8
+ import time
9
+ import random
10
+ import yaml
11
+ from pathlib import Path
12
+ from typing import List, Dict, Optional, Any, Union
13
+ from threading import Semaphore, Lock as ThreadLock
14
+ from openai import OpenAI
15
+ from .llm_service import LLMService, ChatMessage
16
+
17
+
18
+ class VLLMService(LLMService):
19
+ """
20
+ Simplified vLLM service wrapper for local LLM deployment
21
+
22
+ This service connects to a vLLM server endpoint.
23
+ Load balancing should be handled at the deployment level (e.g., nginx, multiple services behind a load balancer).
24
+
25
+ Features:
26
+ - Simple API calls to a single endpoint
27
+ - Automatic retry with exponential backoff for 500 errors
28
+ - Configurable max concurrent requests (per service instance)
29
+ """
30
+
31
+ # Class-level semaphore for rate limiting (shared across all instances of this service)
32
+ # Use lazy initialization to avoid pickle issues with multiprocessing
33
+ _request_semaphore: Optional[Semaphore] = None
34
+ _max_concurrent_requests: int = 8 # Default limit
35
+ _semaphore_lock = ThreadLock() # Thread-safe initialization lock
36
+
37
+ def __init__(
38
+ self,
39
+ base_url: Optional[str] = None,
40
+ api_key: Optional[str] = None,
41
+ model_name: Optional[str] = None,
42
+ timeout: Optional[int] = None,
43
+ config_file: Optional[str] = None,
44
+ max_concurrent_requests: Optional[int] = None,
45
+ max_retries: int = 3,
46
+ retry_delay: float = 1.0,
47
+ retry_backoff: float = 2.0,
48
+ ):
49
+ """
50
+ Initialize vLLM service
51
+
52
+ Args:
53
+ base_url: vLLM server base URL (default: from config or http://localhost:8000/v1)
54
+ api_key: API key (overrides config)
55
+ model_name: Model name identifier (overrides config)
56
+ timeout: Request timeout in seconds (overrides config)
57
+ config_file: Path to config file (default: configs/llm_service_config.yaml)
58
+ max_concurrent_requests: Maximum concurrent requests per service instance (default: 8)
59
+ max_retries: Maximum number of retries for failed requests (default: 3)
60
+ retry_delay: Initial retry delay in seconds (default: 1.0)
61
+ retry_backoff: Retry delay multiplier (default: 2.0)
62
+ """
63
+ # Load config from YAML
64
+ config = self._load_config(config_file)
65
+ vllm_config = config.get("vllm", {})
66
+
67
+ # Use provided values or fall back to config, then environment variables
68
+ self.base_url = base_url or vllm_config.get("base_url") or os.environ.get("VLLM_BASE_URL", "http://localhost:8000/v1")
69
+ self.model_name = model_name or vllm_config.get("model_name", "Qwen/Qwen3-4B-Instruct-2507")
70
+ self.api_key = api_key or vllm_config.get("api_key", "dummy-key")
71
+ self.timeout = timeout or vllm_config.get("timeout", 300)
72
+
73
+ # Retry configuration
74
+ self.max_retries = max_retries
75
+ self.retry_delay = retry_delay
76
+ self.retry_backoff = retry_backoff
77
+
78
+ # Rate limiting: Initialize class-level semaphore if not already initialized
79
+ # Use lazy initialization with thread-safe check to avoid pickle issues
80
+ if max_concurrent_requests is not None:
81
+ VLLMService._max_concurrent_requests = max_concurrent_requests
82
+ else:
83
+ # Try to get from config
84
+ config_max_concurrent = vllm_config.get("max_concurrent_requests")
85
+ if config_max_concurrent is not None:
86
+ VLLMService._max_concurrent_requests = config_max_concurrent
87
+
88
+ # Lazy initialization of semaphore will happen on first use
89
+ # This avoids pickle issues when using multiprocessing/ThreadPoolExecutor
90
+
91
+ # Store default sampling parameters from config
92
+ self.default_temperature = vllm_config.get("temperature", 0.7)
93
+ self.default_top_p = vllm_config.get("top_p", 0.8)
94
+ self.default_top_k = vllm_config.get("top_k", 20)
95
+ self.default_max_tokens = vllm_config.get("max_tokens", 16384)
96
+ self.default_presence_penalty = vllm_config.get("presence_penalty", 0.0)
97
+
98
+ # Create OpenAI client
99
+ self.client = OpenAI(
100
+ api_key=self.api_key,
101
+ base_url=self.base_url,
102
+ timeout=self.timeout,
103
+ )
104
+
105
+ @staticmethod
106
+ def _load_config(config_file: Optional[str] = None) -> Dict[str, Any]:
107
+ """
108
+ Load configuration from YAML file
109
+
110
+ Args:
111
+ config_file: Path to config file
112
+
113
+ Returns:
114
+ Configuration dictionary
115
+ """
116
+ if config_file is None:
117
+ project_root = Path(__file__).parent.parent.parent
118
+ config_file = project_root / "shared" / "configs" / "llm_service_config.yaml"
119
+
120
+ config_path = Path(config_file)
121
+ if not config_path.exists():
122
+ # Return defaults if config file doesn't exist
123
+ return {
124
+ "vllm": {
125
+ "base_url": "http://localhost:8000/v1",
126
+ "api_key": "dummy-key",
127
+ "model_name": "Qwen/Qwen3-4B-Instruct-2507",
128
+ "timeout": 300,
129
+ "max_concurrent_requests": 8,
130
+ "max_retries": 3,
131
+ "retry_delay": 1.0,
132
+ "retry_backoff": 2.0,
133
+ "temperature": 0.7,
134
+ "top_p": 0.8,
135
+ "top_k": 20,
136
+ "max_tokens": 16384,
137
+ "presence_penalty": 0.0,
138
+ }
139
+ }
140
+
141
+ with open(config_path, 'r', encoding='utf-8') as f:
142
+ return yaml.safe_load(f) or {}
143
+
144
+ @classmethod
145
+ def _ensure_semaphore(cls):
146
+ """Thread-safe lazy initialization of semaphore to avoid pickle issues"""
147
+ if cls._request_semaphore is None:
148
+ with cls._semaphore_lock:
149
+ # Double-check pattern
150
+ if cls._request_semaphore is None:
151
+ cls._request_semaphore = Semaphore(cls._max_concurrent_requests)
152
+
153
+ def _format_messages(self, messages: List[Union[ChatMessage, Dict[str, str]]]) -> List[Dict[str, str]]:
154
+ """Format messages for OpenAI API"""
155
+ formatted = []
156
+ for msg in messages:
157
+ if isinstance(msg, ChatMessage):
158
+ formatted.append({"role": msg.role, "content": msg.content})
159
+ elif isinstance(msg, dict):
160
+ formatted.append(msg)
161
+ else:
162
+ raise ValueError(f"Invalid message type: {type(msg)}")
163
+ return formatted
164
+
165
+ def generate(
166
+ self,
167
+ messages: List[Union[ChatMessage, Dict[str, str]]],
168
+ temperature: Optional[float] = None,
169
+ top_p: Optional[float] = None,
170
+ top_k: Optional[int] = None,
171
+ max_tokens: Optional[int] = None,
172
+ presence_penalty: Optional[float] = None,
173
+ **kwargs
174
+ ) -> str:
175
+ """
176
+ Generate text from messages
177
+
178
+ Args:
179
+ messages: List of chat messages
180
+ temperature: Sampling temperature (uses config default if None)
181
+ top_p: Top-p sampling parameter (uses config default if None)
182
+ top_k: Top-k sampling parameter (uses config default if None)
183
+ max_tokens: Maximum tokens to generate (uses config default if None)
184
+ presence_penalty: Presence penalty (uses config default if None)
185
+ **kwargs: Additional parameters
186
+
187
+ Returns:
188
+ Generated text
189
+ """
190
+ formatted_messages = self._format_messages(messages)
191
+
192
+ # Use provided values or fall back to config defaults
193
+ temperature = temperature if temperature is not None else self.default_temperature
194
+ top_p = top_p if top_p is not None else self.default_top_p
195
+ max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
196
+ presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
197
+
198
+ # Ensure semaphore is initialized (lazy, thread-safe)
199
+ self._ensure_semaphore()
200
+
201
+ # Use semaphore to limit concurrent requests
202
+ with VLLMService._request_semaphore:
203
+ last_exception = None
204
+
205
+ for retry_attempt in range(self.max_retries + 1):
206
+ try:
207
+ response = self.client.chat.completions.create(
208
+ model=self.model_name,
209
+ messages=formatted_messages,
210
+ temperature=temperature,
211
+ top_p=top_p,
212
+ max_tokens=max_tokens,
213
+ presence_penalty=presence_penalty,
214
+ **kwargs
215
+ )
216
+
217
+ return response.choices[0].message.content
218
+
219
+ except Exception as e:
220
+ last_exception = e
221
+
222
+ # Check if it's a server error (500, 502, 503, 504) that we should retry
223
+ should_retry = False
224
+ error_str = str(e).lower()
225
+
226
+ if any(code in error_str for code in ["500", "502", "503", "504"]):
227
+ should_retry = True
228
+ elif "server error" in error_str or "internal server error" in error_str:
229
+ should_retry = True
230
+
231
+ # Don't retry on last attempt
232
+ if retry_attempt < self.max_retries and should_retry:
233
+ # Calculate delay with exponential backoff and jitter
234
+ delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
235
+ jitter = random.uniform(0, delay * 0.1) # 10% jitter
236
+ time.sleep(delay + jitter)
237
+ continue
238
+ else:
239
+ # Either not a retryable error or out of retries
240
+ raise last_exception
241
+
242
+ def stream_generate(
243
+ self,
244
+ messages: List[Union[ChatMessage, Dict[str, str]]],
245
+ temperature: Optional[float] = None,
246
+ top_p: Optional[float] = None,
247
+ top_k: Optional[int] = None,
248
+ max_tokens: Optional[int] = None,
249
+ presence_penalty: Optional[float] = None,
250
+ **kwargs
251
+ ):
252
+ """
253
+ Stream generate text from messages
254
+
255
+ Yields:
256
+ Generated text chunks
257
+ """
258
+ formatted_messages = self._format_messages(messages)
259
+
260
+ # Use provided values or fall back to config defaults
261
+ temperature = temperature if temperature is not None else self.default_temperature
262
+ top_p = top_p if top_p is not None else self.default_top_p
263
+ max_tokens = max_tokens if max_tokens is not None else self.default_max_tokens
264
+ presence_penalty = presence_penalty if presence_penalty is not None else self.default_presence_penalty
265
+
266
+ # Ensure semaphore is initialized (lazy, thread-safe)
267
+ self._ensure_semaphore()
268
+
269
+ # Use semaphore to limit concurrent requests
270
+ with VLLMService._request_semaphore:
271
+ last_exception = None
272
+
273
+ for retry_attempt in range(self.max_retries + 1):
274
+ try:
275
+ stream = self.client.chat.completions.create(
276
+ model=self.model_name,
277
+ messages=formatted_messages,
278
+ temperature=temperature,
279
+ top_p=top_p,
280
+ max_tokens=max_tokens,
281
+ presence_penalty=presence_penalty,
282
+ stream=True,
283
+ **kwargs
284
+ )
285
+
286
+ # Stream chunks
287
+ for chunk in stream:
288
+ if chunk.choices[0].delta.content:
289
+ yield chunk.choices[0].delta.content
290
+
291
+ return # Success, exit retry loop
292
+
293
+ except Exception as e:
294
+ last_exception = e
295
+
296
+ # Check if it's a server error that we should retry
297
+ should_retry = False
298
+ error_str = str(e).lower()
299
+
300
+ if any(code in error_str for code in ["500", "502", "503", "504"]):
301
+ should_retry = True
302
+ elif "server error" in error_str or "internal server error" in error_str:
303
+ should_retry = True
304
+
305
+ # Don't retry on last attempt
306
+ if retry_attempt < self.max_retries and should_retry:
307
+ # Calculate delay with exponential backoff and jitter
308
+ delay = self.retry_delay * (self.retry_backoff ** retry_attempt)
309
+ jitter = random.uniform(0, delay * 0.1) # 10% jitter
310
+ time.sleep(delay + jitter)
311
+ continue
312
+ else:
313
+ # Either not a retryable error or out of retries
314
+ raise last_exception
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Unified Review System
3
+
4
+ A comprehensive system for paper review generation and evaluation.
5
+ """
6
+ __version__ = "1.0.0"
src/evaluator/1_get_rubrics.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate review-based rubrics by querying LLMs with concurrent parallel requests.
3
+
4
+ This script:
5
+ 1. Reads the JSON file with review data
6
+ 2. Extracts entries with 'id', 'pred_fast_mode_baseline', 'paper_context', and 'decision'
7
+ 3. Loads the rubric generation prompt from prompts.yaml
8
+ 4. Loads LLM configuration from configs.yaml (supports gpt and vllm modes)
9
+ 5. For each entry, generates rubrics by replacing <<golden_review>> with the ground truth review
10
+ 6. Uses concurrent parallel requests (ThreadPoolExecutor) for efficient LLM queries
11
+ 7. Extracts rubrics from LLM responses and saves to eval_rubrics.json
12
+
13
+ Output JSON file (eval_rubrics.json) contains a list of dicts with:
14
+ - id: Entry identifier
15
+ - paper_context: Paper content
16
+ - decision: Decision field from input
17
+ - golden_review: The pred_fast_mode_baseline review (ground truth)
18
+ - rubrics: List of rubric objects, each with title, description, and weight
19
+
20
+ Usage:
21
+ python 1_generate_review_based_rubrics.py \
22
+ --json_path input.json \
23
+ --output_path eval_rubrics.json \
24
+ --yaml_path prompts.yaml \
25
+ --config_path configs.yaml \
26
+ --max_workers 5
27
+
28
+ The configs.yaml should specify either "gpt" or "vllm" mode and corresponding settings.
29
+ """
30
+ import json
31
+ import os
32
+ import sys
33
+ import argparse
34
+ import yaml
35
+ from typing import Dict, List, Any, Optional
36
+ from concurrent.futures import ThreadPoolExecutor, as_completed
37
+ from tqdm import tqdm
38
+ import pandas as pd
39
+ from dotenv import load_dotenv
40
+
41
+ # Add parent directory to path to import llm_service
42
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
43
+ # Import parse_llm_response from local llm_service module (for parsing LLM responses)
44
+ import llm_service as local_llm_service
45
+ parse_llm_response = local_llm_service.parse_llm_response
46
+
47
+ # Import from shared/utils for gpt/vllm support
48
+ # Add project root to path to enable absolute imports from shared.utils
49
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
50
+ if project_root not in sys.path:
51
+ sys.path.insert(0, project_root)
52
+
53
+ # Use absolute imports from shared.utils package
54
+ from shared.utils.llm_service import LLMService
55
+ from shared.utils.vllm_service import VLLMService
56
+ from shared.utils.gpt_service import GPTService
57
+
58
+ # Load environment variables
59
+ load_dotenv()
60
+
61
+ class ReviewProcessor:
62
+ """Handles the extraction and processing of reviews from different sources."""
63
+
64
+ @staticmethod
65
+ def extract_review_content(pred_context):
66
+ """
67
+ Extract the review content from the prediction context.
68
+
69
+ Args:
70
+ pred_context: Raw prediction data that contains the review
71
+
72
+ Returns:
73
+ str: Extracted review content
74
+ """
75
+ try:
76
+ # First attempt to extract from boxed format
77
+ return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
78
+ except Exception:
79
+ # Alternative extraction if the first method fails
80
+ if isinstance(pred_context, dict) and 'output' in pred_context:
81
+ return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
82
+ else:
83
+ # Return as is if extraction fails
84
+ return pred_context
85
+
86
+
87
+
88
+ def load_json_data(json_path: str) -> List[Dict[str, Any]]:
89
+ """
90
+ Load JSON data from file.
91
+ Handles both list and dict formats.
92
+ """
93
+ with open(json_path, 'r', encoding='utf-8') as f:
94
+ data = json.load(f)
95
+
96
+ # Convert dict to list if needed
97
+ if isinstance(data, dict):
98
+ data = list(data.values())
99
+
100
+ return data
101
+
102
+
103
+ def load_prompt_template(yaml_path: str) -> str:
104
+ """
105
+ Load the rubric generation prompt from YAML file.
106
+ """
107
+ with open(yaml_path, 'r', encoding='utf-8') as f:
108
+ prompts = yaml.safe_load(f)
109
+
110
+ prompt_template = prompts.get('v2_rubric_generation_prompt', '')
111
+ rubric_template = prompts.get('rubrics', '')
112
+
113
+ prompt_template = prompt_template.replace('<<rubric_template>>', rubric_template)
114
+
115
+ return prompt_template
116
+
117
+
118
+ def clean_rubrics_json(json_str: str) -> str:
119
+ """
120
+ Clean JSON string by escaping unescaped double quotes inside string values.
121
+
122
+ This function handles cases where the model outputs double quotes inside
123
+ string values (especially in description fields) without proper escaping.
124
+
125
+ The expected format is a JSON array of objects with "title", "description", "weight" fields.
126
+ Strategy: Find each field's value and escape unescaped quotes inside it.
127
+ """
128
+ import re
129
+
130
+ # First, try to extract JSON array if wrapped in markdown code blocks
131
+ json_match = re.search(r'```json\s*(\[.*?\])\s*```', json_str, re.DOTALL)
132
+ if json_match:
133
+ json_str = json_match.group(1)
134
+ else:
135
+ # Try to find JSON array directly
136
+ json_match = re.search(r'(\[.*?\])', json_str, re.DOTALL)
137
+ if json_match:
138
+ json_str = json_match.group(1)
139
+
140
+ # Process the JSON string character by character to find and fix string values
141
+ # We'll look for patterns like "field": " and then find the matching closing quote
142
+ result = []
143
+ i = 0
144
+
145
+ while i < len(json_str):
146
+ # Look for field pattern: "field_name": "
147
+ field_match = re.search(r'"(title|description|weight)"\s*:\s*"', json_str[i:])
148
+ if not field_match:
149
+ # No more fields to process, append rest and break
150
+ result.append(json_str[i:])
151
+ break
152
+
153
+ # Append everything before the match
154
+ match_start = i + field_match.start()
155
+ result.append(json_str[i:match_start])
156
+
157
+ # Process the field value
158
+ value_start = i + field_match.end() # Position after opening quote
159
+
160
+ # Find the closing quote by scanning character by character
161
+ # The closing quote should be followed by comma, closing brace, or closing bracket
162
+ j = value_start
163
+ found_closing = False
164
+
165
+ while j < len(json_str):
166
+ if json_str[j] == '\\':
167
+ # Skip escaped character (could be \", \\, etc.)
168
+ if j + 1 < len(json_str):
169
+ j += 2
170
+ continue
171
+ else:
172
+ j += 1
173
+ break
174
+ elif json_str[j] == '"':
175
+ # Found a quote - check if it's the closing quote
176
+ # Look ahead (skip whitespace) to see if followed by comma, brace, or bracket
177
+ k = j + 1
178
+ while k < len(json_str) and json_str[k] in ' \t\n\r':
179
+ k += 1
180
+
181
+ if k < len(json_str) and json_str[k] in ',}]':
182
+ # This is the closing quote!
183
+ value_content = json_str[value_start:j]
184
+ closing_part = json_str[j:k+1] # " followed by , } or ]
185
+
186
+ # Fix unescaped quotes in value_content
187
+ # Strategy: preserve already-escaped quotes, escape others
188
+ fixed_content = value_content.replace('\\"', '__TEMP_ESC__')
189
+ fixed_content = fixed_content.replace('"', '\\"')
190
+ fixed_content = fixed_content.replace('__TEMP_ESC__', '\\"')
191
+
192
+ # Append the fixed field
193
+ result.append(json_str[match_start:value_start]) # "field": "
194
+ result.append(fixed_content) # fixed value content
195
+ result.append(closing_part) # " followed by punctuation
196
+
197
+ i = k + 1
198
+ found_closing = True
199
+ break
200
+ j += 1
201
+
202
+ if not found_closing:
203
+ # Couldn't find proper closing quote, append rest and break
204
+ result.append(json_str[match_start:])
205
+ break
206
+
207
+ return ''.join(result)
208
+
209
+
210
+ def extract_rubrics_from_response(response: str) -> Optional[List[Dict[str, Any]]]:
211
+ """
212
+ Extract rubrics (JSON array) from LLM response.
213
+ Handles cases where description fields contain unescaped double quotes.
214
+ Returns None if parsing fails (silently, no error messages printed).
215
+ """
216
+ try:
217
+ # First, try using parse_llm_response (handles markdown blocks)
218
+ try:
219
+ parsed = parse_llm_response(response)
220
+
221
+ # Check if parsed result is a list (array of rubrics)
222
+ if isinstance(parsed, list):
223
+ return parsed
224
+
225
+ # If parsed result is a dict, check for common keys that might contain the array
226
+ if isinstance(parsed, dict):
227
+ # Check for common keys
228
+ for key in ['rubrics', 'rubric', 'items', 'criteria']:
229
+ if key in parsed and isinstance(parsed[key], list):
230
+ return parsed[key]
231
+
232
+ # If no key found, try to find the first list value
233
+ for value in parsed.values():
234
+ if isinstance(value, list):
235
+ return value
236
+ except Exception:
237
+ # parse_llm_response failed, try manual cleaning
238
+ pass
239
+
240
+ # If parse_llm_response failed, try manual extraction and cleaning
241
+ import re
242
+
243
+ # Try to find JSON array in response
244
+ json_match = re.search(r'\[.*?\]', response, re.DOTALL)
245
+ if json_match:
246
+ json_str = json_match.group(0)
247
+
248
+ # Try direct parsing first
249
+ try:
250
+ rubrics = json.loads(json_str)
251
+ if isinstance(rubrics, list):
252
+ return rubrics
253
+ except json.JSONDecodeError:
254
+ # JSON parsing failed, try cleaning
255
+ try:
256
+ cleaned_json = clean_rubrics_json(json_str)
257
+ rubrics = json.loads(cleaned_json)
258
+ if isinstance(rubrics, list):
259
+ return rubrics
260
+ except Exception:
261
+ # Last resort: try a more aggressive cleaning approach
262
+ try:
263
+ # Replace unescaped quotes in description fields more aggressively
264
+ # Pattern: "description": "..." where ... may contain quotes
265
+ def fix_description_quotes(match):
266
+ prefix = match.group(1) # "description": "
267
+ content = match.group(2) # the content
268
+ suffix = match.group(3) # closing quote
269
+
270
+ # Escape all quotes in content, but preserve escaped ones
271
+ # First, mark escaped quotes temporarily
272
+ content = content.replace('\\"', '__ESCAPED_QUOTE__')
273
+ # Escape all remaining quotes
274
+ content = content.replace('"', '\\"')
275
+ # Restore escaped quotes
276
+ content = content.replace('__ESCAPED_QUOTE__', '\\"')
277
+
278
+ return prefix + content + suffix
279
+
280
+ # More specific pattern for description field
281
+ desc_pattern = r'("description"\s*:\s*")(.*?)("(?:\s*[,}])?)'
282
+ fixed_json = re.sub(desc_pattern, fix_description_quotes, json_str, flags=re.DOTALL)
283
+
284
+ rubrics = json.loads(fixed_json)
285
+ if isinstance(rubrics, list):
286
+ return rubrics
287
+ except Exception:
288
+ pass
289
+
290
+ # If all else fails, return None (silently)
291
+ return None
292
+
293
+ except Exception:
294
+ # Any unexpected error, return None (silently)
295
+ return None
296
+
297
+
298
+ def generate_rubrics_for_entry(
299
+ entry: Dict[str, Any],
300
+ prompt_template: str,
301
+ llm_service: LLMService,
302
+ max_retries: int = 16
303
+ ) -> Dict[str, Any]:
304
+ """
305
+ Generate rubrics for a single entry with retry mechanism.
306
+
307
+ Args:
308
+ entry: Dictionary with 'id', 'pred_fast_mode_baseline', 'paper_context', 'decision'
309
+ prompt_template: Prompt template with <<golden_review>> placeholder
310
+ llm_service: LLMService instance (VLLMService or GPTService)
311
+ max_retries: Maximum number of retries if JSON parsing fails (default: 16)
312
+
313
+ Returns:
314
+ Dictionary with 'id', 'paper_context', 'decision', 'golden_review', 'rubrics' (list)
315
+ """
316
+ entry_id = entry.get('id', 'unknown')
317
+ golden_review = entry.get('pred_fast_mode_baseline', '')
318
+ paper_context = entry.get('paper_context', '')
319
+ decision = entry.get('decision', '')
320
+
321
+ # Replace placeholder in prompt template
322
+ prompt = prompt_template.replace('<<golden_review>>', golden_review)
323
+ prompt = prompt.replace('<<paper_context>>', paper_context)
324
+
325
+ # Convert prompt to messages format (shared/utils services use messages format)
326
+ messages = [{"role": "user", "content": prompt}]
327
+
328
+ # Retry loop for JSON parsing failures
329
+ last_error = None
330
+ for attempt in range(max_retries):
331
+ try:
332
+ # Generate response from LLM
333
+ response = llm_service.generate(messages=messages)
334
+
335
+ # Extract rubrics from response
336
+ rubrics_list = extract_rubrics_from_response(response)
337
+
338
+ # If successful, return the result (silently, no output during retries)
339
+ if rubrics_list is not None and isinstance(rubrics_list, list):
340
+ return {
341
+ 'id': entry_id,
342
+ 'paper_context': paper_context,
343
+ 'decision': decision,
344
+ 'golden_review': golden_review,
345
+ 'rubrics': rubrics_list
346
+ }
347
+
348
+ # If extraction failed, continue retrying silently
349
+ # Store the error message for the last attempt
350
+ if attempt == max_retries - 1:
351
+ last_error = "Failed to extract valid rubrics from response"
352
+
353
+ except Exception as e:
354
+ # Store the error (will be overwritten by subsequent attempts until the last one)
355
+ last_error = e
356
+
357
+ # All retries failed, output warning only once
358
+ if last_error:
359
+ print(f"[WARN] Failed to generate rubrics for entry {entry_id} after {max_retries} attempts: {last_error}")
360
+
361
+ # All retries failed, return with empty rubrics
362
+ result = {
363
+ 'id': entry_id,
364
+ 'paper_context': paper_context,
365
+ 'decision': decision,
366
+ 'golden_review': golden_review,
367
+ 'rubrics': [] # Empty list as fallback
368
+ }
369
+ if last_error:
370
+ result['error'] = str(last_error)
371
+ return result
372
+
373
+
374
+ def load_llm_config(config_path: str) -> Dict[str, Any]:
375
+ """
376
+ Load LLM configuration from YAML file.
377
+
378
+ Args:
379
+ config_path: Path to configs.yaml file
380
+
381
+ Returns:
382
+ Configuration dictionary
383
+ """
384
+ with open(config_path, 'r', encoding='utf-8') as f:
385
+ config = yaml.safe_load(f)
386
+ return config
387
+
388
+
389
+ def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
390
+ """
391
+ Create LLM service from configuration.
392
+
393
+ Args:
394
+ config: Configuration dictionary from configs.yaml
395
+
396
+ Returns:
397
+ LLMService instance (VLLMService or GPTService)
398
+ """
399
+ mode = config.get('mode', 'gpt').lower()
400
+
401
+ if mode == 'gpt':
402
+ gpt_config = config.get('gpt', {})
403
+ api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
404
+ if not api_key:
405
+ raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
406
+
407
+ service = GPTService(
408
+ api_key=api_key,
409
+ model_name=gpt_config.get('model_name', 'gpt-4o'),
410
+ base_url=gpt_config.get('base_url'),
411
+ timeout=gpt_config.get('timeout', 300)
412
+ )
413
+ return service
414
+
415
+ elif mode == 'vllm':
416
+ vllm_config = config.get('vllm', {})
417
+ service = VLLMService(
418
+ base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
419
+ api_key=vllm_config.get('api_key', 'dummy-key'),
420
+ model_name=vllm_config.get('model_name'),
421
+ timeout=vllm_config.get('timeout', 300),
422
+ max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
423
+ max_retries=vllm_config.get('max_retries', 3),
424
+ retry_delay=vllm_config.get('retry_delay', 1.0),
425
+ retry_backoff=vllm_config.get('retry_backoff', 2.0)
426
+ )
427
+ return service
428
+
429
+ else:
430
+ raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
431
+
432
+
433
+ def parse_args():
434
+ """Parse command line arguments."""
435
+ parser = argparse.ArgumentParser(description="Generate review-based rubrics using LLMs")
436
+
437
+ # Input/Output paths
438
+ parser.add_argument("--json_path", type=str, required=True,
439
+ help="Path to input JSON file with review data")
440
+ parser.add_argument("--output_path", type=str, default=None,
441
+ help="Path to output JSON file (default: eval_rubrics.json in same dir as input)")
442
+ parser.add_argument("--yaml_path", type=str, default=None,
443
+ help="Path to prompts.yaml file (default: prompts.yaml in same dir as script)")
444
+ parser.add_argument("--config_path", type=str, default=None,
445
+ help="Path to configs.yaml file (default: configs.yaml in same dir as script)")
446
+
447
+ # Multi-threading
448
+ parser.add_argument("--max_workers", type=int, default=None,
449
+ help="Maximum number of worker threads (default: from MAX_WORKERS env var or 5)")
450
+
451
+ return parser.parse_args()
452
+
453
+
454
+ def main():
455
+ """Main execution function."""
456
+ args = parse_args()
457
+
458
+ script_dir = os.path.dirname(os.path.abspath(__file__))
459
+
460
+ # File paths
461
+ json_path = args.json_path
462
+ if not os.path.isabs(json_path):
463
+ json_path = os.path.join(script_dir, json_path)
464
+
465
+ if args.output_path:
466
+ output_path = args.output_path
467
+ if not os.path.isabs(output_path):
468
+ output_path = os.path.join(script_dir, output_path)
469
+ else:
470
+ # Default: same directory as input JSON, with eval_rubrics.json name
471
+ output_dir = os.path.dirname(json_path)
472
+ output_path = os.path.join(output_dir, 'eval_rubrics.json')
473
+
474
+ if args.yaml_path:
475
+ yaml_path = args.yaml_path
476
+ if not os.path.isabs(yaml_path):
477
+ yaml_path = os.path.join(script_dir, yaml_path)
478
+ else:
479
+ yaml_path = os.path.join(script_dir, 'prompts.yaml')
480
+
481
+ if args.config_path:
482
+ config_path = args.config_path
483
+ if not os.path.isabs(config_path):
484
+ config_path = os.path.join(script_dir, config_path)
485
+ else:
486
+ config_path = os.path.join(script_dir, 'configs.yaml')
487
+
488
+ max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
489
+
490
+ # Check if files exist
491
+ if not os.path.exists(json_path):
492
+ raise FileNotFoundError(f"JSON file not found: {json_path}")
493
+ if not os.path.exists(yaml_path):
494
+ raise FileNotFoundError(f"YAML file not found: {yaml_path}")
495
+ if not os.path.exists(config_path):
496
+ raise FileNotFoundError(f"Config file not found: {config_path}")
497
+
498
+ print(f"Loading JSON data from {json_path}...")
499
+ data = load_json_data(json_path)
500
+ print(f"Loaded {len(data)} entries")
501
+
502
+ print(f"Loading prompt template from {yaml_path}...")
503
+ prompt_template = load_prompt_template(yaml_path)
504
+ if not prompt_template:
505
+ raise ValueError("Could not find 'v2_rubric_generation_prompt' in YAML file")
506
+ print("Prompt template loaded successfully")
507
+
508
+ # Load LLM configuration and create service
509
+ print(f"Loading LLM configuration from {config_path}...")
510
+ llm_config = load_llm_config(config_path)
511
+ llm_service = create_llm_service_from_config(llm_config)
512
+ mode = llm_config.get('mode', 'gpt')
513
+ print(f"LLM service initialized (mode: {mode})")
514
+ if hasattr(llm_service, 'model_name'):
515
+ print(f"Using model: {llm_service.model_name}")
516
+
517
+ # Extract required fields from each entry
518
+ print("Extracting required fields from entries...")
519
+ entries = []
520
+ for item in data:
521
+ if 'id' in item and 'pred_fast_mode_baseline' in item:
522
+ entries.append(item)
523
+ else:
524
+ print(f"[WARN] Skipping entry missing required fields: {item.get('id', 'unknown')}")
525
+
526
+ print(f"Processing {len(entries)} entries with {max_workers} workers...")
527
+
528
+ # Generate rubrics using concurrent processing
529
+ results = []
530
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
531
+ # Submit all tasks
532
+ future_to_entry = {
533
+ executor.submit(
534
+ generate_rubrics_for_entry,
535
+ entry,
536
+ prompt_template,
537
+ llm_service
538
+ ): entry
539
+ for entry in entries
540
+ }
541
+
542
+ # Process completed tasks with progress bar
543
+ for future in tqdm(as_completed(future_to_entry), total=len(entries), desc="Generating rubrics"):
544
+ try:
545
+ result = future.result()
546
+ results.append(result)
547
+ except Exception as e:
548
+ entry = future_to_entry[future]
549
+ entry_id = entry.get('id', 'unknown')
550
+ print(f"\n[ERROR] Failed to process entry {entry_id}: {e}")
551
+ # Add error entry with empty rubrics
552
+ results.append({
553
+ 'id': entry_id,
554
+ 'paper_context': entry.get('paper_context', ''),
555
+ 'decision': entry.get('decision', ''),
556
+ 'golden_review': entry.get('pred_fast_mode_baseline', ''),
557
+ 'rubrics': [],
558
+ 'error': str(e)
559
+ })
560
+
561
+ print(f"\nSuccessfully generated rubrics for {len(results)} entries")
562
+
563
+ # Save to JSON
564
+ with open(output_path, 'w', encoding='utf-8') as f:
565
+ json.dump(results, f, ensure_ascii=False, indent=2)
566
+ print(f"\nResults saved to {output_path}")
567
+
568
+ # Print summary statistics
569
+ print("\n" + "="*80)
570
+ print("SUMMARY STATISTICS")
571
+ print("="*80)
572
+ print(f"Total entries processed: {len(results)}")
573
+
574
+ # Count successful vs failed
575
+ successful = sum(1 for r in results if 'error' not in r and len(r.get('rubrics', [])) > 0)
576
+ failed = len(results) - successful
577
+ print(f"Successful: {successful}")
578
+ print(f"Failed: {failed}")
579
+
580
+ # Check rubrics statistics
581
+ rubric_counts = [len(r.get('rubrics', [])) for r in results if isinstance(r.get('rubrics'), list)]
582
+
583
+ if rubric_counts:
584
+ print(f"\nRubrics per entry:")
585
+ print(f" Mean: {sum(rubric_counts) / len(rubric_counts):.2f}")
586
+ print(f" Min: {min(rubric_counts)}")
587
+ print(f" Max: {max(rubric_counts)}")
588
+
589
+
590
+ if __name__ == "__main__":
591
+ main()
592
+
593
+ """
594
+ Example usage:
595
+ python 1_generate_review_based_rubrics.py \
596
+ --json_path ./examples/input.json \
597
+ --output_path eval_rubrics.json \
598
+ --yaml_path prompts.yaml \
599
+ --config_path configs.yaml \
600
+ --max_workers 5
601
+ """
src/evaluator/2_evaluate.py ADDED
@@ -0,0 +1,1730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
3
+
4
+ This script:
5
+ 1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
6
+ 2. Reads input JSON file containing model reviews (supports multiple formats)
7
+ 3. Supports three evaluation modes:
8
+ - semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
9
+ - auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
10
+ - both: Run both evaluations separately
11
+ 4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
12
+ 5. Outputs separate JSON files for results and summaries
13
+
14
+ Usage:
15
+ # Semantic evaluation only
16
+ python 2_evaluate.py \
17
+ --rubrics_path eval_rubrics.json \
18
+ --reviews_path model_reviews.json \
19
+ --mode semantic \
20
+ --yaml_path prompts.yaml \
21
+ --config_path configs.yaml \
22
+ --semantic_output semantic_results.json \
23
+ --max_workers 5
24
+
25
+ # Auto-metric evaluation only
26
+ python 2_evaluate.py \
27
+ --rubrics_path eval_rubrics.json \
28
+ --reviews_path model_reviews.json \
29
+ --mode auto_metric \
30
+ --auto_metric_output auto_metric_results.json
31
+
32
+ # Auto-metric evaluation with strict mode (normalize scores to discrete scales)
33
+ python 2_evaluate.py \
34
+ --rubrics_path eval_rubrics.json \
35
+ --reviews_path model_reviews.json \
36
+ --mode auto_metric \
37
+ --auto_metric_output auto_metric_results.json \
38
+ --strict_mode
39
+
40
+ # Auto-metric evaluation with manually specified input format (refined)
41
+ python 2_evaluate.py \
42
+ --rubrics_path eval_rubrics.json \
43
+ --reviews_path model_reviews.json \
44
+ --mode auto_metric \
45
+ --auto_metric_output auto_metric_results.json \
46
+ --input_format refined
47
+
48
+ # Auto-metric evaluation with manually specified input format (original)
49
+ python 2_evaluate.py \
50
+ --rubrics_path eval_rubrics.json \
51
+ --reviews_path ours.json \
52
+ --mode auto_metric \
53
+ --auto_metric_output auto_metric_results.json \
54
+ --input_format original
55
+
56
+ # Both evaluations
57
+ python 2_evaluate.py \
58
+ --rubrics_path eval_rubrics.json \
59
+ --reviews_path model_reviews.json \
60
+ --mode both \
61
+ --yaml_path prompts.yaml \
62
+ --config_path configs.yaml \
63
+ --semantic_output semantic_results.json \
64
+ --auto_metric_output auto_metric_results.json \
65
+ --max_workers 32
66
+ """
67
+ from __future__ import annotations
68
+
69
+ import json
70
+ import os
71
+ import sys
72
+ import argparse
73
+ import yaml
74
+ import math
75
+ from typing import Dict, List, Any, Optional
76
+ from concurrent.futures import ThreadPoolExecutor, as_completed
77
+ from tqdm import tqdm
78
+ from itertools import combinations
79
+ from scipy.stats import spearmanr
80
+ from sklearn.metrics import precision_recall_fscore_support
81
+
82
+ # Add parent directory to path
83
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
84
+ # Import parse_llm_response from local llm_service module
85
+ import llm_service as local_llm_service
86
+ parse_llm_response = local_llm_service.parse_llm_response
87
+
88
+ # Import from shared/utils for gpt/vllm support
89
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
90
+ if project_root not in sys.path:
91
+ sys.path.insert(0, project_root)
92
+
93
+ from shared.utils.llm_service import LLMService
94
+ from shared.utils.vllm_service import VLLMService
95
+ from shared.utils.gpt_service import GPTService
96
+ sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
97
+ from json_parser import parse_review_markdown
98
+
99
+ class ReviewProcessor:
100
+ """Handles the extraction and processing of reviews from different sources."""
101
+
102
+ @staticmethod
103
+ def extract_review_content(pred_context):
104
+ """
105
+ Extract the review content from the prediction context.
106
+
107
+ Args:
108
+ pred_context: Raw prediction data that contains the review
109
+
110
+ Returns:
111
+ str: Extracted review content
112
+ """
113
+ try:
114
+ # First attempt to extract from boxed format
115
+ return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
116
+ except Exception:
117
+ # Alternative extraction if the first method fails
118
+ if isinstance(pred_context, dict) and 'output' in pred_context:
119
+ return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
120
+ else:
121
+ # Return as is if extraction fails
122
+ return pred_context
123
+
124
+
125
+ # ============================================================================
126
+ # Semantic Evaluation Functions (from 2_evaluate_direct.py)
127
+ # ============================================================================
128
+
129
+ def load_prompt_template(yaml_path: str) -> str:
130
+ """Load the evaluator prompt from YAML file."""
131
+ with open(yaml_path, 'r', encoding='utf-8') as f:
132
+ prompts = yaml.safe_load(f)
133
+ return prompts.get('v1_evaluator_prompt', '')
134
+
135
+
136
+ def build_evaluation_prompt(
137
+ rubrics: List[Dict[str, Any]],
138
+ paper_content: str,
139
+ review: str,
140
+ prompt_template: str
141
+ ) -> str:
142
+ """Build the evaluation prompt by replacing placeholders."""
143
+ rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
144
+ prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
145
+ prompt = prompt.replace('<<paper_content>>', paper_content)
146
+ prompt = prompt.replace('<<review>>', review)
147
+ return prompt
148
+
149
+
150
+ def calculate_weighted_scores(
151
+ raw_scores: Dict[str, Dict[str, Any]],
152
+ rubrics: List[Dict[str, Any]]
153
+ ) -> Dict[str, float]:
154
+ """Calculate weighted scores for each rubric."""
155
+ rubric_weights = {r['title']: r['weight'] for r in rubrics}
156
+ weighted_scores = {}
157
+
158
+ for rubric_title, rubric_data in raw_scores.items():
159
+ if rubric_title not in rubric_weights:
160
+ continue
161
+
162
+ rubric_score = rubric_data.get('score', 0)
163
+ if isinstance(rubric_score, str):
164
+ try:
165
+ rubric_score = int(rubric_score)
166
+ except ValueError:
167
+ rubric_score = 0
168
+
169
+ if rubric_score not in [0, 1]:
170
+ rubric_score = 1 if rubric_score > 0 else 0
171
+
172
+ weight = rubric_weights[rubric_title]
173
+ weighted_scores[rubric_title] = rubric_score * weight
174
+
175
+ return weighted_scores
176
+
177
+
178
+ def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
179
+ """Calculate scores for each rubric."""
180
+ scores = {}
181
+ for rubric_title, rubric_data in raw_scores.items():
182
+ scores[rubric_title] = rubric_data.get('score', 0)
183
+ return scores
184
+
185
+
186
+ def evaluate_review_semantic(
187
+ entry: Dict[str, Any],
188
+ paper_content: str,
189
+ prompt_template: str,
190
+ llm_service: LLMService
191
+ ) -> Dict[str, Any]:
192
+ """Evaluate a single review using article-specific rubrics."""
193
+ entry_id = entry.get('id', 'unknown')
194
+ rubrics = entry.get('rubrics', [])
195
+ model_review = entry.get('model_review', '')
196
+
197
+ if not rubrics:
198
+ return {
199
+ 'id': entry_id,
200
+ 'raw_scores': {},
201
+ 'weighted_scores': {},
202
+ 'total_score': 0.0,
203
+ 'error': 'No valid rubrics found',
204
+ 'raw_response': ''
205
+ }
206
+
207
+ # Build prompt
208
+ prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
209
+
210
+ # Call LLM
211
+ try:
212
+ messages = [{"role": "user", "content": prompt}]
213
+ response = llm_service.generate(messages=messages)
214
+
215
+ # Parse response
216
+ raw_scores = parse_llm_response(response)
217
+ weighted_scores = calculate_scores(raw_scores)
218
+ total_score = sum(weighted_scores.values())
219
+
220
+ return {
221
+ 'id': entry_id,
222
+ 'raw_scores': raw_scores,
223
+ 'weighted_scores': weighted_scores,
224
+ 'total_score': total_score,
225
+ 'raw_response': response
226
+ }
227
+ except Exception as e:
228
+ print(f"[ERROR] Error evaluating review {entry_id}: {e}")
229
+ return {
230
+ 'id': entry_id,
231
+ 'raw_scores': {},
232
+ 'weighted_scores': {},
233
+ 'total_score': 0.0,
234
+ 'error': str(e),
235
+ 'raw_response': ''
236
+ }
237
+
238
+
239
+ def calculate_per_rubric_statistics(
240
+ valid_results: List[Dict[str, Any]],
241
+ rubric_titles: List[str]
242
+ ) -> Dict[str, Dict[str, float]]:
243
+ """Calculate per-rubric statistics from evaluation results."""
244
+ rubric_scores = {title: [] for title in rubric_titles}
245
+
246
+ for result in valid_results:
247
+ weighted_scores = result.get('weighted_scores', {})
248
+ if not isinstance(weighted_scores, dict):
249
+ continue
250
+
251
+ for rubric_title in rubric_titles:
252
+ if rubric_title in weighted_scores:
253
+ score = weighted_scores[rubric_title]
254
+ if isinstance(score, str):
255
+ try:
256
+ score = float(score)
257
+ except ValueError:
258
+ continue
259
+ elif isinstance(score, (int, float)):
260
+ score = float(score)
261
+ else:
262
+ continue
263
+ rubric_scores[rubric_title].append(score)
264
+
265
+ per_rubric_stats = {}
266
+ for rubric_title in rubric_titles:
267
+ scores = rubric_scores[rubric_title]
268
+ if not scores:
269
+ continue
270
+
271
+ mean_score = sum(scores) / len(scores)
272
+ min_score = min(scores)
273
+ max_score = max(scores)
274
+ count = len(scores)
275
+
276
+ if rubric_title == "False or Contradictory Claims":
277
+ pass_count = sum(1 for s in scores if s >= 0)
278
+ else:
279
+ pass_count = sum(1 for s in scores if s >= 1)
280
+ pass_rate = pass_count / count if count > 0 else 0.0
281
+
282
+ per_rubric_stats[rubric_title] = {
283
+ 'mean': mean_score,
284
+ 'min': min_score,
285
+ 'max': max_score,
286
+ 'count': count,
287
+ 'pass_rate': pass_rate
288
+ }
289
+
290
+ return per_rubric_stats
291
+
292
+
293
+ # ============================================================================
294
+ # Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
295
+ # ============================================================================
296
+
297
+ def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
298
+ """Extract numeric scores and decision from a review markdown text."""
299
+ if not review_text:
300
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
301
+
302
+ try:
303
+ parsed = parse_review_markdown(review_text)
304
+ decision = parsed.get('decision', '')
305
+ if decision:
306
+ decision_lower = decision.lower().strip()
307
+ if 'accept' in decision_lower:
308
+ decision = 'accept'
309
+ elif 'reject' in decision_lower:
310
+ decision = 'reject'
311
+ elif 'undecided' in decision_lower:
312
+ decision = 'undecided'
313
+ else:
314
+ decision = decision_lower
315
+ else:
316
+ decision = None
317
+
318
+ return {
319
+ 'soundness': parsed.get('soundness'),
320
+ 'presentation': parsed.get('presentation'),
321
+ 'rating': parsed.get('rating'),
322
+ 'confidence': parsed.get('confidence'),
323
+ 'decision': decision
324
+ }
325
+ except Exception as e:
326
+ print(f"Warning: Failed to parse review text: {e}")
327
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
328
+
329
+
330
+ def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
331
+ """Calculate Mean Squared Error for a single value."""
332
+ if predicted is None or ground_truth is None:
333
+ return None
334
+ return (predicted - ground_truth) ** 2
335
+
336
+
337
+ def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
338
+ """Calculate Mean Absolute Error for a single value."""
339
+ if predicted is None or ground_truth is None:
340
+ return None
341
+ return abs(predicted - ground_truth)
342
+
343
+
344
+ def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
345
+ """
346
+ Normalize a float score to the nearest discrete value based on scale type.
347
+ Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
348
+
349
+ Args:
350
+ score: The float score to normalize (can be None)
351
+ scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
352
+ or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
353
+
354
+ Returns:
355
+ Normalized discrete score, or None if input is None
356
+ """
357
+ if score is None:
358
+ return None
359
+
360
+ try:
361
+ score = float(score)
362
+ except (ValueError, TypeError):
363
+ return None
364
+
365
+ if scale_type == '0-5':
366
+ # Discrete values: 0, 1, 2, 3, 4, 5
367
+ discrete_values = [0, 1, 2, 3, 4, 5]
368
+ # Clamp to valid range
369
+ score = max(0, min(5, score))
370
+ # Find nearest discrete value, with round-half-up tie-breaking
371
+ # For ties, prefer the higher value
372
+ best_value = None
373
+ best_distance = float('inf')
374
+ for val in discrete_values:
375
+ distance = abs(val - score)
376
+ if distance < best_distance:
377
+ best_distance = distance
378
+ best_value = val
379
+ elif distance == best_distance and val > best_value:
380
+ # Tie-breaking: prefer higher value (round-half-up)
381
+ best_value = val
382
+ return best_value
383
+ elif scale_type == '0-10':
384
+ # Discrete values: 0, 2, 4, 6, 8, 10
385
+ discrete_values = [0, 2, 4, 6, 8, 10]
386
+ # Clamp to valid range
387
+ score = max(0, min(10, score))
388
+ # Find nearest discrete value, with round-half-up tie-breaking
389
+ best_value = None
390
+ best_distance = float('inf')
391
+ for val in discrete_values:
392
+ distance = abs(val - score)
393
+ if distance < best_distance:
394
+ best_distance = distance
395
+ best_value = val
396
+ elif distance == best_distance and val > best_value:
397
+ # Tie-breaking: prefer higher value (round-half-up)
398
+ best_value = val
399
+ return best_value
400
+ else:
401
+ raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
402
+
403
+
404
+ def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
405
+ """
406
+ Normalize all scores in a dictionary to their appropriate discrete scales.
407
+
408
+ Args:
409
+ scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
410
+
411
+ Returns:
412
+ Dictionary with normalized scores
413
+ """
414
+ normalized = {}
415
+
416
+ # soundness, presentation, confidence use 0-5 scale
417
+ for key in ['soundness', 'presentation', 'confidence']:
418
+ normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
419
+
420
+ # rating uses 0-10 scale
421
+ normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
422
+
423
+ return normalized
424
+
425
+
426
+ def calculate_score_metrics(
427
+ model_scores: Dict[str, float],
428
+ ground_truth_scores: Dict[str, float],
429
+ normalize: bool = False
430
+ ) -> Dict[str, Any]:
431
+ """
432
+ Calculate MSE and MAE metrics for each scoring dimension.
433
+
434
+ Args:
435
+ model_scores: Dictionary with model scores
436
+ ground_truth_scores: Dictionary with ground truth scores
437
+ normalize: If True, normalize scores to discrete scales before computing metrics
438
+
439
+ Returns:
440
+ Dictionary with MSE, MAE metrics and optionally normalized scores
441
+ """
442
+ dimensions = ['soundness', 'presentation', 'rating', 'confidence']
443
+
444
+ # Normalize scores to discrete scales if requested
445
+ if normalize:
446
+ model_scores_normalized = normalize_scores_dict(model_scores)
447
+ gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
448
+ else:
449
+ model_scores_normalized = model_scores
450
+ gt_scores_normalized = ground_truth_scores
451
+
452
+ mse_values = {}
453
+ mae_values = {}
454
+ valid_count = 0
455
+
456
+ for dim in dimensions:
457
+ # Use normalized scores for metric calculation
458
+ mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
459
+ mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
460
+ mse_values[f'{dim}_mse'] = mse
461
+ mae_values[f'{dim}_mae'] = mae
462
+ if mse is not None:
463
+ valid_count += 1
464
+
465
+ overall_error = sum([v for v in mse_values.values() if v is not None])
466
+
467
+ result = {
468
+ **mse_values,
469
+ **mae_values,
470
+ 'overall_error': overall_error if valid_count > 0 else None,
471
+ 'valid_dimensions': valid_count
472
+ }
473
+
474
+ # Include normalized scores in result for transparency (only if normalize=True)
475
+ if normalize:
476
+ result['model_scores_normalized'] = model_scores_normalized
477
+ result['gt_scores_normalized'] = gt_scores_normalized
478
+
479
+ return result
480
+
481
+
482
+ def normalize_score_value(value):
483
+ """Normalize score value to float, handling string representations."""
484
+ if value is None:
485
+ return None
486
+ if isinstance(value, (int, float)):
487
+ return float(value)
488
+ if isinstance(value, str):
489
+ # Try to extract numeric value from string (e.g., "2.75" -> 2.75)
490
+ try:
491
+ import re
492
+ match = re.search(r'(\d+\.?\d*)', value)
493
+ if match:
494
+ return float(match.group(1))
495
+ except:
496
+ pass
497
+ return None
498
+
499
+
500
+ def normalize_decision(decision):
501
+ """Normalize decision string to standard format."""
502
+ if decision is None:
503
+ return None
504
+ decision_lower = str(decision).lower().strip()
505
+ if 'accept' in decision_lower:
506
+ return 'accept'
507
+ elif 'reject' in decision_lower:
508
+ return 'reject'
509
+ elif 'undecided' in decision_lower:
510
+ return 'undecided'
511
+ else:
512
+ return decision_lower
513
+
514
+
515
+ def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
516
+ """
517
+ Extract scores from a structured dictionary (scores or initial_scores format).
518
+
519
+ Args:
520
+ scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
521
+
522
+ Returns:
523
+ Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
524
+ """
525
+ if not scores_dict:
526
+ return {
527
+ 'soundness': None,
528
+ 'presentation': None,
529
+ 'rating': None,
530
+ 'confidence': None,
531
+ 'decision': None
532
+ }
533
+
534
+ return {
535
+ 'soundness': normalize_score_value(scores_dict.get('soundness')),
536
+ 'presentation': normalize_score_value(scores_dict.get('presentation')),
537
+ 'rating': normalize_score_value(scores_dict.get('rating')),
538
+ 'confidence': normalize_score_value(scores_dict.get('confidence')),
539
+ 'decision': normalize_decision(scores_dict.get('decision'))
540
+ }
541
+
542
+
543
+ def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
544
+ """
545
+ Evaluate a single entry by extracting scores and calculating metrics.
546
+
547
+ Args:
548
+ entry: Evaluation entry containing model_review, scores, initial_scores, etc.
549
+ use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
550
+
551
+ Returns:
552
+ Dict containing evaluation metrics
553
+ """
554
+ entry_id = entry.get('id', 'unknown')
555
+ model_review = entry.get('model_review', '')
556
+ format_type = entry.get('format', 'unknown')
557
+
558
+ # Extract scores based on format
559
+ model_scores = {}
560
+ model_decision = None
561
+
562
+ if format_type == 'refined' and not use_initial_scores:
563
+ # Use refined scores from structured data
564
+ scores_dict = entry.get('scores', {})
565
+ model_data = extract_scores_from_dict(scores_dict)
566
+ model_scores = {
567
+ 'soundness': model_data.get('soundness'),
568
+ 'presentation': model_data.get('presentation'),
569
+ 'rating': model_data.get('rating'),
570
+ 'confidence': model_data.get('confidence')
571
+ }
572
+ model_decision = model_data.get('decision')
573
+ elif format_type == 'refined' and use_initial_scores:
574
+ # Use initial scores from structured data
575
+ initial_scores_dict = entry.get('initial_scores', {})
576
+ model_data = extract_scores_from_dict(initial_scores_dict)
577
+ model_scores = {
578
+ 'soundness': model_data.get('soundness'),
579
+ 'presentation': model_data.get('presentation'),
580
+ 'rating': model_data.get('rating'),
581
+ 'confidence': model_data.get('confidence')
582
+ }
583
+ model_decision = model_data.get('decision')
584
+ elif format_type == 'original':
585
+ # Use initial scores from structured data
586
+ initial_scores_dict = entry.get('initial_scores', {})
587
+ model_data = extract_scores_from_dict(initial_scores_dict)
588
+ model_scores = {
589
+ 'soundness': model_data.get('soundness'),
590
+ 'presentation': model_data.get('presentation'),
591
+ 'rating': model_data.get('rating'),
592
+ 'confidence': model_data.get('confidence')
593
+ }
594
+ model_decision = model_data.get('decision')
595
+
596
+ # Fallback: If confidence is missing from structured data, try to extract from review text
597
+ # (meta_review may not have confidence field, but review text might)
598
+ if model_scores.get('confidence') is None and model_review:
599
+ try:
600
+ review_data = extract_scores_from_review(model_review)
601
+ if review_data.get('confidence') is not None:
602
+ model_scores['confidence'] = review_data.get('confidence')
603
+ except Exception:
604
+ pass # Keep confidence as None if extraction fails
605
+ else:
606
+ # Fallback: extract from markdown review text
607
+ model_data = extract_scores_from_review(model_review)
608
+ model_scores = {
609
+ 'soundness': model_data.get('soundness'),
610
+ 'presentation': model_data.get('presentation'),
611
+ 'rating': model_data.get('rating'),
612
+ 'confidence': model_data.get('confidence')
613
+ }
614
+ model_decision = model_data.get('decision')
615
+
616
+ # Get ground truth scores from golden_review ONLY
617
+ # Ground truth must ONLY come from golden_review, never from model output
618
+ # If extraction fails, leave fields as None (do not use model_review as fallback)
619
+ ground_truth_review = entry.get('golden_review', '')
620
+ ground_truth_scores = {}
621
+ gt_decision = None
622
+
623
+ if not ground_truth_review:
624
+ print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
625
+ else:
626
+ try:
627
+ # Extract scores from golden_review markdown text
628
+ gt_data = extract_scores_from_review(ground_truth_review)
629
+ if not gt_data:
630
+ print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
631
+ else:
632
+ ground_truth_scores = {
633
+ 'soundness': gt_data.get('soundness'),
634
+ 'presentation': gt_data.get('presentation'),
635
+ 'rating': gt_data.get('rating'),
636
+ 'confidence': gt_data.get('confidence')
637
+ }
638
+ gt_decision = normalize_decision(gt_data.get('decision'))
639
+ # Note: If any field is None, it stays None - we do NOT use model_review as fallback
640
+ # Using model output as ground truth would inflate evaluation scores
641
+ except Exception as e:
642
+ print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
643
+ print(f" Ground truth scores will be empty. Error: {str(e)}")
644
+
645
+ # Calculate MSE and MAE metrics (with optional normalization in strict mode)
646
+ score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
647
+
648
+ # Calculate decision accuracy
649
+ decision_match = False
650
+ decision_accuracy = None
651
+ if model_decision is not None and gt_decision is not None:
652
+ model_decision_normalized = normalize_decision(model_decision)
653
+ decision_match = (model_decision_normalized == gt_decision)
654
+ decision_accuracy = 1.0 if decision_match else 0.0
655
+
656
+ result = {
657
+ 'id': entry_id,
658
+ 'format': format_type,
659
+ 'model_soundness': model_scores.get('soundness'),
660
+ 'model_presentation': model_scores.get('presentation'),
661
+ 'model_rating': model_scores.get('rating'),
662
+ 'model_confidence': model_scores.get('confidence'),
663
+ 'model_decision': model_decision,
664
+ 'gt_soundness': ground_truth_scores.get('soundness'),
665
+ 'gt_presentation': ground_truth_scores.get('presentation'),
666
+ 'gt_rating': ground_truth_scores.get('rating'),
667
+ 'gt_confidence': ground_truth_scores.get('confidence'),
668
+ 'gt_decision': gt_decision,
669
+ 'decision_match': decision_match,
670
+ 'decision_accuracy': decision_accuracy,
671
+ **score_metrics
672
+ }
673
+
674
+ # Add prefix to indicate which scores were used
675
+ if format_type == 'refined':
676
+ if use_initial_scores:
677
+ result['score_type'] = 'initial'
678
+ else:
679
+ result['score_type'] = 'refined'
680
+ else:
681
+ result['score_type'] = 'auto'
682
+
683
+ return result
684
+
685
+
686
+ def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
687
+ """Calculate pairwise accuracy for each metric by comparing rankings."""
688
+ if len(paper_scores) < 2:
689
+ return {}
690
+
691
+ total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
692
+ correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
693
+
694
+ for paper1, paper2 in combinations(paper_scores, 2):
695
+ # Check rating ranking
696
+ if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
697
+ paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
698
+ total_valid_pairs['rating'] += 1
699
+ true_order = paper1['true_rating'] > paper2['true_rating']
700
+ pred_order = paper1['pred_rating'] > paper2['pred_rating']
701
+ if true_order == pred_order:
702
+ correct_pairs['rating'] += 1
703
+
704
+ # Similar for other dimensions...
705
+ # (abbreviated for space, similar logic for soundness, presentation, confidence)
706
+ for metric in ['soundness', 'presentation', 'confidence']:
707
+ true_key = f'true_{metric}'
708
+ pred_key = f'pred_{metric}'
709
+ if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
710
+ paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
711
+ total_valid_pairs[metric] += 1
712
+ true_order = paper1[true_key] > paper2[true_key]
713
+ pred_order = paper1[pred_key] > paper2[pred_key]
714
+ if true_order == pred_order:
715
+ correct_pairs[metric] += 1
716
+
717
+ pairwise_accuracies = {
718
+ metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
719
+ for metric in ['rating', 'soundness', 'presentation', 'confidence']
720
+ }
721
+
722
+ return pairwise_accuracies
723
+
724
+
725
+ # ============================================================================
726
+ # Data Loading Functions
727
+ # ============================================================================
728
+
729
+ def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
730
+ """Load rubrics JSON and create lookup by id."""
731
+ with open(rubrics_path, 'r', encoding='utf-8') as f:
732
+ data = json.load(f)
733
+
734
+ if isinstance(data, list):
735
+ return {item['id']: item for item in data}
736
+ elif isinstance(data, dict):
737
+ return data
738
+ else:
739
+ raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
740
+
741
+
742
+ def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
743
+ """
744
+ Load model reviews JSON and extract reviews by id.
745
+
746
+ Supports two input formats:
747
+ 1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
748
+ 2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
749
+
750
+ Args:
751
+ reviews_path: Path to JSON file containing model reviews
752
+ format_override: Optional format override ('refined', 'original', or None for auto-detect)
753
+
754
+ Returns:
755
+ Dict mapping paper_id to dict containing:
756
+ - 'review': review text (markdown)
757
+ - 'scores': refined scores dict (if available)
758
+ - 'initial_scores': initial scores dict (if available)
759
+ - 'format': 'refined' or 'original'
760
+ """
761
+ with open(reviews_path, 'r', encoding='utf-8') as f:
762
+ data = json.load(f)
763
+
764
+ if isinstance(data, dict):
765
+ data = list(data.values())
766
+
767
+ reviews_dict = {}
768
+ for item in data:
769
+ item_id = None
770
+ review_text = ''
771
+ scores = None
772
+ initial_scores = None
773
+ format_type = None
774
+
775
+ # Use format override if provided, otherwise auto-detect
776
+ if format_override and format_override != 'auto':
777
+ # Force use specified format
778
+ if format_override == 'refined':
779
+ item_id = item.get('paper_id') or item.get('id')
780
+ if not item_id:
781
+ continue
782
+ format_type = 'refined'
783
+ review_text = item.get('review_markdown', '') or item.get('review', '')
784
+ scores = item.get('scores', {})
785
+ initial_scores = item.get('initial_scores', {})
786
+ elif format_override == 'original':
787
+ item_id = item.get('id')
788
+ if not item_id:
789
+ continue
790
+ format_type = 'original'
791
+ model_prediction = item.get('model_prediction', {})
792
+ meta_review = model_prediction.get('meta_review', {})
793
+ review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
794
+ initial_scores = {
795
+ 'rating': meta_review.get('rating'),
796
+ 'soundness': meta_review.get('soundness'),
797
+ 'presentation': meta_review.get('presentation'),
798
+ 'contribution': meta_review.get('contribution'),
799
+ 'decision': model_prediction.get('decision'),
800
+ }
801
+ else:
802
+ raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
803
+ else:
804
+ # Auto-detect format
805
+ if "paper_id" in item:
806
+ # Refined format (from refinement pipeline)
807
+ item_id = item.get('paper_id')
808
+ if not item_id:
809
+ continue
810
+
811
+ # Check if this is refined format (has scores and initial_scores)
812
+ if 'scores' in item and 'initial_scores' in item:
813
+ format_type = 'refined'
814
+ review_text = item.get('review_markdown', '') or item.get('review', '')
815
+ scores = item.get('scores', {})
816
+ initial_scores = item.get('initial_scores', {})
817
+ else:
818
+ # Standard format with paper_id
819
+ format_type = 'standard'
820
+ review_text = item.get('review_markdown', '') or item.get('review', '')
821
+ elif "model_prediction" in item:
822
+ # Original format (like ours.json)
823
+ item_id = item.get('id')
824
+ if not item_id:
825
+ continue
826
+
827
+ format_type = 'original'
828
+ model_prediction = item.get('model_prediction', {})
829
+ meta_review = model_prediction.get('meta_review', {})
830
+
831
+ # Extract review content (prefer meta_review.content, fallback to raw_text)
832
+ review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
833
+
834
+ # Extract initial scores
835
+ initial_scores = {
836
+ 'rating': meta_review.get('rating'),
837
+ 'soundness': meta_review.get('soundness'),
838
+ 'presentation': meta_review.get('presentation'),
839
+ 'contribution': meta_review.get('contribution'),
840
+ 'decision': model_prediction.get('decision'),
841
+ }
842
+ else:
843
+ # Legacy format (pred_fast_mode)
844
+ item_id = item.get('id')
845
+ if not item_id:
846
+ continue
847
+
848
+ format_type = 'legacy'
849
+ review_dict = item.get('pred_fast_mode', {})
850
+ if isinstance(review_dict, dict):
851
+ # review_text = review_dict.get('raw_text', '')
852
+ review_text = review_dict
853
+ else:
854
+ review_text = str(review_dict)
855
+
856
+ # Extract review content from the review text field
857
+ try:
858
+ if review_text:
859
+ extracted_review = ReviewProcessor.extract_review_content(review_text)
860
+ else:
861
+ extracted_review = ''
862
+
863
+ reviews_dict[item_id] = {
864
+ 'review': extracted_review,
865
+ 'scores': scores,
866
+ 'initial_scores': initial_scores,
867
+ 'format': format_type
868
+ }
869
+ except Exception as e:
870
+ print(f"[WARN] Failed to extract review for {item_id}: {e}")
871
+ continue
872
+
873
+ return reviews_dict
874
+
875
+
876
+ def combine_rubrics_and_reviews(
877
+ rubrics_data: Dict[str, Dict[str, Any]],
878
+ reviews_dict: Dict[str, Dict[str, Any]]
879
+ ) -> List[Dict[str, Any]]:
880
+ """
881
+ Combine rubrics and reviews into evaluation entries.
882
+
883
+ Args:
884
+ rubrics_data: Dict mapping paper_id to rubric entry
885
+ reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
886
+
887
+ Returns:
888
+ List of evaluation entries with model_review, scores, initial_scores, and format info
889
+ """
890
+ combined = []
891
+ missing_reviews = []
892
+
893
+ for paper_id, rubric_entry in rubrics_data.items():
894
+ review_data = reviews_dict.get(paper_id)
895
+ if not review_data or not review_data.get('review'):
896
+ missing_reviews.append(paper_id)
897
+ continue
898
+
899
+ entry = {
900
+ 'id': paper_id,
901
+ 'paper_context': rubric_entry.get('paper_context', ''),
902
+ 'decision': rubric_entry.get('decision', ''),
903
+ 'golden_review': rubric_entry.get('golden_review', ''),
904
+ 'rubrics': rubric_entry.get('rubrics', []),
905
+ 'model_review': review_data.get('review', ''),
906
+ 'scores': review_data.get('scores'), # Refined scores (if available)
907
+ 'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
908
+ 'format': review_data.get('format', 'unknown') # Format type
909
+ }
910
+ combined.append(entry)
911
+
912
+ if missing_reviews:
913
+ print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
914
+
915
+ return combined
916
+
917
+
918
+ # ============================================================================
919
+ # LLM Service Configuration
920
+ # ============================================================================
921
+
922
+ def load_llm_config(config_path: str) -> Dict[str, Any]:
923
+ """Load LLM configuration from YAML file."""
924
+ with open(config_path, 'r', encoding='utf-8') as f:
925
+ config = yaml.safe_load(f)
926
+ return config
927
+
928
+
929
+ def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
930
+ """Create LLM service from configuration."""
931
+ mode = config.get('mode', 'gpt').lower()
932
+
933
+ if mode == 'gpt':
934
+ gpt_config = config.get('gpt', {})
935
+ api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
936
+ if not api_key:
937
+ raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
938
+
939
+ service = GPTService(
940
+ api_key=api_key,
941
+ model_name=gpt_config.get('model_name', 'gpt-4o'),
942
+ base_url=gpt_config.get('base_url'),
943
+ timeout=gpt_config.get('timeout', 300)
944
+ )
945
+ return service
946
+
947
+ elif mode == 'vllm':
948
+ vllm_config = config.get('vllm', {})
949
+ service = VLLMService(
950
+ base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
951
+ api_key=vllm_config.get('api_key', 'dummy-key'),
952
+ model_name=vllm_config.get('model_name'),
953
+ timeout=vllm_config.get('timeout', 300),
954
+ max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
955
+ max_retries=vllm_config.get('max_retries', 3),
956
+ retry_delay=vllm_config.get('retry_delay', 1.0),
957
+ retry_backoff=vllm_config.get('retry_backoff', 2.0)
958
+ )
959
+ return service
960
+
961
+ else:
962
+ raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
963
+
964
+
965
+ # ============================================================================
966
+ # Main Evaluation Functions
967
+ # ============================================================================
968
+
969
+ def run_semantic_evaluation(
970
+ evaluation_data: List[Dict[str, Any]],
971
+ prompt_template: str,
972
+ llm_service: LLMService,
973
+ max_workers: int
974
+ ) -> tuple:
975
+ """Run semantic evaluation and return results and summary."""
976
+ print(f"\n{'='*80}")
977
+ print("RUNNING SEMANTIC EVALUATION")
978
+ print(f"{'='*80}")
979
+ print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
980
+
981
+ results = []
982
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
983
+ future_to_entry = {
984
+ executor.submit(
985
+ evaluate_review_semantic,
986
+ entry,
987
+ entry['paper_context'],
988
+ prompt_template,
989
+ llm_service
990
+ ): entry
991
+ for entry in evaluation_data
992
+ }
993
+
994
+ for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
995
+ try:
996
+ result = future.result()
997
+ results.append(result)
998
+ except Exception as e:
999
+ entry = future_to_entry[future]
1000
+ print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
1001
+ results.append({
1002
+ 'id': entry.get('id', 'unknown'),
1003
+ 'raw_scores': {},
1004
+ 'weighted_scores': {},
1005
+ 'total_score': 0.0,
1006
+ 'error': str(e),
1007
+ 'raw_response': ''
1008
+ })
1009
+
1010
+ # Calculate statistics
1011
+ valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
1012
+ review_scores = [r.get('total_score', 0.0) for r in valid_results]
1013
+
1014
+ summary = {
1015
+ 'total_entries': len(results),
1016
+ 'valid_entries': len(valid_results),
1017
+ 'failed_entries': len(results) - len(valid_results)
1018
+ }
1019
+
1020
+ if review_scores:
1021
+ summary['overall_score'] = {
1022
+ 'mean': sum(review_scores) / len(review_scores),
1023
+ 'min': min(review_scores),
1024
+ 'max': max(review_scores)
1025
+ }
1026
+
1027
+ # Calculate per-rubric statistics (extract rubric titles from first entry)
1028
+ if evaluation_data and evaluation_data[0].get('rubrics'):
1029
+ rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
1030
+ per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
1031
+ summary['per_rubric_statistics'] = per_rubric_stats
1032
+
1033
+ return results, summary
1034
+
1035
+
1036
+ def run_auto_metric_evaluation(
1037
+ evaluation_data: List[Dict[str, Any]],
1038
+ strict_mode: bool = False
1039
+ ) -> tuple:
1040
+ """
1041
+ Run auto-metric evaluation and return results and summary.
1042
+
1043
+ For refined format (has scores and initial_scores), evaluates both:
1044
+ - Refined scores evaluation
1045
+ - Initial scores evaluation
1046
+
1047
+ For original format (only initial_scores), evaluates:
1048
+ - Initial scores evaluation only
1049
+
1050
+ Returns:
1051
+ Tuple of (results_list, summary_dict)
1052
+ - results_list: List of evaluation results (may contain both refined and initial results for refined format)
1053
+ - summary_dict: Summary statistics
1054
+ """
1055
+ print(f"\n{'='*80}")
1056
+ print("RUNNING AUTO-METRIC EVALUATION")
1057
+ print(f"{'='*80}")
1058
+ print(f"Evaluating {len(evaluation_data)} entries...")
1059
+
1060
+ # Detect format types
1061
+ refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
1062
+ original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
1063
+
1064
+ if refined_format_count > 0:
1065
+ print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
1066
+ if original_format_count > 0:
1067
+ print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
1068
+
1069
+ results = []
1070
+ for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
1071
+ format_type = entry.get('format', 'unknown')
1072
+
1073
+ if format_type == 'refined':
1074
+ # Evaluate both refined scores and initial scores
1075
+ try:
1076
+ entry_id = entry.get('id', 'unknown')
1077
+
1078
+ # Evaluate refined scores
1079
+ refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1080
+ refined_result['paper_id'] = entry_id # Keep original paper_id
1081
+ refined_result['id'] = f"{entry_id}_refined"
1082
+ results.append(refined_result)
1083
+
1084
+ # Evaluate initial scores
1085
+ initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
1086
+ initial_result['paper_id'] = entry_id # Keep original paper_id
1087
+ initial_result['id'] = f"{entry_id}_initial"
1088
+ results.append(initial_result)
1089
+ except Exception as e:
1090
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1091
+ results.append({
1092
+ 'id': entry.get('id', 'unknown'),
1093
+ 'error': str(e)
1094
+ })
1095
+ else:
1096
+ # Evaluate initial scores only (or extract from markdown)
1097
+ try:
1098
+ result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1099
+ results.append(result)
1100
+ except Exception as e:
1101
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1102
+ results.append({
1103
+ 'id': entry.get('id', 'unknown'),
1104
+ 'error': str(e)
1105
+ })
1106
+
1107
+ # Calculate statistics
1108
+ valid_results = [r for r in results if 'error' not in r]
1109
+ mse_results = [r for r in valid_results if r.get('overall_error') is not None]
1110
+
1111
+ # Separate refined and initial results for refined format
1112
+ refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
1113
+ initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
1114
+ auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
1115
+
1116
+ summary = {
1117
+ 'total_entries': len(results),
1118
+ 'valid_entries': len(valid_results),
1119
+ 'mse_entries': len(mse_results),
1120
+ 'refined_results_count': len(refined_results),
1121
+ 'initial_results_count': len(initial_results),
1122
+ 'auto_results_count': len(auto_results)
1123
+ }
1124
+
1125
+ # Calculate MSE/MAE statistics
1126
+ # For refined format, only use refined results for overall statistics (avoid double counting)
1127
+ # For other formats, use all results
1128
+ if refined_format_count > 0:
1129
+ # Refined format: use only refined results for overall statistics
1130
+ stats_results = [r for r in refined_results if r.get('overall_error') is not None]
1131
+ else:
1132
+ # Original/other formats: use all results
1133
+ stats_results = mse_results
1134
+
1135
+ if stats_results:
1136
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1137
+ mse_stats = {}
1138
+ mae_stats = {}
1139
+
1140
+ for dim in dimensions:
1141
+ mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
1142
+ mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
1143
+
1144
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1145
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1146
+
1147
+ if mse_clean:
1148
+ mse_stats[dim] = {
1149
+ 'mean': sum(mse_clean) / len(mse_clean),
1150
+ 'count': len(mse_clean)
1151
+ }
1152
+ if mae_clean:
1153
+ mae_stats[dim] = {
1154
+ 'mean': sum(mae_clean) / len(mae_clean),
1155
+ 'count': len(mae_clean)
1156
+ }
1157
+
1158
+ overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
1159
+ overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
1160
+
1161
+ if overall_clean:
1162
+ summary['overall_error'] = {
1163
+ 'mean': sum(overall_clean) / len(overall_clean),
1164
+ 'count': len(overall_clean)
1165
+ }
1166
+
1167
+ summary['mse_statistics'] = mse_stats
1168
+ summary['mae_statistics'] = mae_stats
1169
+
1170
+ # Calculate separate statistics for refined and initial results
1171
+ if refined_results:
1172
+ refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
1173
+ if refined_mse_results:
1174
+ refined_mse_stats = {}
1175
+ refined_mae_stats = {}
1176
+ for dim in dimensions:
1177
+ mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
1178
+ mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
1179
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1180
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1181
+ if mse_clean:
1182
+ refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1183
+ if mae_clean:
1184
+ refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1185
+ summary['refined_mse_statistics'] = refined_mse_stats
1186
+ summary['refined_mae_statistics'] = refined_mae_stats
1187
+
1188
+ if initial_results:
1189
+ initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
1190
+ if initial_mse_results:
1191
+ initial_mse_stats = {}
1192
+ initial_mae_stats = {}
1193
+ for dim in dimensions:
1194
+ mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
1195
+ mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
1196
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1197
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1198
+ if mse_clean:
1199
+ initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1200
+ if mae_clean:
1201
+ initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1202
+ summary['initial_mse_statistics'] = initial_mse_stats
1203
+ summary['initial_mae_statistics'] = initial_mae_stats
1204
+
1205
+ # Calculate Spearman correlations
1206
+ def filter_valid_pairs(true_list, pred_list):
1207
+ filtered_true = []
1208
+ filtered_pred = []
1209
+ for t, p in zip(true_list, pred_list):
1210
+ if (t is not None and p is not None and
1211
+ not (isinstance(t, float) and math.isnan(t)) and
1212
+ not (isinstance(p, float) and math.isnan(p))):
1213
+ filtered_true.append(t)
1214
+ filtered_pred.append(p)
1215
+ return filtered_true, filtered_pred
1216
+
1217
+ # Calculate Spearman correlations
1218
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1219
+ # For other formats, use all results
1220
+ if refined_format_count > 0:
1221
+ # Calculate refined spearman correlations
1222
+ refined_spearman_stats = {}
1223
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1224
+ for dim in dimensions:
1225
+ true_values = [r.get(f'gt_{dim}') for r in refined_results]
1226
+ pred_values = [r.get(f'model_{dim}') for r in refined_results]
1227
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1228
+
1229
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1230
+ try:
1231
+ corr, _ = spearmanr(true_clean, pred_clean)
1232
+ if not math.isnan(corr):
1233
+ refined_spearman_stats[dim] = {
1234
+ 'correlation': corr,
1235
+ 'count': len(true_clean)
1236
+ }
1237
+ except Exception:
1238
+ pass
1239
+
1240
+ # Calculate initial spearman correlations
1241
+ initial_spearman_stats = {}
1242
+ for dim in dimensions:
1243
+ true_values = [r.get(f'gt_{dim}') for r in initial_results]
1244
+ pred_values = [r.get(f'model_{dim}') for r in initial_results]
1245
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1246
+
1247
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1248
+ try:
1249
+ corr, _ = spearmanr(true_clean, pred_clean)
1250
+ if not math.isnan(corr):
1251
+ initial_spearman_stats[dim] = {
1252
+ 'correlation': corr,
1253
+ 'count': len(true_clean)
1254
+ }
1255
+ except Exception:
1256
+ pass
1257
+
1258
+ # Use refined for overall statistics (avoid double counting)
1259
+ summary['spearman_correlations'] = refined_spearman_stats
1260
+ summary['refined_spearman_correlations'] = refined_spearman_stats
1261
+ summary['initial_spearman_correlations'] = initial_spearman_stats
1262
+ else:
1263
+ # Original/other formats: use all results
1264
+ correlation_results = valid_results
1265
+ spearman_stats = {}
1266
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1267
+ for dim in dimensions:
1268
+ true_values = [r.get(f'gt_{dim}') for r in correlation_results]
1269
+ pred_values = [r.get(f'model_{dim}') for r in correlation_results]
1270
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1271
+
1272
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1273
+ try:
1274
+ corr, _ = spearmanr(true_clean, pred_clean)
1275
+ if not math.isnan(corr):
1276
+ spearman_stats[dim] = {
1277
+ 'correlation': corr,
1278
+ 'count': len(true_clean)
1279
+ }
1280
+ except Exception:
1281
+ pass
1282
+
1283
+ summary['spearman_correlations'] = spearman_stats
1284
+
1285
+ # Calculate Decision metrics
1286
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1287
+ # For other formats, use all results
1288
+ if refined_format_count > 0:
1289
+ # Calculate refined decision metrics
1290
+ refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1291
+ if refined_decision_results:
1292
+ true_decisions = []
1293
+ pred_decisions = []
1294
+ decision_acc = []
1295
+
1296
+ for r in refined_decision_results:
1297
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1298
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1299
+
1300
+ if 'accept' in pred_decision:
1301
+ pred_binary = 1
1302
+ else:
1303
+ pred_binary = 0
1304
+
1305
+ if 'accept' in gt_decision:
1306
+ gt_binary = 1
1307
+ else:
1308
+ gt_binary = 0
1309
+
1310
+ true_decisions.append(gt_binary)
1311
+ pred_decisions.append(pred_binary)
1312
+
1313
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1314
+ decision_acc.append(1.0)
1315
+ else:
1316
+ decision_acc.append(0.0)
1317
+
1318
+ if decision_acc:
1319
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1320
+ try:
1321
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1322
+ refined_decision_metrics = {
1323
+ 'accuracy': decision_accuracy,
1324
+ 'f1_macro': f1_score,
1325
+ 'count': len(decision_acc)
1326
+ }
1327
+ except Exception:
1328
+ refined_decision_metrics = {
1329
+ 'accuracy': decision_accuracy,
1330
+ 'count': len(decision_acc)
1331
+ }
1332
+ summary['refined_decision_metrics'] = refined_decision_metrics
1333
+ summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
1334
+
1335
+ # Calculate initial decision metrics
1336
+ initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1337
+ if initial_decision_results:
1338
+ true_decisions = []
1339
+ pred_decisions = []
1340
+ decision_acc = []
1341
+
1342
+ for r in initial_decision_results:
1343
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1344
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1345
+
1346
+ if 'accept' in pred_decision:
1347
+ pred_binary = 1
1348
+ else:
1349
+ pred_binary = 0
1350
+
1351
+ if 'accept' in gt_decision:
1352
+ gt_binary = 1
1353
+ else:
1354
+ gt_binary = 0
1355
+
1356
+ true_decisions.append(gt_binary)
1357
+ pred_decisions.append(pred_binary)
1358
+
1359
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1360
+ decision_acc.append(1.0)
1361
+ else:
1362
+ decision_acc.append(0.0)
1363
+
1364
+ if decision_acc:
1365
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1366
+ try:
1367
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1368
+ initial_decision_metrics = {
1369
+ 'accuracy': decision_accuracy,
1370
+ 'f1_macro': f1_score,
1371
+ 'count': len(decision_acc)
1372
+ }
1373
+ except Exception:
1374
+ initial_decision_metrics = {
1375
+ 'accuracy': decision_accuracy,
1376
+ 'count': len(decision_acc)
1377
+ }
1378
+ summary['initial_decision_metrics'] = initial_decision_metrics
1379
+ else:
1380
+ # Original/other formats: use all results
1381
+ decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1382
+ if decision_results:
1383
+ true_decisions = []
1384
+ pred_decisions = []
1385
+ decision_acc = []
1386
+
1387
+ for r in decision_results:
1388
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1389
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1390
+
1391
+ if 'accept' in pred_decision:
1392
+ pred_binary = 1
1393
+ else:
1394
+ pred_binary = 0
1395
+
1396
+ if 'accept' in gt_decision:
1397
+ gt_binary = 1
1398
+ else:
1399
+ gt_binary = 0
1400
+
1401
+ true_decisions.append(gt_binary)
1402
+ pred_decisions.append(pred_binary)
1403
+
1404
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1405
+ decision_acc.append(1.0)
1406
+ else:
1407
+ decision_acc.append(0.0)
1408
+
1409
+ if decision_acc:
1410
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1411
+ try:
1412
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1413
+ summary['decision_metrics'] = {
1414
+ 'accuracy': decision_accuracy,
1415
+ 'f1_macro': f1_score,
1416
+ 'count': len(decision_acc)
1417
+ }
1418
+ except Exception:
1419
+ summary['decision_metrics'] = {
1420
+ 'accuracy': decision_accuracy,
1421
+ 'count': len(decision_acc)
1422
+ }
1423
+
1424
+ # Calculate Pairwise comparison
1425
+ # For refined format, only use refined results (avoid double counting)
1426
+ # For other formats, use all results
1427
+ if refined_format_count > 0:
1428
+ pairwise_results = refined_results
1429
+ else:
1430
+ pairwise_results = valid_results
1431
+
1432
+ paper_scores = []
1433
+ for r in pairwise_results:
1434
+ if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
1435
+ (r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
1436
+ paper_scores.append({
1437
+ 'true_rating': r.get('gt_rating'),
1438
+ 'pred_rating': r.get('model_rating'),
1439
+ 'true_soundness': r.get('gt_soundness'),
1440
+ 'pred_soundness': r.get('model_soundness'),
1441
+ 'true_presentation': r.get('gt_presentation'),
1442
+ 'pred_presentation': r.get('model_presentation'),
1443
+ 'true_confidence': r.get('gt_confidence'),
1444
+ 'pred_confidence': r.get('model_confidence')
1445
+ })
1446
+
1447
+ if len(paper_scores) >= 2:
1448
+ pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
1449
+ summary['pairwise_accuracies'] = pairwise_accuracies
1450
+
1451
+ return results, summary
1452
+
1453
+
1454
+ # ============================================================================
1455
+ # Main Function
1456
+ # ============================================================================
1457
+
1458
+ def parse_args():
1459
+ """Parse command line arguments."""
1460
+ parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
1461
+
1462
+ # Input paths
1463
+ parser.add_argument("--rubrics_path", type=str, required=True,
1464
+ help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
1465
+ parser.add_argument("--reviews_path", type=str, required=True,
1466
+ help="Path to JSON file with model reviews (contains pred_fast_mode)")
1467
+
1468
+ # Evaluation mode
1469
+ parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
1470
+ help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
1471
+
1472
+ # Output paths
1473
+ parser.add_argument("--semantic_output", type=str, default=None,
1474
+ help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
1475
+ parser.add_argument("--auto_metric_output", type=str, default=None,
1476
+ help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
1477
+
1478
+ # Semantic evaluation settings
1479
+ parser.add_argument("--yaml_path", type=str, default=None,
1480
+ help="Path to prompts.yaml file (required for semantic evaluation)")
1481
+ parser.add_argument("--config_path", type=str, default=None,
1482
+ help="Path to configs.yaml file (required for semantic evaluation)")
1483
+
1484
+ # Multi-threading
1485
+ parser.add_argument("--max_workers", type=int, default=None,
1486
+ help="Maximum number of worker threads for semantic evaluation (default: 5)")
1487
+
1488
+ # Strict mode (normalize scores to discrete scales)
1489
+ parser.add_argument("--strict_mode", action="store_true", default=False,
1490
+ help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
1491
+
1492
+ # Input format override
1493
+ parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
1494
+ help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
1495
+
1496
+ return parser.parse_args()
1497
+
1498
+
1499
+ def main():
1500
+ """Main execution function."""
1501
+ args = parse_args()
1502
+
1503
+ script_dir = os.path.dirname(os.path.abspath(__file__))
1504
+
1505
+ # Resolve paths
1506
+ rubrics_path = args.rubrics_path
1507
+ if not os.path.isabs(rubrics_path):
1508
+ rubrics_path = os.path.join(script_dir, rubrics_path)
1509
+
1510
+ reviews_path = args.reviews_path
1511
+ if not os.path.isabs(reviews_path):
1512
+ reviews_path = os.path.join(script_dir, reviews_path)
1513
+
1514
+ max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
1515
+
1516
+ # Validate mode and output paths
1517
+ if args.mode in ["semantic", "both"]:
1518
+ if not args.semantic_output:
1519
+ raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
1520
+ if not args.yaml_path:
1521
+ raise ValueError("--yaml_path is required for semantic evaluation")
1522
+ if not args.config_path:
1523
+ raise ValueError("--config_path is required for semantic evaluation")
1524
+
1525
+ if args.mode in ["auto_metric", "both"]:
1526
+ if not args.auto_metric_output:
1527
+ raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
1528
+
1529
+ # Check if files exist
1530
+ if not os.path.exists(rubrics_path):
1531
+ raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
1532
+ if not os.path.exists(reviews_path):
1533
+ raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
1534
+
1535
+ # Load data
1536
+ print(f"Loading rubrics from {rubrics_path}...")
1537
+ rubrics_data = load_rubrics_json(rubrics_path)
1538
+ print(f"Loaded {len(rubrics_data)} rubrics entries")
1539
+
1540
+ print(f"Loading model reviews from {reviews_path}...")
1541
+ if args.input_format != 'auto':
1542
+ print(f"Using manually specified format: {args.input_format}")
1543
+ else:
1544
+ print("Auto-detecting input format...")
1545
+ reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
1546
+ print(f"Loaded {len(reviews_dict)} model reviews")
1547
+
1548
+ # Combine rubrics and reviews
1549
+ print("Combining rubrics and reviews...")
1550
+ evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
1551
+ print(f"Prepared {len(evaluation_data)} entries for evaluation")
1552
+
1553
+ # Run evaluations based on mode
1554
+ if args.mode in ["semantic", "both"]:
1555
+ # Resolve semantic evaluation paths
1556
+ yaml_path = args.yaml_path
1557
+ if not os.path.isabs(yaml_path):
1558
+ yaml_path = os.path.join(script_dir, yaml_path)
1559
+
1560
+ config_path = args.config_path
1561
+ if not os.path.isabs(config_path):
1562
+ config_path = os.path.join(script_dir, config_path)
1563
+
1564
+ if not os.path.exists(yaml_path):
1565
+ raise FileNotFoundError(f"YAML file not found: {yaml_path}")
1566
+ if not os.path.exists(config_path):
1567
+ raise FileNotFoundError(f"Config file not found: {config_path}")
1568
+
1569
+ # Load prompt template
1570
+ print(f"Loading prompt template from {yaml_path}...")
1571
+ prompt_template = load_prompt_template(yaml_path)
1572
+ if not prompt_template:
1573
+ raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
1574
+
1575
+ # Initialize LLM service
1576
+ print(f"Loading LLM configuration from {config_path}...")
1577
+ llm_config = load_llm_config(config_path)
1578
+ llm_service = create_llm_service_from_config(llm_config)
1579
+ mode = llm_config.get('mode', 'gpt')
1580
+ print(f"LLM service initialized (mode: {mode})")
1581
+ if hasattr(llm_service, 'model_name'):
1582
+ print(f"Using model: {llm_service.model_name}")
1583
+
1584
+ # Run semantic evaluation
1585
+ semantic_results, semantic_summary = run_semantic_evaluation(
1586
+ evaluation_data, prompt_template, llm_service, max_workers
1587
+ )
1588
+
1589
+ # Save semantic results
1590
+ semantic_output = args.semantic_output
1591
+ if not os.path.isabs(semantic_output):
1592
+ semantic_output = os.path.join(script_dir, semantic_output)
1593
+
1594
+ output_dir = os.path.dirname(semantic_output)
1595
+ os.makedirs(output_dir, exist_ok=True)
1596
+
1597
+ with open(semantic_output, 'w', encoding='utf-8') as f:
1598
+ json.dump(semantic_results, f, ensure_ascii=False, indent=2)
1599
+ print(f"\nSemantic evaluation results saved to {semantic_output}")
1600
+
1601
+ # Save semantic summary
1602
+ semantic_summary_path = semantic_output.replace('.json', '_summary.json')
1603
+ with open(semantic_summary_path, 'w', encoding='utf-8') as f:
1604
+ json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
1605
+ print(f"Semantic evaluation summary saved to {semantic_summary_path}")
1606
+
1607
+ # Print semantic summary
1608
+ print("\n" + "="*80)
1609
+ print("SEMANTIC EVALUATION SUMMARY")
1610
+ print("="*80)
1611
+ print(f"Total entries: {semantic_summary['total_entries']}")
1612
+ print(f"Valid entries: {semantic_summary['valid_entries']}")
1613
+ print(f"Failed entries: {semantic_summary['failed_entries']}")
1614
+ if 'overall_score' in semantic_summary:
1615
+ score = semantic_summary['overall_score']
1616
+ print(f"\nOverall Score:")
1617
+ print(f" Mean: {score['mean']:.2f}")
1618
+ print(f" Min: {score['min']:.2f}")
1619
+ print(f" Max: {score['max']:.2f}")
1620
+
1621
+ if args.mode in ["auto_metric", "both"]:
1622
+ # Run auto-metric evaluation
1623
+ auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
1624
+ evaluation_data,
1625
+ strict_mode=args.strict_mode
1626
+ )
1627
+
1628
+ # Save auto-metric results
1629
+ auto_metric_output = args.auto_metric_output
1630
+ if not os.path.isabs(auto_metric_output):
1631
+ auto_metric_output = os.path.join(script_dir, auto_metric_output)
1632
+
1633
+ output_dir = os.path.dirname(auto_metric_output)
1634
+ os.makedirs(output_dir, exist_ok=True)
1635
+
1636
+ with open(auto_metric_output, 'w', encoding='utf-8') as f:
1637
+ json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
1638
+ print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
1639
+
1640
+ # Save auto-metric summary
1641
+ auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
1642
+ with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
1643
+ json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
1644
+ print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
1645
+
1646
+ # Print auto-metric summary
1647
+ print("\n" + "="*80)
1648
+ print("AUTO-METRIC EVALUATION SUMMARY")
1649
+ print("="*80)
1650
+ print(f"Total entries: {auto_metric_summary['total_entries']}")
1651
+ print(f"Valid entries: {auto_metric_summary['valid_entries']}")
1652
+ print(f"MSE entries: {auto_metric_summary['mse_entries']}")
1653
+
1654
+ if 'mse_statistics' in auto_metric_summary:
1655
+ print("\nMSE Statistics:")
1656
+ for dim, stats in auto_metric_summary['mse_statistics'].items():
1657
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1658
+
1659
+ if 'mae_statistics' in auto_metric_summary:
1660
+ print("\nMAE Statistics:")
1661
+ for dim, stats in auto_metric_summary['mae_statistics'].items():
1662
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1663
+
1664
+ # Print refined and initial statistics if available
1665
+ if 'refined_mse_statistics' in auto_metric_summary:
1666
+ print("\nRefined Scores - MSE Statistics:")
1667
+ for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
1668
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1669
+
1670
+ if 'refined_mae_statistics' in auto_metric_summary:
1671
+ print("\nRefined Scores - MAE Statistics:")
1672
+ for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
1673
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1674
+
1675
+ if 'initial_mse_statistics' in auto_metric_summary:
1676
+ print("\nInitial Scores - MSE Statistics:")
1677
+ for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
1678
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1679
+
1680
+ if 'initial_mae_statistics' in auto_metric_summary:
1681
+ print("\nInitial Scores - MAE Statistics:")
1682
+ for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
1683
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1684
+
1685
+ if 'spearman_correlations' in auto_metric_summary:
1686
+ print("\nSpearman Correlations:")
1687
+ for dim, stats in auto_metric_summary['spearman_correlations'].items():
1688
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1689
+
1690
+ # Print refined and initial spearman correlations if available
1691
+ if 'refined_spearman_correlations' in auto_metric_summary:
1692
+ print("\nRefined Scores - Spearman Correlations:")
1693
+ for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
1694
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1695
+
1696
+ if 'initial_spearman_correlations' in auto_metric_summary:
1697
+ print("\nInitial Scores - Spearman Correlations:")
1698
+ for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
1699
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1700
+
1701
+ if 'decision_metrics' in auto_metric_summary:
1702
+ dm = auto_metric_summary['decision_metrics']
1703
+ print(f"\nDecision Metrics:")
1704
+ print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
1705
+ if 'f1_macro' in dm:
1706
+ print(f" F1 (macro): {dm['f1_macro']:.4f}")
1707
+
1708
+ # Print refined and initial decision metrics if available
1709
+ if 'refined_decision_metrics' in auto_metric_summary:
1710
+ print("\nRefined Scores - Decision Metrics:")
1711
+ rdm = auto_metric_summary['refined_decision_metrics']
1712
+ print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
1713
+ if 'f1_macro' in rdm:
1714
+ print(f" F1 (macro): {rdm['f1_macro']:.4f}")
1715
+
1716
+ if 'initial_decision_metrics' in auto_metric_summary:
1717
+ print("\nInitial Scores - Decision Metrics:")
1718
+ idm = auto_metric_summary['initial_decision_metrics']
1719
+ print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
1720
+ if 'f1_macro' in idm:
1721
+ print(f" F1 (macro): {idm['f1_macro']:.4f}")
1722
+
1723
+ print("\n" + "="*80)
1724
+ print("EVALUATION COMPLETE")
1725
+ print("="*80)
1726
+
1727
+
1728
+ if __name__ == "__main__":
1729
+ main()
1730
+
src/evaluator/2_evaluate_agenticreview.py ADDED
@@ -0,0 +1,1866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
3
+
4
+ This script:
5
+ 1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
6
+ 2. Reads input JSON file containing model reviews (supports multiple formats)
7
+ 3. Supports three evaluation modes:
8
+ - semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
9
+ - auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
10
+ - both: Run both evaluations separately
11
+ 4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
12
+ 5. Outputs separate JSON files for results and summaries
13
+
14
+ Usage:
15
+ # Semantic evaluation only
16
+ python 2_evaluate.py \
17
+ --rubrics_path eval_rubrics.json \
18
+ --reviews_path model_reviews.json \
19
+ --mode semantic \
20
+ --yaml_path prompts.yaml \
21
+ --config_path configs.yaml \
22
+ --semantic_output semantic_results.json \
23
+ --max_workers 5
24
+
25
+ # Auto-metric evaluation only
26
+ python 2_evaluate.py \
27
+ --rubrics_path eval_rubrics.json \
28
+ --reviews_path model_reviews.json \
29
+ --mode auto_metric \
30
+ --auto_metric_output auto_metric_results.json
31
+
32
+ # Auto-metric evaluation with strict mode (normalize scores to discrete scales)
33
+ python 2_evaluate.py \
34
+ --rubrics_path eval_rubrics.json \
35
+ --reviews_path model_reviews.json \
36
+ --mode auto_metric \
37
+ --auto_metric_output auto_metric_results.json \
38
+ --strict_mode
39
+
40
+ # Auto-metric evaluation with manually specified input format (refined)
41
+ python 2_evaluate.py \
42
+ --rubrics_path eval_rubrics.json \
43
+ --reviews_path model_reviews.json \
44
+ --mode auto_metric \
45
+ --auto_metric_output auto_metric_results.json \
46
+ --input_format refined
47
+
48
+ # Auto-metric evaluation with manually specified input format (original)
49
+ python 2_evaluate.py \
50
+ --rubrics_path eval_rubrics.json \
51
+ --reviews_path ours.json \
52
+ --mode auto_metric \
53
+ --auto_metric_output auto_metric_results.json \
54
+ --input_format original
55
+
56
+ # Both evaluations
57
+ python 2_evaluate.py \
58
+ --rubrics_path eval_rubrics.json \
59
+ --reviews_path model_reviews.json \
60
+ --mode both \
61
+ --yaml_path prompts.yaml \
62
+ --config_path configs.yaml \
63
+ --semantic_output semantic_results.json \
64
+ --auto_metric_output auto_metric_results.json \
65
+ --max_workers 32
66
+ """
67
+ from __future__ import annotations
68
+
69
+ import json
70
+ import os
71
+ import sys
72
+ import argparse
73
+ import yaml
74
+ import math
75
+ import re
76
+ from typing import Dict, List, Any, Optional
77
+ from concurrent.futures import ThreadPoolExecutor, as_completed
78
+ from tqdm import tqdm
79
+ from itertools import combinations
80
+ from scipy.stats import spearmanr
81
+ from sklearn.metrics import precision_recall_fscore_support
82
+
83
+ # Add parent directory to path
84
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
85
+ # Import parse_llm_response from local llm_service module
86
+ import llm_service as local_llm_service
87
+ parse_llm_response = local_llm_service.parse_llm_response
88
+
89
+ # Import from shared/utils for gpt/vllm support
90
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
91
+ if project_root not in sys.path:
92
+ sys.path.insert(0, project_root)
93
+
94
+ from shared.utils.llm_service import LLMService
95
+ from shared.utils.vllm_service import VLLMService
96
+ from shared.utils.gpt_service import GPTService
97
+ sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
98
+ from json_parser import parse_review_markdown
99
+
100
+
101
+ def convert_ai_researcher(review: dict) -> str:
102
+ """
103
+ Convert the review text from ai-researcher format to unified review system format.
104
+ """
105
+
106
+ summary = review["Summary"]
107
+ strengths = "\n".join(f"- {s}" for s in review["Strengths"])
108
+ weaknesses = "\n".join(f"- {w}" for w in review["Weaknesses"])
109
+
110
+ # scores
111
+ originality = review["Originality"]
112
+ quality = review["Quality"]
113
+ clarity = review["Clarity"]
114
+ significance = review["Significance"]
115
+
116
+ questions = "\n".join(f"- {q}" for q in review["Questions"])
117
+ limitations = "\n".join(f"- {l}" for l in review["Limitations"])
118
+ ethical_concerns = review["Ethical Concerns"]
119
+
120
+ # scores again
121
+ soundness = review["Soundness"]
122
+ presentation = review["Presentation"]
123
+ contribution = review["Contribution"]
124
+ overall = review["Overall"]
125
+ confidence = review["Confidence"]
126
+
127
+ # final decision
128
+ decision = review["Decision"]
129
+
130
+ meta_review = {
131
+ "rating": overall,
132
+ "soundness": soundness,
133
+ "presentation": presentation,
134
+ "contribution": contribution,
135
+ "confidence": confidence,
136
+ "decision": decision.lower().strip(),
137
+ }
138
+
139
+ return f"Summary: {summary}\nStrengths: {strengths}\nWeaknesses: {weaknesses}\nOriginality: {originality}\nQuality: {quality}\nClarity: {clarity}\nSignificance: {significance}\nQuestions: {questions}\nLimitations: {limitations}\nEthical Concerns: {ethical_concerns}\nSoundness: {soundness}\nPresentation: {presentation}\nContribution: {contribution}\nOverall: {overall}\nConfidence: {confidence}\nDecision: {decision}", meta_review
140
+
141
+
142
+ def convert_agenticreview(review_text: str) -> tuple:
143
+ """
144
+ Convert the review text from agenticreview format to unified review system format.
145
+
146
+ The agenticreview format has text like:
147
+ "Overall rating: 5\n\nSignificance and novelty: ..."
148
+
149
+ Args:
150
+ review_text: Raw review text string
151
+
152
+ Returns:
153
+ Tuple of (formatted_review_text, meta_review_dict)
154
+ """
155
+ # Extract rating from "Overall rating: x" format
156
+ rating = None
157
+ rating_match = re.search(r'Overall\s+rating\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
158
+ if rating_match:
159
+ try:
160
+ rating = float(rating_match.group(1))
161
+ except (ValueError, IndexError):
162
+ pass
163
+
164
+ # If not found, try alternative patterns
165
+ if rating is None:
166
+ rating_match = re.search(r'(?:rating|score)\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
167
+ if rating_match:
168
+ try:
169
+ rating = float(rating_match.group(1))
170
+ except (ValueError, IndexError):
171
+ pass
172
+
173
+ # Try to extract from parse_review_markdown as fallback
174
+ if rating is None:
175
+ try:
176
+ parsed = parse_review_markdown(review_text)
177
+ rating = parsed.get('rating')
178
+ except Exception:
179
+ pass
180
+
181
+ # Create meta_review dict - agenticreview only has rating, no other scores
182
+ meta_review = {
183
+ "rating": rating,
184
+ "soundness": None,
185
+ "presentation": None,
186
+ "contribution": None,
187
+ "confidence": None,
188
+ "decision": None,
189
+ }
190
+
191
+ # Return the review text as-is (it's already in a readable format)
192
+ return review_text, meta_review
193
+
194
+
195
+ class ReviewProcessor:
196
+ """Handles the extraction and processing of reviews from different sources."""
197
+
198
+ @staticmethod
199
+ def extract_review_content(pred_context):
200
+ """
201
+ Extract the review content from the prediction context.
202
+
203
+ Args:
204
+ pred_context: Raw prediction data that contains the review
205
+
206
+ Returns:
207
+ str: Extracted review content
208
+ """
209
+ try:
210
+ # First attempt to extract from boxed format
211
+ return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
212
+ except Exception:
213
+ # Alternative extraction if the first method fails
214
+ if isinstance(pred_context, dict) and 'output' in pred_context:
215
+ return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
216
+ else:
217
+ # Return as is if extraction fails
218
+ return pred_context
219
+
220
+
221
+ # ============================================================================
222
+ # Semantic Evaluation Functions (from 2_evaluate_direct.py)
223
+ # ============================================================================
224
+
225
+ def load_prompt_template(yaml_path: str) -> str:
226
+ """Load the evaluator prompt from YAML file."""
227
+ with open(yaml_path, 'r', encoding='utf-8') as f:
228
+ prompts = yaml.safe_load(f)
229
+ return prompts.get('v1_evaluator_prompt', '')
230
+
231
+
232
+ def build_evaluation_prompt(
233
+ rubrics: List[Dict[str, Any]],
234
+ paper_content: str,
235
+ review: str,
236
+ prompt_template: str
237
+ ) -> str:
238
+ """Build the evaluation prompt by replacing placeholders."""
239
+ rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
240
+ prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
241
+ prompt = prompt.replace('<<paper_content>>', paper_content)
242
+ prompt = prompt.replace('<<review>>', review)
243
+ return prompt
244
+
245
+
246
+ def calculate_weighted_scores(
247
+ raw_scores: Dict[str, Dict[str, Any]],
248
+ rubrics: List[Dict[str, Any]]
249
+ ) -> Dict[str, float]:
250
+ """Calculate weighted scores for each rubric."""
251
+ rubric_weights = {r['title']: r['weight'] for r in rubrics}
252
+ weighted_scores = {}
253
+
254
+ for rubric_title, rubric_data in raw_scores.items():
255
+ if rubric_title not in rubric_weights:
256
+ continue
257
+
258
+ rubric_score = rubric_data.get('score', 0)
259
+ if isinstance(rubric_score, str):
260
+ try:
261
+ rubric_score = int(rubric_score)
262
+ except ValueError:
263
+ rubric_score = 0
264
+
265
+ if rubric_score not in [0, 1]:
266
+ rubric_score = 1 if rubric_score > 0 else 0
267
+
268
+ weight = rubric_weights[rubric_title]
269
+ weighted_scores[rubric_title] = rubric_score * weight
270
+
271
+ return weighted_scores
272
+
273
+
274
+ def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
275
+ """Calculate scores for each rubric."""
276
+ scores = {}
277
+ for rubric_title, rubric_data in raw_scores.items():
278
+ scores[rubric_title] = rubric_data.get('score', 0)
279
+ return scores
280
+
281
+
282
+ def evaluate_review_semantic(
283
+ entry: Dict[str, Any],
284
+ paper_content: str,
285
+ prompt_template: str,
286
+ llm_service: LLMService
287
+ ) -> Dict[str, Any]:
288
+ """Evaluate a single review using article-specific rubrics."""
289
+ entry_id = entry.get('id', 'unknown')
290
+ rubrics = entry.get('rubrics', [])
291
+ model_review = entry.get('model_review', '')
292
+
293
+ if not rubrics:
294
+ return {
295
+ 'id': entry_id,
296
+ 'raw_scores': {},
297
+ 'weighted_scores': {},
298
+ 'total_score': 0.0,
299
+ 'error': 'No valid rubrics found',
300
+ 'raw_response': ''
301
+ }
302
+
303
+ # Build prompt
304
+ prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
305
+
306
+ # Call LLM
307
+ try:
308
+ messages = [{"role": "user", "content": prompt}]
309
+ response = llm_service.generate(messages=messages)
310
+
311
+ # Parse response
312
+ raw_scores = parse_llm_response(response)
313
+ weighted_scores = calculate_scores(raw_scores)
314
+ total_score = sum(weighted_scores.values())
315
+
316
+ return {
317
+ 'id': entry_id,
318
+ 'raw_scores': raw_scores,
319
+ 'weighted_scores': weighted_scores,
320
+ 'total_score': total_score,
321
+ 'raw_response': response
322
+ }
323
+ except Exception as e:
324
+ print(f"[ERROR] Error evaluating review {entry_id}: {e}")
325
+ return {
326
+ 'id': entry_id,
327
+ 'raw_scores': {},
328
+ 'weighted_scores': {},
329
+ 'total_score': 0.0,
330
+ 'error': str(e),
331
+ 'raw_response': ''
332
+ }
333
+
334
+
335
+ def calculate_per_rubric_statistics(
336
+ valid_results: List[Dict[str, Any]],
337
+ rubric_titles: List[str]
338
+ ) -> Dict[str, Dict[str, float]]:
339
+ """Calculate per-rubric statistics from evaluation results."""
340
+ rubric_scores = {title: [] for title in rubric_titles}
341
+
342
+ for result in valid_results:
343
+ weighted_scores = result.get('weighted_scores', {})
344
+ if not isinstance(weighted_scores, dict):
345
+ continue
346
+
347
+ for rubric_title in rubric_titles:
348
+ if rubric_title in weighted_scores:
349
+ score = weighted_scores[rubric_title]
350
+ if isinstance(score, str):
351
+ try:
352
+ score = float(score)
353
+ except ValueError:
354
+ continue
355
+ elif isinstance(score, (int, float)):
356
+ score = float(score)
357
+ else:
358
+ continue
359
+ rubric_scores[rubric_title].append(score)
360
+
361
+ per_rubric_stats = {}
362
+ for rubric_title in rubric_titles:
363
+ scores = rubric_scores[rubric_title]
364
+ if not scores:
365
+ continue
366
+
367
+ mean_score = sum(scores) / len(scores)
368
+ min_score = min(scores)
369
+ max_score = max(scores)
370
+ count = len(scores)
371
+
372
+ if rubric_title == "False or Contradictory Claims":
373
+ pass_count = sum(1 for s in scores if s >= 0)
374
+ else:
375
+ pass_count = sum(1 for s in scores if s >= 1)
376
+ pass_rate = pass_count / count if count > 0 else 0.0
377
+
378
+ per_rubric_stats[rubric_title] = {
379
+ 'mean': mean_score,
380
+ 'min': min_score,
381
+ 'max': max_score,
382
+ 'count': count,
383
+ 'pass_rate': pass_rate
384
+ }
385
+
386
+ return per_rubric_stats
387
+
388
+
389
+ # ============================================================================
390
+ # Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
391
+ # ============================================================================
392
+
393
+ def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
394
+ """Extract numeric scores and decision from a review markdown text."""
395
+ if not review_text:
396
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
397
+
398
+ try:
399
+ parsed = parse_review_markdown(review_text)
400
+ decision = parsed.get('decision', '')
401
+ if decision:
402
+ decision_lower = decision.lower().strip()
403
+ if 'accept' in decision_lower:
404
+ decision = 'accept'
405
+ elif 'reject' in decision_lower:
406
+ decision = 'reject'
407
+ elif 'undecided' in decision_lower:
408
+ decision = 'undecided'
409
+ else:
410
+ decision = decision_lower
411
+ else:
412
+ decision = None
413
+
414
+ return {
415
+ 'soundness': parsed.get('soundness'),
416
+ 'presentation': parsed.get('presentation'),
417
+ 'rating': parsed.get('rating'),
418
+ 'confidence': parsed.get('confidence'),
419
+ 'decision': decision
420
+ }
421
+ except Exception as e:
422
+ print(f"Warning: Failed to parse review text: {e}")
423
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
424
+
425
+
426
+ def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
427
+ """Calculate Mean Squared Error for a single value."""
428
+ if predicted is None or ground_truth is None:
429
+ return None
430
+ return (predicted - ground_truth) ** 2
431
+
432
+
433
+ def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
434
+ """Calculate Mean Absolute Error for a single value."""
435
+ if predicted is None or ground_truth is None:
436
+ return None
437
+ return abs(predicted - ground_truth)
438
+
439
+
440
+ def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
441
+ """
442
+ Normalize a float score to the nearest discrete value based on scale type.
443
+ Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
444
+
445
+ Args:
446
+ score: The float score to normalize (can be None)
447
+ scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
448
+ or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
449
+
450
+ Returns:
451
+ Normalized discrete score, or None if input is None
452
+ """
453
+ if score is None:
454
+ return None
455
+
456
+ try:
457
+ score = float(score)
458
+ except (ValueError, TypeError):
459
+ return None
460
+
461
+ if scale_type == '0-5':
462
+ # Discrete values: 0, 1, 2, 3, 4, 5
463
+ discrete_values = [0, 1, 2, 3, 4, 5]
464
+ # Clamp to valid range
465
+ score = max(0, min(5, score))
466
+ # Find nearest discrete value, with round-half-up tie-breaking
467
+ # For ties, prefer the higher value
468
+ best_value = None
469
+ best_distance = float('inf')
470
+ for val in discrete_values:
471
+ distance = abs(val - score)
472
+ if distance < best_distance:
473
+ best_distance = distance
474
+ best_value = val
475
+ elif distance == best_distance and val > best_value:
476
+ # Tie-breaking: prefer higher value (round-half-up)
477
+ best_value = val
478
+ return best_value
479
+ elif scale_type == '0-10':
480
+ # Discrete values: 0, 2, 4, 6, 8, 10
481
+ discrete_values = [0, 2, 4, 6, 8, 10]
482
+ # Clamp to valid range
483
+ score = max(0, min(10, score))
484
+ # Find nearest discrete value, with round-half-up tie-breaking
485
+ best_value = None
486
+ best_distance = float('inf')
487
+ for val in discrete_values:
488
+ distance = abs(val - score)
489
+ if distance < best_distance:
490
+ best_distance = distance
491
+ best_value = val
492
+ elif distance == best_distance and val > best_value:
493
+ # Tie-breaking: prefer higher value (round-half-up)
494
+ best_value = val
495
+ return best_value
496
+ else:
497
+ raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
498
+
499
+
500
+ def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
501
+ """
502
+ Normalize all scores in a dictionary to their appropriate discrete scales.
503
+
504
+ Args:
505
+ scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
506
+
507
+ Returns:
508
+ Dictionary with normalized scores
509
+ """
510
+ normalized = {}
511
+
512
+ # soundness, presentation, confidence use 0-5 scale
513
+ for key in ['soundness', 'presentation', 'confidence']:
514
+ normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
515
+
516
+ # rating uses 0-10 scale
517
+ normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
518
+
519
+ return normalized
520
+
521
+
522
+ def calculate_score_metrics(
523
+ model_scores: Dict[str, float],
524
+ ground_truth_scores: Dict[str, float],
525
+ normalize: bool = False
526
+ ) -> Dict[str, Any]:
527
+ """
528
+ Calculate MSE and MAE metrics for each scoring dimension.
529
+
530
+ Args:
531
+ model_scores: Dictionary with model scores
532
+ ground_truth_scores: Dictionary with ground truth scores
533
+ normalize: If True, normalize scores to discrete scales before computing metrics
534
+
535
+ Returns:
536
+ Dictionary with MSE, MAE metrics and optionally normalized scores
537
+ """
538
+ dimensions = ['soundness', 'presentation', 'rating', 'confidence']
539
+
540
+ # Normalize scores to discrete scales if requested
541
+ if normalize:
542
+ model_scores_normalized = normalize_scores_dict(model_scores)
543
+ gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
544
+ else:
545
+ model_scores_normalized = model_scores
546
+ gt_scores_normalized = ground_truth_scores
547
+
548
+ mse_values = {}
549
+ mae_values = {}
550
+ valid_count = 0
551
+
552
+ for dim in dimensions:
553
+ # Use normalized scores for metric calculation
554
+ mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
555
+ mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
556
+ mse_values[f'{dim}_mse'] = mse
557
+ mae_values[f'{dim}_mae'] = mae
558
+ if mse is not None:
559
+ valid_count += 1
560
+
561
+ overall_error = sum([v for v in mse_values.values() if v is not None])
562
+
563
+ result = {
564
+ **mse_values,
565
+ **mae_values,
566
+ 'overall_error': overall_error if valid_count > 0 else None,
567
+ 'valid_dimensions': valid_count
568
+ }
569
+
570
+ # Include normalized scores in result for transparency (only if normalize=True)
571
+ if normalize:
572
+ result['model_scores_normalized'] = model_scores_normalized
573
+ result['gt_scores_normalized'] = gt_scores_normalized
574
+
575
+ return result
576
+
577
+
578
+ def normalize_score_value(value):
579
+ """Normalize score value to float, handling string representations."""
580
+ if value is None:
581
+ return None
582
+ if isinstance(value, (int, float)):
583
+ return float(value)
584
+ if isinstance(value, str):
585
+ # Try to extract numeric value from string (e.g., "2.75" -> 2.75)
586
+ try:
587
+ import re
588
+ match = re.search(r'(\d+\.?\d*)', value)
589
+ if match:
590
+ return float(match.group(1))
591
+ except:
592
+ pass
593
+ return None
594
+
595
+
596
+ def normalize_decision(decision):
597
+ """Normalize decision string to standard format."""
598
+ if decision is None:
599
+ return None
600
+ decision_lower = str(decision).lower().strip()
601
+ if 'accept' in decision_lower:
602
+ return 'accept'
603
+ elif 'reject' in decision_lower:
604
+ return 'reject'
605
+ elif 'undecided' in decision_lower:
606
+ return 'undecided'
607
+ else:
608
+ return decision_lower
609
+
610
+
611
+ def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
612
+ """
613
+ Extract scores from a structured dictionary (scores or initial_scores format).
614
+
615
+ Args:
616
+ scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
617
+
618
+ Returns:
619
+ Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
620
+ """
621
+ if not scores_dict:
622
+ return {
623
+ 'soundness': None,
624
+ 'presentation': None,
625
+ 'rating': None,
626
+ 'confidence': None,
627
+ 'decision': None
628
+ }
629
+
630
+ return {
631
+ 'soundness': normalize_score_value(scores_dict.get('soundness')),
632
+ 'presentation': normalize_score_value(scores_dict.get('presentation')),
633
+ 'rating': normalize_score_value(scores_dict.get('rating')),
634
+ 'confidence': normalize_score_value(scores_dict.get('confidence')),
635
+ 'decision': normalize_decision(scores_dict.get('decision'))
636
+ }
637
+
638
+
639
+ def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
640
+ """
641
+ Evaluate a single entry by extracting scores and calculating metrics.
642
+
643
+ Args:
644
+ entry: Evaluation entry containing model_review, scores, initial_scores, etc.
645
+ use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
646
+
647
+ Returns:
648
+ Dict containing evaluation metrics
649
+ """
650
+ entry_id = entry.get('id', 'unknown')
651
+ model_review = entry.get('model_review', '')
652
+ format_type = entry.get('format', 'unknown')
653
+
654
+ # Extract scores based on format
655
+ model_scores = {}
656
+ model_decision = None
657
+
658
+ if format_type == 'refined' and not use_initial_scores:
659
+ # Use refined scores from structured data
660
+ scores_dict = entry.get('scores', {})
661
+ model_data = extract_scores_from_dict(scores_dict)
662
+ model_scores = {
663
+ 'soundness': model_data.get('soundness'),
664
+ 'presentation': model_data.get('presentation'),
665
+ 'rating': model_data.get('rating'),
666
+ 'confidence': model_data.get('confidence')
667
+ }
668
+ model_decision = model_data.get('decision')
669
+ elif format_type == 'refined' and use_initial_scores:
670
+ # Use initial scores from structured data
671
+ initial_scores_dict = entry.get('initial_scores', {})
672
+ model_data = extract_scores_from_dict(initial_scores_dict)
673
+ model_scores = {
674
+ 'soundness': model_data.get('soundness'),
675
+ 'presentation': model_data.get('presentation'),
676
+ 'rating': model_data.get('rating'),
677
+ 'confidence': model_data.get('confidence')
678
+ }
679
+ model_decision = model_data.get('decision')
680
+ elif format_type == 'original':
681
+ # Use initial scores from structured data
682
+ initial_scores_dict = entry.get('initial_scores', {})
683
+ model_data = extract_scores_from_dict(initial_scores_dict)
684
+ model_scores = {
685
+ 'soundness': model_data.get('soundness'),
686
+ 'presentation': model_data.get('presentation'),
687
+ 'rating': model_data.get('rating'),
688
+ 'confidence': model_data.get('confidence')
689
+ }
690
+ model_decision = model_data.get('decision')
691
+
692
+ # Fallback: If confidence is missing from structured data, try to extract from review text
693
+ # (meta_review may not have confidence field, but review text might)
694
+ if model_scores.get('confidence') is None and model_review:
695
+ try:
696
+ review_data = extract_scores_from_review(model_review)
697
+ if review_data.get('confidence') is not None:
698
+ model_scores['confidence'] = review_data.get('confidence')
699
+ except Exception:
700
+ pass # Keep confidence as None if extraction fails
701
+ else:
702
+ # Fallback: extract from markdown review text
703
+ model_data = extract_scores_from_review(model_review)
704
+ model_scores = {
705
+ 'soundness': model_data.get('soundness'),
706
+ 'presentation': model_data.get('presentation'),
707
+ 'rating': model_data.get('rating'),
708
+ 'confidence': model_data.get('confidence')
709
+ }
710
+ model_decision = model_data.get('decision')
711
+
712
+ # Get ground truth scores from golden_review ONLY
713
+ # Ground truth must ONLY come from golden_review, never from model output
714
+ # If extraction fails, leave fields as None (do not use model_review as fallback)
715
+ ground_truth_review = entry.get('golden_review', '')
716
+ ground_truth_scores = {}
717
+ gt_decision = None
718
+
719
+ if not ground_truth_review:
720
+ print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
721
+ else:
722
+ try:
723
+ # Extract scores from golden_review markdown text
724
+ gt_data = extract_scores_from_review(ground_truth_review)
725
+ if not gt_data:
726
+ print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
727
+ else:
728
+ ground_truth_scores = {
729
+ 'soundness': gt_data.get('soundness'),
730
+ 'presentation': gt_data.get('presentation'),
731
+ 'rating': gt_data.get('rating'),
732
+ 'confidence': gt_data.get('confidence')
733
+ }
734
+ gt_decision = normalize_decision(gt_data.get('decision'))
735
+ # Note: If any field is None, it stays None - we do NOT use model_review as fallback
736
+ # Using model output as ground truth would inflate evaluation scores
737
+ except Exception as e:
738
+ print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
739
+ print(f" Ground truth scores will be empty. Error: {str(e)}")
740
+
741
+ # Calculate MSE and MAE metrics (with optional normalization in strict mode)
742
+ score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
743
+
744
+ # Calculate decision accuracy
745
+ decision_match = False
746
+ decision_accuracy = None
747
+ if model_decision is not None and gt_decision is not None:
748
+ model_decision_normalized = normalize_decision(model_decision)
749
+ decision_match = (model_decision_normalized == gt_decision)
750
+ decision_accuracy = 1.0 if decision_match else 0.0
751
+
752
+ result = {
753
+ 'id': entry_id,
754
+ 'format': format_type,
755
+ 'model_soundness': model_scores.get('soundness'),
756
+ 'model_presentation': model_scores.get('presentation'),
757
+ 'model_rating': model_scores.get('rating'),
758
+ 'model_confidence': model_scores.get('confidence'),
759
+ 'model_decision': model_decision,
760
+ 'gt_soundness': ground_truth_scores.get('soundness'),
761
+ 'gt_presentation': ground_truth_scores.get('presentation'),
762
+ 'gt_rating': ground_truth_scores.get('rating'),
763
+ 'gt_confidence': ground_truth_scores.get('confidence'),
764
+ 'gt_decision': gt_decision,
765
+ 'decision_match': decision_match,
766
+ 'decision_accuracy': decision_accuracy,
767
+ **score_metrics
768
+ }
769
+
770
+ # Add prefix to indicate which scores were used
771
+ if format_type == 'refined':
772
+ if use_initial_scores:
773
+ result['score_type'] = 'initial'
774
+ else:
775
+ result['score_type'] = 'refined'
776
+ else:
777
+ result['score_type'] = 'auto'
778
+
779
+ return result
780
+
781
+
782
+ def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
783
+ """Calculate pairwise accuracy for each metric by comparing rankings."""
784
+ if len(paper_scores) < 2:
785
+ return {}
786
+
787
+ total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
788
+ correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
789
+
790
+ for paper1, paper2 in combinations(paper_scores, 2):
791
+ # Check rating ranking
792
+ if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
793
+ paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
794
+ total_valid_pairs['rating'] += 1
795
+ true_order = paper1['true_rating'] > paper2['true_rating']
796
+ pred_order = paper1['pred_rating'] > paper2['pred_rating']
797
+ if true_order == pred_order:
798
+ correct_pairs['rating'] += 1
799
+
800
+ # Similar for other dimensions...
801
+ # (abbreviated for space, similar logic for soundness, presentation, confidence)
802
+ for metric in ['soundness', 'presentation', 'confidence']:
803
+ true_key = f'true_{metric}'
804
+ pred_key = f'pred_{metric}'
805
+ if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
806
+ paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
807
+ total_valid_pairs[metric] += 1
808
+ true_order = paper1[true_key] > paper2[true_key]
809
+ pred_order = paper1[pred_key] > paper2[pred_key]
810
+ if true_order == pred_order:
811
+ correct_pairs[metric] += 1
812
+
813
+ pairwise_accuracies = {
814
+ metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
815
+ for metric in ['rating', 'soundness', 'presentation', 'confidence']
816
+ }
817
+
818
+ return pairwise_accuracies
819
+
820
+
821
+ # ============================================================================
822
+ # Data Loading Functions
823
+ # ============================================================================
824
+
825
+ def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
826
+ """Load rubrics JSON and create lookup by id."""
827
+ with open(rubrics_path, 'r', encoding='utf-8') as f:
828
+ data = json.load(f)
829
+
830
+ if isinstance(data, list):
831
+ return {item['id']: item for item in data}
832
+ elif isinstance(data, dict):
833
+ return data
834
+ else:
835
+ raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
836
+
837
+
838
+ def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
839
+ """
840
+ Load model reviews JSON and extract reviews by id.
841
+
842
+ Supports two input formats:
843
+ 1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
844
+ 2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
845
+
846
+ Args:
847
+ reviews_path: Path to JSON file containing model reviews
848
+ format_override: Optional format override ('refined', 'original', or None for auto-detect)
849
+
850
+ Returns:
851
+ Dict mapping paper_id to dict containing:
852
+ - 'review': review text (markdown)
853
+ - 'scores': refined scores dict (if available)
854
+ - 'initial_scores': initial scores dict (if available)
855
+ - 'format': 'refined' or 'original'
856
+ """
857
+ with open(reviews_path, 'r', encoding='utf-8') as f:
858
+ data = json.load(f)
859
+
860
+ if isinstance(data, dict):
861
+ data = list(data.values())
862
+
863
+ reviews_dict = {}
864
+ for item in data:
865
+ item_id = None
866
+ review_text = ''
867
+ scores = None
868
+ initial_scores = None
869
+ format_type = None
870
+
871
+ # Use format override if provided, otherwise auto-detect
872
+ if format_override and format_override != 'auto':
873
+ # Force use specified format
874
+ if format_override == 'refined':
875
+ item_id = item.get('paper_id') or item.get('id')
876
+ if not item_id:
877
+ continue
878
+ format_type = 'refined'
879
+ review_text = item.get('review_markdown', '') or item.get('review', '')
880
+ scores = item.get('scores', {})
881
+ initial_scores = item.get('initial_scores', {})
882
+ elif format_override == 'original':
883
+ item_id = item.get('id')
884
+ if not item_id:
885
+ continue
886
+ format_type = 'original'
887
+ model_prediction = item.get('model_prediction', {})
888
+ meta_review = model_prediction.get('meta_review', {})
889
+ review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
890
+ initial_scores = {
891
+ 'rating': meta_review.get('rating'),
892
+ 'soundness': meta_review.get('soundness'),
893
+ 'presentation': meta_review.get('presentation'),
894
+ 'contribution': meta_review.get('contribution'),
895
+ 'decision': model_prediction.get('decision'),
896
+ }
897
+ else:
898
+ raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
899
+ else:
900
+ # Auto-detect format
901
+ if "paper_id" in item:
902
+ # Refined format (from refinement pipeline)
903
+ item_id = item.get('paper_id')
904
+ if not item_id:
905
+ continue
906
+
907
+ # Check if this is refined format (has scores and initial_scores)
908
+ if 'scores' in item and 'initial_scores' in item:
909
+ format_type = 'refined'
910
+ review_text = item.get('review_markdown', '') or item.get('review', '')
911
+ scores = item.get('scores', {})
912
+ initial_scores = item.get('initial_scores', {})
913
+ else:
914
+ # Standard format with paper_id
915
+ format_type = 'standard'
916
+ review_text = item.get('review_markdown', '') or item.get('review', '')
917
+ elif "model_prediction" in item:
918
+ # Original format (like ours.json) or agenticreview format
919
+ item_id = item.get('id')
920
+ if not item_id:
921
+ continue
922
+
923
+ format_type = 'original'
924
+ model_prediction = item.get('model_prediction', {})
925
+
926
+ review_text = model_prediction.get('raw_text', '')
927
+
928
+ if review_text is None:
929
+ continue
930
+
931
+ # Detect format: agenticreview has raw_text as string with "Overall rating: x"
932
+ # ai_researcher format has raw_text as dict or JSON string with structured fields
933
+ is_agenticreview = False
934
+ if isinstance(review_text, str):
935
+ # Check if it's a JSON string (ai_researcher format)
936
+ try:
937
+ parsed_json = json.loads(review_text)
938
+ if isinstance(parsed_json, dict) and any(key in parsed_json for key in ["Summary", "Strengths", "Overall", "Decision"]):
939
+ # It's ai_researcher format
940
+ review_text = parsed_json
941
+ review_text, meta_review = convert_ai_researcher(review_text)
942
+ else:
943
+ # It's agenticreview format (plain text with "Overall rating: x")
944
+ is_agenticreview = True
945
+ except (json.JSONDecodeError, TypeError):
946
+ # Not JSON, check if it contains "Overall rating:" pattern
947
+ if re.search(r'Overall\s+rating\s*[:=]', review_text, re.IGNORECASE):
948
+ is_agenticreview = True
949
+ else:
950
+ # Try to parse as ai_researcher anyway
951
+ try:
952
+ review_text = json.loads(review_text)
953
+ review_text, meta_review = convert_ai_researcher(review_text)
954
+ except:
955
+ review_text = 'Empty Review'
956
+ meta_review = {}
957
+ elif isinstance(review_text, dict):
958
+ # It's ai_researcher format (dict)
959
+ review_text, meta_review = convert_ai_researcher(review_text)
960
+ else:
961
+ review_text = 'Empty Review'
962
+ meta_review = {}
963
+
964
+ # Handle agenticreview format
965
+ if is_agenticreview:
966
+ review_text, meta_review = convert_agenticreview(review_text)
967
+
968
+ # Extract initial scores
969
+ # Use meta_review as primary source (from convert_ai_researcher or convert_agenticreview)
970
+ # Fallback to model_prediction.get('decision') if not in meta_review
971
+ initial_scores = {
972
+ 'rating': meta_review.get('rating'),
973
+ 'soundness': meta_review.get('soundness'),
974
+ 'presentation': meta_review.get('presentation'),
975
+ 'contribution': meta_review.get('contribution'),
976
+ 'confidence': meta_review.get('confidence'),
977
+ 'decision': meta_review.get('decision') or model_prediction.get('decision'),
978
+ }
979
+ else:
980
+ # Legacy format (pred_fast_mode)
981
+ item_id = item.get('id')
982
+ if not item_id:
983
+ continue
984
+
985
+ format_type = 'legacy'
986
+ review_dict = item.get('pred_fast_mode', {})
987
+ if isinstance(review_dict, dict):
988
+ review_text = review_dict.get('raw_text', '')
989
+ else:
990
+ review_text = str(review_dict)
991
+
992
+ # Extract review content from the review text field
993
+ try:
994
+ if review_text:
995
+ # extracted_review = ReviewProcessor.extract_review_content(review_text)
996
+ extracted_review = review_text
997
+ else:
998
+ extracted_review = ''
999
+
1000
+ reviews_dict[item_id] = {
1001
+ 'review': extracted_review,
1002
+ 'scores': scores,
1003
+ 'initial_scores': initial_scores,
1004
+ 'format': format_type
1005
+ }
1006
+ except Exception as e:
1007
+ print(f"[WARN] Failed to extract review for {item_id}: {e}")
1008
+ continue
1009
+
1010
+ return reviews_dict
1011
+
1012
+
1013
+ def combine_rubrics_and_reviews(
1014
+ rubrics_data: Dict[str, Dict[str, Any]],
1015
+ reviews_dict: Dict[str, Dict[str, Any]]
1016
+ ) -> List[Dict[str, Any]]:
1017
+ """
1018
+ Combine rubrics and reviews into evaluation entries.
1019
+
1020
+ Args:
1021
+ rubrics_data: Dict mapping paper_id to rubric entry
1022
+ reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
1023
+
1024
+ Returns:
1025
+ List of evaluation entries with model_review, scores, initial_scores, and format info
1026
+ """
1027
+ combined = []
1028
+ missing_reviews = []
1029
+
1030
+ for paper_id, rubric_entry in rubrics_data.items():
1031
+ review_data = reviews_dict.get(paper_id)
1032
+ if not review_data or not review_data.get('review'):
1033
+ missing_reviews.append(paper_id)
1034
+ continue
1035
+
1036
+ entry = {
1037
+ 'id': paper_id,
1038
+ 'paper_context': rubric_entry.get('paper_context', ''),
1039
+ 'decision': rubric_entry.get('decision', ''),
1040
+ 'golden_review': rubric_entry.get('golden_review', ''),
1041
+ 'rubrics': rubric_entry.get('rubrics', []),
1042
+ 'model_review': review_data.get('review', ''),
1043
+ 'scores': review_data.get('scores'), # Refined scores (if available)
1044
+ 'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
1045
+ 'format': review_data.get('format', 'unknown') # Format type
1046
+ }
1047
+ combined.append(entry)
1048
+
1049
+ if missing_reviews:
1050
+ print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
1051
+
1052
+ return combined
1053
+
1054
+
1055
+ # ============================================================================
1056
+ # LLM Service Configuration
1057
+ # ============================================================================
1058
+
1059
+ def load_llm_config(config_path: str) -> Dict[str, Any]:
1060
+ """Load LLM configuration from YAML file."""
1061
+ with open(config_path, 'r', encoding='utf-8') as f:
1062
+ config = yaml.safe_load(f)
1063
+ return config
1064
+
1065
+
1066
+ def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
1067
+ """Create LLM service from configuration."""
1068
+ mode = config.get('mode', 'gpt').lower()
1069
+
1070
+ if mode == 'gpt':
1071
+ gpt_config = config.get('gpt', {})
1072
+ api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
1073
+ if not api_key:
1074
+ raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
1075
+
1076
+ service = GPTService(
1077
+ api_key=api_key,
1078
+ model_name=gpt_config.get('model_name', 'gpt-4o'),
1079
+ base_url=gpt_config.get('base_url'),
1080
+ timeout=gpt_config.get('timeout', 300)
1081
+ )
1082
+ return service
1083
+
1084
+ elif mode == 'vllm':
1085
+ vllm_config = config.get('vllm', {})
1086
+ service = VLLMService(
1087
+ base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
1088
+ api_key=vllm_config.get('api_key', 'dummy-key'),
1089
+ model_name=vllm_config.get('model_name'),
1090
+ timeout=vllm_config.get('timeout', 300),
1091
+ max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
1092
+ max_retries=vllm_config.get('max_retries', 3),
1093
+ retry_delay=vllm_config.get('retry_delay', 1.0),
1094
+ retry_backoff=vllm_config.get('retry_backoff', 2.0)
1095
+ )
1096
+ return service
1097
+
1098
+ else:
1099
+ raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
1100
+
1101
+
1102
+ # ============================================================================
1103
+ # Main Evaluation Functions
1104
+ # ============================================================================
1105
+
1106
+ def run_semantic_evaluation(
1107
+ evaluation_data: List[Dict[str, Any]],
1108
+ prompt_template: str,
1109
+ llm_service: LLMService,
1110
+ max_workers: int
1111
+ ) -> tuple:
1112
+ """Run semantic evaluation and return results and summary."""
1113
+ print(f"\n{'='*80}")
1114
+ print("RUNNING SEMANTIC EVALUATION")
1115
+ print(f"{'='*80}")
1116
+ print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
1117
+
1118
+ results = []
1119
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
1120
+ future_to_entry = {
1121
+ executor.submit(
1122
+ evaluate_review_semantic,
1123
+ entry,
1124
+ entry['paper_context'],
1125
+ prompt_template,
1126
+ llm_service
1127
+ ): entry
1128
+ for entry in evaluation_data
1129
+ }
1130
+
1131
+ for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
1132
+ try:
1133
+ result = future.result()
1134
+ results.append(result)
1135
+ except Exception as e:
1136
+ entry = future_to_entry[future]
1137
+ print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
1138
+ results.append({
1139
+ 'id': entry.get('id', 'unknown'),
1140
+ 'raw_scores': {},
1141
+ 'weighted_scores': {},
1142
+ 'total_score': 0.0,
1143
+ 'error': str(e),
1144
+ 'raw_response': ''
1145
+ })
1146
+
1147
+ # Calculate statistics
1148
+ valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
1149
+ review_scores = [r.get('total_score', 0.0) for r in valid_results]
1150
+
1151
+ summary = {
1152
+ 'total_entries': len(results),
1153
+ 'valid_entries': len(valid_results),
1154
+ 'failed_entries': len(results) - len(valid_results)
1155
+ }
1156
+
1157
+ if review_scores:
1158
+ summary['overall_score'] = {
1159
+ 'mean': sum(review_scores) / len(review_scores),
1160
+ 'min': min(review_scores),
1161
+ 'max': max(review_scores)
1162
+ }
1163
+
1164
+ # Calculate per-rubric statistics (extract rubric titles from first entry)
1165
+ if evaluation_data and evaluation_data[0].get('rubrics'):
1166
+ rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
1167
+ per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
1168
+ summary['per_rubric_statistics'] = per_rubric_stats
1169
+
1170
+ return results, summary
1171
+
1172
+
1173
+ def run_auto_metric_evaluation(
1174
+ evaluation_data: List[Dict[str, Any]],
1175
+ strict_mode: bool = False
1176
+ ) -> tuple:
1177
+ """
1178
+ Run auto-metric evaluation and return results and summary.
1179
+
1180
+ For refined format (has scores and initial_scores), evaluates both:
1181
+ - Refined scores evaluation
1182
+ - Initial scores evaluation
1183
+
1184
+ For original format (only initial_scores), evaluates:
1185
+ - Initial scores evaluation only
1186
+
1187
+ Returns:
1188
+ Tuple of (results_list, summary_dict)
1189
+ - results_list: List of evaluation results (may contain both refined and initial results for refined format)
1190
+ - summary_dict: Summary statistics
1191
+ """
1192
+ print(f"\n{'='*80}")
1193
+ print("RUNNING AUTO-METRIC EVALUATION")
1194
+ print(f"{'='*80}")
1195
+ print(f"Evaluating {len(evaluation_data)} entries...")
1196
+
1197
+ # Detect format types
1198
+ refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
1199
+ original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
1200
+
1201
+ if refined_format_count > 0:
1202
+ print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
1203
+ if original_format_count > 0:
1204
+ print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
1205
+
1206
+ results = []
1207
+ for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
1208
+ format_type = entry.get('format', 'unknown')
1209
+
1210
+ if format_type == 'refined':
1211
+ # Evaluate both refined scores and initial scores
1212
+ try:
1213
+ entry_id = entry.get('id', 'unknown')
1214
+
1215
+ # Evaluate refined scores
1216
+ refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1217
+ refined_result['paper_id'] = entry_id # Keep original paper_id
1218
+ refined_result['id'] = f"{entry_id}_refined"
1219
+ results.append(refined_result)
1220
+
1221
+ # Evaluate initial scores
1222
+ initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
1223
+ initial_result['paper_id'] = entry_id # Keep original paper_id
1224
+ initial_result['id'] = f"{entry_id}_initial"
1225
+ results.append(initial_result)
1226
+ except Exception as e:
1227
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1228
+ results.append({
1229
+ 'id': entry.get('id', 'unknown'),
1230
+ 'error': str(e)
1231
+ })
1232
+ else:
1233
+ # Evaluate initial scores only (or extract from markdown)
1234
+ try:
1235
+ result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1236
+ results.append(result)
1237
+ except Exception as e:
1238
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1239
+ results.append({
1240
+ 'id': entry.get('id', 'unknown'),
1241
+ 'error': str(e)
1242
+ })
1243
+
1244
+ # Calculate statistics
1245
+ valid_results = [r for r in results if 'error' not in r]
1246
+ mse_results = [r for r in valid_results if r.get('overall_error') is not None]
1247
+
1248
+ # Separate refined and initial results for refined format
1249
+ refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
1250
+ initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
1251
+ auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
1252
+
1253
+ summary = {
1254
+ 'total_entries': len(results),
1255
+ 'valid_entries': len(valid_results),
1256
+ 'mse_entries': len(mse_results),
1257
+ 'refined_results_count': len(refined_results),
1258
+ 'initial_results_count': len(initial_results),
1259
+ 'auto_results_count': len(auto_results)
1260
+ }
1261
+
1262
+ # Calculate MSE/MAE statistics
1263
+ # For refined format, only use refined results for overall statistics (avoid double counting)
1264
+ # For other formats, use all results
1265
+ if refined_format_count > 0:
1266
+ # Refined format: use only refined results for overall statistics
1267
+ stats_results = [r for r in refined_results if r.get('overall_error') is not None]
1268
+ else:
1269
+ # Original/other formats: use all results
1270
+ stats_results = mse_results
1271
+
1272
+ if stats_results:
1273
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1274
+ mse_stats = {}
1275
+ mae_stats = {}
1276
+
1277
+ for dim in dimensions:
1278
+ mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
1279
+ mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
1280
+
1281
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1282
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1283
+
1284
+ if mse_clean:
1285
+ mse_stats[dim] = {
1286
+ 'mean': sum(mse_clean) / len(mse_clean),
1287
+ 'count': len(mse_clean)
1288
+ }
1289
+ if mae_clean:
1290
+ mae_stats[dim] = {
1291
+ 'mean': sum(mae_clean) / len(mae_clean),
1292
+ 'count': len(mae_clean)
1293
+ }
1294
+
1295
+ overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
1296
+ overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
1297
+
1298
+ if overall_clean:
1299
+ summary['overall_error'] = {
1300
+ 'mean': sum(overall_clean) / len(overall_clean),
1301
+ 'count': len(overall_clean)
1302
+ }
1303
+
1304
+ summary['mse_statistics'] = mse_stats
1305
+ summary['mae_statistics'] = mae_stats
1306
+
1307
+ # Calculate separate statistics for refined and initial results
1308
+ if refined_results:
1309
+ refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
1310
+ if refined_mse_results:
1311
+ refined_mse_stats = {}
1312
+ refined_mae_stats = {}
1313
+ for dim in dimensions:
1314
+ mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
1315
+ mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
1316
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1317
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1318
+ if mse_clean:
1319
+ refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1320
+ if mae_clean:
1321
+ refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1322
+ summary['refined_mse_statistics'] = refined_mse_stats
1323
+ summary['refined_mae_statistics'] = refined_mae_stats
1324
+
1325
+ if initial_results:
1326
+ initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
1327
+ if initial_mse_results:
1328
+ initial_mse_stats = {}
1329
+ initial_mae_stats = {}
1330
+ for dim in dimensions:
1331
+ mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
1332
+ mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
1333
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1334
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1335
+ if mse_clean:
1336
+ initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1337
+ if mae_clean:
1338
+ initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1339
+ summary['initial_mse_statistics'] = initial_mse_stats
1340
+ summary['initial_mae_statistics'] = initial_mae_stats
1341
+
1342
+ # Calculate Spearman correlations
1343
+ def filter_valid_pairs(true_list, pred_list):
1344
+ filtered_true = []
1345
+ filtered_pred = []
1346
+ for t, p in zip(true_list, pred_list):
1347
+ if (t is not None and p is not None and
1348
+ not (isinstance(t, float) and math.isnan(t)) and
1349
+ not (isinstance(p, float) and math.isnan(p))):
1350
+ filtered_true.append(t)
1351
+ filtered_pred.append(p)
1352
+ return filtered_true, filtered_pred
1353
+
1354
+ # Calculate Spearman correlations
1355
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1356
+ # For other formats, use all results
1357
+ if refined_format_count > 0:
1358
+ # Calculate refined spearman correlations
1359
+ refined_spearman_stats = {}
1360
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1361
+ for dim in dimensions:
1362
+ true_values = [r.get(f'gt_{dim}') for r in refined_results]
1363
+ pred_values = [r.get(f'model_{dim}') for r in refined_results]
1364
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1365
+
1366
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1367
+ try:
1368
+ corr, _ = spearmanr(true_clean, pred_clean)
1369
+ if not math.isnan(corr):
1370
+ refined_spearman_stats[dim] = {
1371
+ 'correlation': corr,
1372
+ 'count': len(true_clean)
1373
+ }
1374
+ except Exception:
1375
+ pass
1376
+
1377
+ # Calculate initial spearman correlations
1378
+ initial_spearman_stats = {}
1379
+ for dim in dimensions:
1380
+ true_values = [r.get(f'gt_{dim}') for r in initial_results]
1381
+ pred_values = [r.get(f'model_{dim}') for r in initial_results]
1382
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1383
+
1384
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1385
+ try:
1386
+ corr, _ = spearmanr(true_clean, pred_clean)
1387
+ if not math.isnan(corr):
1388
+ initial_spearman_stats[dim] = {
1389
+ 'correlation': corr,
1390
+ 'count': len(true_clean)
1391
+ }
1392
+ except Exception:
1393
+ pass
1394
+
1395
+ # Use refined for overall statistics (avoid double counting)
1396
+ summary['spearman_correlations'] = refined_spearman_stats
1397
+ summary['refined_spearman_correlations'] = refined_spearman_stats
1398
+ summary['initial_spearman_correlations'] = initial_spearman_stats
1399
+ else:
1400
+ # Original/other formats: use all results
1401
+ correlation_results = valid_results
1402
+ spearman_stats = {}
1403
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1404
+ for dim in dimensions:
1405
+ true_values = [r.get(f'gt_{dim}') for r in correlation_results]
1406
+ pred_values = [r.get(f'model_{dim}') for r in correlation_results]
1407
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1408
+
1409
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1410
+ try:
1411
+ corr, _ = spearmanr(true_clean, pred_clean)
1412
+ if not math.isnan(corr):
1413
+ spearman_stats[dim] = {
1414
+ 'correlation': corr,
1415
+ 'count': len(true_clean)
1416
+ }
1417
+ except Exception:
1418
+ pass
1419
+
1420
+ summary['spearman_correlations'] = spearman_stats
1421
+
1422
+ # Calculate Decision metrics
1423
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1424
+ # For other formats, use all results
1425
+ if refined_format_count > 0:
1426
+ # Calculate refined decision metrics
1427
+ refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1428
+ if refined_decision_results:
1429
+ true_decisions = []
1430
+ pred_decisions = []
1431
+ decision_acc = []
1432
+
1433
+ for r in refined_decision_results:
1434
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1435
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1436
+
1437
+ if 'accept' in pred_decision:
1438
+ pred_binary = 1
1439
+ else:
1440
+ pred_binary = 0
1441
+
1442
+ if 'accept' in gt_decision:
1443
+ gt_binary = 1
1444
+ else:
1445
+ gt_binary = 0
1446
+
1447
+ true_decisions.append(gt_binary)
1448
+ pred_decisions.append(pred_binary)
1449
+
1450
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1451
+ decision_acc.append(1.0)
1452
+ else:
1453
+ decision_acc.append(0.0)
1454
+
1455
+ if decision_acc:
1456
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1457
+ try:
1458
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1459
+ refined_decision_metrics = {
1460
+ 'accuracy': decision_accuracy,
1461
+ 'f1_macro': f1_score,
1462
+ 'count': len(decision_acc)
1463
+ }
1464
+ except Exception:
1465
+ refined_decision_metrics = {
1466
+ 'accuracy': decision_accuracy,
1467
+ 'count': len(decision_acc)
1468
+ }
1469
+ summary['refined_decision_metrics'] = refined_decision_metrics
1470
+ summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
1471
+
1472
+ # Calculate initial decision metrics
1473
+ initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1474
+ if initial_decision_results:
1475
+ true_decisions = []
1476
+ pred_decisions = []
1477
+ decision_acc = []
1478
+
1479
+ for r in initial_decision_results:
1480
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1481
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1482
+
1483
+ if 'accept' in pred_decision:
1484
+ pred_binary = 1
1485
+ else:
1486
+ pred_binary = 0
1487
+
1488
+ if 'accept' in gt_decision:
1489
+ gt_binary = 1
1490
+ else:
1491
+ gt_binary = 0
1492
+
1493
+ true_decisions.append(gt_binary)
1494
+ pred_decisions.append(pred_binary)
1495
+
1496
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1497
+ decision_acc.append(1.0)
1498
+ else:
1499
+ decision_acc.append(0.0)
1500
+
1501
+ if decision_acc:
1502
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1503
+ try:
1504
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1505
+ initial_decision_metrics = {
1506
+ 'accuracy': decision_accuracy,
1507
+ 'f1_macro': f1_score,
1508
+ 'count': len(decision_acc)
1509
+ }
1510
+ except Exception:
1511
+ initial_decision_metrics = {
1512
+ 'accuracy': decision_accuracy,
1513
+ 'count': len(decision_acc)
1514
+ }
1515
+ summary['initial_decision_metrics'] = initial_decision_metrics
1516
+ else:
1517
+ # Original/other formats: use all results
1518
+ decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1519
+ if decision_results:
1520
+ true_decisions = []
1521
+ pred_decisions = []
1522
+ decision_acc = []
1523
+
1524
+ for r in decision_results:
1525
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1526
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1527
+
1528
+ if 'accept' in pred_decision:
1529
+ pred_binary = 1
1530
+ else:
1531
+ pred_binary = 0
1532
+
1533
+ if 'accept' in gt_decision:
1534
+ gt_binary = 1
1535
+ else:
1536
+ gt_binary = 0
1537
+
1538
+ true_decisions.append(gt_binary)
1539
+ pred_decisions.append(pred_binary)
1540
+
1541
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1542
+ decision_acc.append(1.0)
1543
+ else:
1544
+ decision_acc.append(0.0)
1545
+
1546
+ if decision_acc:
1547
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1548
+ try:
1549
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1550
+ summary['decision_metrics'] = {
1551
+ 'accuracy': decision_accuracy,
1552
+ 'f1_macro': f1_score,
1553
+ 'count': len(decision_acc)
1554
+ }
1555
+ except Exception:
1556
+ summary['decision_metrics'] = {
1557
+ 'accuracy': decision_accuracy,
1558
+ 'count': len(decision_acc)
1559
+ }
1560
+
1561
+ # Calculate Pairwise comparison
1562
+ # For refined format, only use refined results (avoid double counting)
1563
+ # For other formats, use all results
1564
+ if refined_format_count > 0:
1565
+ pairwise_results = refined_results
1566
+ else:
1567
+ pairwise_results = valid_results
1568
+
1569
+ paper_scores = []
1570
+ for r in pairwise_results:
1571
+ if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
1572
+ (r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
1573
+ paper_scores.append({
1574
+ 'true_rating': r.get('gt_rating'),
1575
+ 'pred_rating': r.get('model_rating'),
1576
+ 'true_soundness': r.get('gt_soundness'),
1577
+ 'pred_soundness': r.get('model_soundness'),
1578
+ 'true_presentation': r.get('gt_presentation'),
1579
+ 'pred_presentation': r.get('model_presentation'),
1580
+ 'true_confidence': r.get('gt_confidence'),
1581
+ 'pred_confidence': r.get('model_confidence')
1582
+ })
1583
+
1584
+ if len(paper_scores) >= 2:
1585
+ pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
1586
+ summary['pairwise_accuracies'] = pairwise_accuracies
1587
+
1588
+ return results, summary
1589
+
1590
+
1591
+ # ============================================================================
1592
+ # Main Function
1593
+ # ============================================================================
1594
+
1595
+ def parse_args():
1596
+ """Parse command line arguments."""
1597
+ parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
1598
+
1599
+ # Input paths
1600
+ parser.add_argument("--rubrics_path", type=str, required=True,
1601
+ help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
1602
+ parser.add_argument("--reviews_path", type=str, required=True,
1603
+ help="Path to JSON file with model reviews (contains pred_fast_mode)")
1604
+
1605
+ # Evaluation mode
1606
+ parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
1607
+ help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
1608
+
1609
+ # Output paths
1610
+ parser.add_argument("--semantic_output", type=str, default=None,
1611
+ help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
1612
+ parser.add_argument("--auto_metric_output", type=str, default=None,
1613
+ help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
1614
+
1615
+ # Semantic evaluation settings
1616
+ parser.add_argument("--yaml_path", type=str, default=None,
1617
+ help="Path to prompts.yaml file (required for semantic evaluation)")
1618
+ parser.add_argument("--config_path", type=str, default=None,
1619
+ help="Path to configs.yaml file (required for semantic evaluation)")
1620
+
1621
+ # Multi-threading
1622
+ parser.add_argument("--max_workers", type=int, default=None,
1623
+ help="Maximum number of worker threads for semantic evaluation (default: 5)")
1624
+
1625
+ # Strict mode (normalize scores to discrete scales)
1626
+ parser.add_argument("--strict_mode", action="store_true", default=False,
1627
+ help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
1628
+
1629
+ # Input format override
1630
+ parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
1631
+ help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
1632
+
1633
+ return parser.parse_args()
1634
+
1635
+
1636
+ def main():
1637
+ """Main execution function."""
1638
+ args = parse_args()
1639
+
1640
+ script_dir = os.path.dirname(os.path.abspath(__file__))
1641
+
1642
+ # Resolve paths
1643
+ rubrics_path = args.rubrics_path
1644
+ if not os.path.isabs(rubrics_path):
1645
+ rubrics_path = os.path.join(script_dir, rubrics_path)
1646
+
1647
+ reviews_path = args.reviews_path
1648
+ if not os.path.isabs(reviews_path):
1649
+ reviews_path = os.path.join(script_dir, reviews_path)
1650
+
1651
+ max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
1652
+
1653
+ # Validate mode and output paths
1654
+ if args.mode in ["semantic", "both"]:
1655
+ if not args.semantic_output:
1656
+ raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
1657
+ if not args.yaml_path:
1658
+ raise ValueError("--yaml_path is required for semantic evaluation")
1659
+ if not args.config_path:
1660
+ raise ValueError("--config_path is required for semantic evaluation")
1661
+
1662
+ if args.mode in ["auto_metric", "both"]:
1663
+ if not args.auto_metric_output:
1664
+ raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
1665
+
1666
+ # Check if files exist
1667
+ if not os.path.exists(rubrics_path):
1668
+ raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
1669
+ if not os.path.exists(reviews_path):
1670
+ raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
1671
+
1672
+ # Load data
1673
+ print(f"Loading rubrics from {rubrics_path}...")
1674
+ rubrics_data = load_rubrics_json(rubrics_path)
1675
+ print(f"Loaded {len(rubrics_data)} rubrics entries")
1676
+
1677
+ print(f"Loading model reviews from {reviews_path}...")
1678
+ if args.input_format != 'auto':
1679
+ print(f"Using manually specified format: {args.input_format}")
1680
+ else:
1681
+ print("Auto-detecting input format...")
1682
+ reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
1683
+ print(f"Loaded {len(reviews_dict)} model reviews")
1684
+
1685
+ # Combine rubrics and reviews
1686
+ print("Combining rubrics and reviews...")
1687
+ evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
1688
+ print(f"Prepared {len(evaluation_data)} entries for evaluation")
1689
+
1690
+ # Run evaluations based on mode
1691
+ if args.mode in ["semantic", "both"]:
1692
+ # Resolve semantic evaluation paths
1693
+ yaml_path = args.yaml_path
1694
+ if not os.path.isabs(yaml_path):
1695
+ yaml_path = os.path.join(script_dir, yaml_path)
1696
+
1697
+ config_path = args.config_path
1698
+ if not os.path.isabs(config_path):
1699
+ config_path = os.path.join(script_dir, config_path)
1700
+
1701
+ if not os.path.exists(yaml_path):
1702
+ raise FileNotFoundError(f"YAML file not found: {yaml_path}")
1703
+ if not os.path.exists(config_path):
1704
+ raise FileNotFoundError(f"Config file not found: {config_path}")
1705
+
1706
+ # Load prompt template
1707
+ print(f"Loading prompt template from {yaml_path}...")
1708
+ prompt_template = load_prompt_template(yaml_path)
1709
+ if not prompt_template:
1710
+ raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
1711
+
1712
+ # Initialize LLM service
1713
+ print(f"Loading LLM configuration from {config_path}...")
1714
+ llm_config = load_llm_config(config_path)
1715
+ llm_service = create_llm_service_from_config(llm_config)
1716
+ mode = llm_config.get('mode', 'gpt')
1717
+ print(f"LLM service initialized (mode: {mode})")
1718
+ if hasattr(llm_service, 'model_name'):
1719
+ print(f"Using model: {llm_service.model_name}")
1720
+
1721
+ # Run semantic evaluation
1722
+ semantic_results, semantic_summary = run_semantic_evaluation(
1723
+ evaluation_data, prompt_template, llm_service, max_workers
1724
+ )
1725
+
1726
+ # Save semantic results
1727
+ semantic_output = args.semantic_output
1728
+ if not os.path.isabs(semantic_output):
1729
+ semantic_output = os.path.join(script_dir, semantic_output)
1730
+
1731
+ output_dir = os.path.dirname(semantic_output)
1732
+ os.makedirs(output_dir, exist_ok=True)
1733
+
1734
+ with open(semantic_output, 'w', encoding='utf-8') as f:
1735
+ json.dump(semantic_results, f, ensure_ascii=False, indent=2)
1736
+ print(f"\nSemantic evaluation results saved to {semantic_output}")
1737
+
1738
+ # Save semantic summary
1739
+ semantic_summary_path = semantic_output.replace('.json', '_summary.json')
1740
+ with open(semantic_summary_path, 'w', encoding='utf-8') as f:
1741
+ json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
1742
+ print(f"Semantic evaluation summary saved to {semantic_summary_path}")
1743
+
1744
+ # Print semantic summary
1745
+ print("\n" + "="*80)
1746
+ print("SEMANTIC EVALUATION SUMMARY")
1747
+ print("="*80)
1748
+ print(f"Total entries: {semantic_summary['total_entries']}")
1749
+ print(f"Valid entries: {semantic_summary['valid_entries']}")
1750
+ print(f"Failed entries: {semantic_summary['failed_entries']}")
1751
+ if 'overall_score' in semantic_summary:
1752
+ score = semantic_summary['overall_score']
1753
+ print(f"\nOverall Score:")
1754
+ print(f" Mean: {score['mean']:.2f}")
1755
+ print(f" Min: {score['min']:.2f}")
1756
+ print(f" Max: {score['max']:.2f}")
1757
+
1758
+ if args.mode in ["auto_metric", "both"]:
1759
+ # Run auto-metric evaluation
1760
+ auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
1761
+ evaluation_data,
1762
+ strict_mode=args.strict_mode
1763
+ )
1764
+
1765
+ # Save auto-metric results
1766
+ auto_metric_output = args.auto_metric_output
1767
+ if not os.path.isabs(auto_metric_output):
1768
+ auto_metric_output = os.path.join(script_dir, auto_metric_output)
1769
+
1770
+ output_dir = os.path.dirname(auto_metric_output)
1771
+ os.makedirs(output_dir, exist_ok=True)
1772
+
1773
+ with open(auto_metric_output, 'w', encoding='utf-8') as f:
1774
+ json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
1775
+ print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
1776
+
1777
+ # Save auto-metric summary
1778
+ auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
1779
+ with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
1780
+ json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
1781
+ print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
1782
+
1783
+ # Print auto-metric summary
1784
+ print("\n" + "="*80)
1785
+ print("AUTO-METRIC EVALUATION SUMMARY")
1786
+ print("="*80)
1787
+ print(f"Total entries: {auto_metric_summary['total_entries']}")
1788
+ print(f"Valid entries: {auto_metric_summary['valid_entries']}")
1789
+ print(f"MSE entries: {auto_metric_summary['mse_entries']}")
1790
+
1791
+ if 'mse_statistics' in auto_metric_summary:
1792
+ print("\nMSE Statistics:")
1793
+ for dim, stats in auto_metric_summary['mse_statistics'].items():
1794
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1795
+
1796
+ if 'mae_statistics' in auto_metric_summary:
1797
+ print("\nMAE Statistics:")
1798
+ for dim, stats in auto_metric_summary['mae_statistics'].items():
1799
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1800
+
1801
+ # Print refined and initial statistics if available
1802
+ if 'refined_mse_statistics' in auto_metric_summary:
1803
+ print("\nRefined Scores - MSE Statistics:")
1804
+ for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
1805
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1806
+
1807
+ if 'refined_mae_statistics' in auto_metric_summary:
1808
+ print("\nRefined Scores - MAE Statistics:")
1809
+ for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
1810
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1811
+
1812
+ if 'initial_mse_statistics' in auto_metric_summary:
1813
+ print("\nInitial Scores - MSE Statistics:")
1814
+ for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
1815
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1816
+
1817
+ if 'initial_mae_statistics' in auto_metric_summary:
1818
+ print("\nInitial Scores - MAE Statistics:")
1819
+ for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
1820
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1821
+
1822
+ if 'spearman_correlations' in auto_metric_summary:
1823
+ print("\nSpearman Correlations:")
1824
+ for dim, stats in auto_metric_summary['spearman_correlations'].items():
1825
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1826
+
1827
+ # Print refined and initial spearman correlations if available
1828
+ if 'refined_spearman_correlations' in auto_metric_summary:
1829
+ print("\nRefined Scores - Spearman Correlations:")
1830
+ for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
1831
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1832
+
1833
+ if 'initial_spearman_correlations' in auto_metric_summary:
1834
+ print("\nInitial Scores - Spearman Correlations:")
1835
+ for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
1836
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1837
+
1838
+ if 'decision_metrics' in auto_metric_summary:
1839
+ dm = auto_metric_summary['decision_metrics']
1840
+ print(f"\nDecision Metrics:")
1841
+ print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
1842
+ if 'f1_macro' in dm:
1843
+ print(f" F1 (macro): {dm['f1_macro']:.4f}")
1844
+
1845
+ # Print refined and initial decision metrics if available
1846
+ if 'refined_decision_metrics' in auto_metric_summary:
1847
+ print("\nRefined Scores - Decision Metrics:")
1848
+ rdm = auto_metric_summary['refined_decision_metrics']
1849
+ print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
1850
+ if 'f1_macro' in rdm:
1851
+ print(f" F1 (macro): {rdm['f1_macro']:.4f}")
1852
+
1853
+ if 'initial_decision_metrics' in auto_metric_summary:
1854
+ print("\nInitial Scores - Decision Metrics:")
1855
+ idm = auto_metric_summary['initial_decision_metrics']
1856
+ print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
1857
+ if 'f1_macro' in idm:
1858
+ print(f" F1 (macro): {idm['f1_macro']:.4f}")
1859
+
1860
+ print("\n" + "="*80)
1861
+ print("EVALUATION COMPLETE")
1862
+ print("="*80)
1863
+
1864
+
1865
+ if __name__ == "__main__":
1866
+ main()
src/evaluator/2_evaluate_aiscientist.py ADDED
@@ -0,0 +1,1866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
3
+
4
+ This script:
5
+ 1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
6
+ 2. Reads input JSON file containing model reviews (supports multiple formats)
7
+ 3. Supports three evaluation modes:
8
+ - semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
9
+ - auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
10
+ - both: Run both evaluations separately
11
+ 4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
12
+ 5. Outputs separate JSON files for results and summaries
13
+
14
+ Usage:
15
+ # Semantic evaluation only
16
+ python 2_evaluate.py \
17
+ --rubrics_path eval_rubrics.json \
18
+ --reviews_path model_reviews.json \
19
+ --mode semantic \
20
+ --yaml_path prompts.yaml \
21
+ --config_path configs.yaml \
22
+ --semantic_output semantic_results.json \
23
+ --max_workers 5
24
+
25
+ # Auto-metric evaluation only
26
+ python 2_evaluate.py \
27
+ --rubrics_path eval_rubrics.json \
28
+ --reviews_path model_reviews.json \
29
+ --mode auto_metric \
30
+ --auto_metric_output auto_metric_results.json
31
+
32
+ # Auto-metric evaluation with strict mode (normalize scores to discrete scales)
33
+ python 2_evaluate.py \
34
+ --rubrics_path eval_rubrics.json \
35
+ --reviews_path model_reviews.json \
36
+ --mode auto_metric \
37
+ --auto_metric_output auto_metric_results.json \
38
+ --strict_mode
39
+
40
+ # Auto-metric evaluation with manually specified input format (refined)
41
+ python 2_evaluate.py \
42
+ --rubrics_path eval_rubrics.json \
43
+ --reviews_path model_reviews.json \
44
+ --mode auto_metric \
45
+ --auto_metric_output auto_metric_results.json \
46
+ --input_format refined
47
+
48
+ # Auto-metric evaluation with manually specified input format (original)
49
+ python 2_evaluate.py \
50
+ --rubrics_path eval_rubrics.json \
51
+ --reviews_path ours.json \
52
+ --mode auto_metric \
53
+ --auto_metric_output auto_metric_results.json \
54
+ --input_format original
55
+
56
+ # Both evaluations
57
+ python 2_evaluate.py \
58
+ --rubrics_path eval_rubrics.json \
59
+ --reviews_path model_reviews.json \
60
+ --mode both \
61
+ --yaml_path prompts.yaml \
62
+ --config_path configs.yaml \
63
+ --semantic_output semantic_results.json \
64
+ --auto_metric_output auto_metric_results.json \
65
+ --max_workers 32
66
+ """
67
+ from __future__ import annotations
68
+
69
+ import json
70
+ import os
71
+ import sys
72
+ import argparse
73
+ import yaml
74
+ import math
75
+ import re
76
+ from typing import Dict, List, Any, Optional
77
+ from concurrent.futures import ThreadPoolExecutor, as_completed
78
+ from tqdm import tqdm
79
+ from itertools import combinations
80
+ from scipy.stats import spearmanr
81
+ from sklearn.metrics import precision_recall_fscore_support
82
+
83
+ # Add parent directory to path
84
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
85
+ # Import parse_llm_response from local llm_service module
86
+ import llm_service as local_llm_service
87
+ parse_llm_response = local_llm_service.parse_llm_response
88
+
89
+ # Import from shared/utils for gpt/vllm support
90
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
91
+ if project_root not in sys.path:
92
+ sys.path.insert(0, project_root)
93
+
94
+ from shared.utils.llm_service import LLMService
95
+ from shared.utils.vllm_service import VLLMService
96
+ from shared.utils.gpt_service import GPTService
97
+ sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
98
+ from json_parser import parse_review_markdown
99
+
100
+
101
+ def convert_ai_researcher(review: dict) -> str:
102
+ """
103
+ Convert the review text from ai-researcher format to unified review system format.
104
+ """
105
+
106
+ summary = review["Summary"]
107
+ strengths = "\n".join(f"- {s}" for s in review["Strengths"])
108
+ weaknesses = "\n".join(f"- {w}" for w in review["Weaknesses"])
109
+
110
+ # scores
111
+ originality = review["Originality"]
112
+ quality = review["Quality"]
113
+ clarity = review["Clarity"]
114
+ significance = review["Significance"]
115
+
116
+ questions = "\n".join(f"- {q}" for q in review["Questions"])
117
+ limitations = "\n".join(f"- {l}" for l in review["Limitations"])
118
+ ethical_concerns = review["Ethical Concerns"]
119
+
120
+ # scores again
121
+ soundness = review["Soundness"]
122
+ presentation = review["Presentation"]
123
+ contribution = review["Contribution"]
124
+ overall = review["Overall"]
125
+ confidence = review["Confidence"]
126
+
127
+ # final decision
128
+ decision = review["Decision"]
129
+
130
+ meta_review = {
131
+ "rating": overall,
132
+ "soundness": soundness,
133
+ "presentation": presentation,
134
+ "contribution": contribution,
135
+ "confidence": confidence,
136
+ "decision": decision.lower().strip(),
137
+ }
138
+
139
+ return f"Summary: {summary}\nStrengths: {strengths}\nWeaknesses: {weaknesses}\nOriginality: {originality}\nQuality: {quality}\nClarity: {clarity}\nSignificance: {significance}\nQuestions: {questions}\nLimitations: {limitations}\nEthical Concerns: {ethical_concerns}\nSoundness: {soundness}\nPresentation: {presentation}\nContribution: {contribution}\nOverall: {overall}\nConfidence: {confidence}\nDecision: {decision}", meta_review
140
+
141
+
142
+ def convert_agenticreview(review_text: str) -> tuple:
143
+ """
144
+ Convert the review text from agenticreview format to unified review system format.
145
+
146
+ The agenticreview format has text like:
147
+ "Overall rating: 5\n\nSignificance and novelty: ..."
148
+
149
+ Args:
150
+ review_text: Raw review text string
151
+
152
+ Returns:
153
+ Tuple of (formatted_review_text, meta_review_dict)
154
+ """
155
+ # Extract rating from "Overall rating: x" format
156
+ rating = None
157
+ rating_match = re.search(r'Overall\s+rating\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
158
+ if rating_match:
159
+ try:
160
+ rating = float(rating_match.group(1))
161
+ except (ValueError, IndexError):
162
+ pass
163
+
164
+ # If not found, try alternative patterns
165
+ if rating is None:
166
+ rating_match = re.search(r'(?:rating|score)\s*[:=]\s*(\d+\.?\d*)', review_text, re.IGNORECASE)
167
+ if rating_match:
168
+ try:
169
+ rating = float(rating_match.group(1))
170
+ except (ValueError, IndexError):
171
+ pass
172
+
173
+ # Try to extract from parse_review_markdown as fallback
174
+ if rating is None:
175
+ try:
176
+ parsed = parse_review_markdown(review_text)
177
+ rating = parsed.get('rating')
178
+ except Exception:
179
+ pass
180
+
181
+ # Create meta_review dict - agenticreview only has rating, no other scores
182
+ meta_review = {
183
+ "rating": rating,
184
+ "soundness": None,
185
+ "presentation": None,
186
+ "contribution": None,
187
+ "confidence": None,
188
+ "decision": None,
189
+ }
190
+
191
+ # Return the review text as-is (it's already in a readable format)
192
+ return review_text, meta_review
193
+
194
+
195
+ class ReviewProcessor:
196
+ """Handles the extraction and processing of reviews from different sources."""
197
+
198
+ @staticmethod
199
+ def extract_review_content(pred_context):
200
+ """
201
+ Extract the review content from the prediction context.
202
+
203
+ Args:
204
+ pred_context: Raw prediction data that contains the review
205
+
206
+ Returns:
207
+ str: Extracted review content
208
+ """
209
+ try:
210
+ # First attempt to extract from boxed format
211
+ return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
212
+ except Exception:
213
+ # Alternative extraction if the first method fails
214
+ if isinstance(pred_context, dict) and 'output' in pred_context:
215
+ return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
216
+ else:
217
+ # Return as is if extraction fails
218
+ return pred_context
219
+
220
+
221
+ # ============================================================================
222
+ # Semantic Evaluation Functions (from 2_evaluate_direct.py)
223
+ # ============================================================================
224
+
225
+ def load_prompt_template(yaml_path: str) -> str:
226
+ """Load the evaluator prompt from YAML file."""
227
+ with open(yaml_path, 'r', encoding='utf-8') as f:
228
+ prompts = yaml.safe_load(f)
229
+ return prompts.get('v1_evaluator_prompt', '')
230
+
231
+
232
+ def build_evaluation_prompt(
233
+ rubrics: List[Dict[str, Any]],
234
+ paper_content: str,
235
+ review: str,
236
+ prompt_template: str
237
+ ) -> str:
238
+ """Build the evaluation prompt by replacing placeholders."""
239
+ rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
240
+ prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
241
+ prompt = prompt.replace('<<paper_content>>', paper_content)
242
+ prompt = prompt.replace('<<review>>', review)
243
+ return prompt
244
+
245
+
246
+ def calculate_weighted_scores(
247
+ raw_scores: Dict[str, Dict[str, Any]],
248
+ rubrics: List[Dict[str, Any]]
249
+ ) -> Dict[str, float]:
250
+ """Calculate weighted scores for each rubric."""
251
+ rubric_weights = {r['title']: r['weight'] for r in rubrics}
252
+ weighted_scores = {}
253
+
254
+ for rubric_title, rubric_data in raw_scores.items():
255
+ if rubric_title not in rubric_weights:
256
+ continue
257
+
258
+ rubric_score = rubric_data.get('score', 0)
259
+ if isinstance(rubric_score, str):
260
+ try:
261
+ rubric_score = int(rubric_score)
262
+ except ValueError:
263
+ rubric_score = 0
264
+
265
+ if rubric_score not in [0, 1]:
266
+ rubric_score = 1 if rubric_score > 0 else 0
267
+
268
+ weight = rubric_weights[rubric_title]
269
+ weighted_scores[rubric_title] = rubric_score * weight
270
+
271
+ return weighted_scores
272
+
273
+
274
+ def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
275
+ """Calculate scores for each rubric."""
276
+ scores = {}
277
+ for rubric_title, rubric_data in raw_scores.items():
278
+ scores[rubric_title] = rubric_data.get('score', 0)
279
+ return scores
280
+
281
+
282
+ def evaluate_review_semantic(
283
+ entry: Dict[str, Any],
284
+ paper_content: str,
285
+ prompt_template: str,
286
+ llm_service: LLMService
287
+ ) -> Dict[str, Any]:
288
+ """Evaluate a single review using article-specific rubrics."""
289
+ entry_id = entry.get('id', 'unknown')
290
+ rubrics = entry.get('rubrics', [])
291
+ model_review = entry.get('model_review', '')
292
+
293
+ if not rubrics:
294
+ return {
295
+ 'id': entry_id,
296
+ 'raw_scores': {},
297
+ 'weighted_scores': {},
298
+ 'total_score': 0.0,
299
+ 'error': 'No valid rubrics found',
300
+ 'raw_response': ''
301
+ }
302
+
303
+ # Build prompt
304
+ prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
305
+
306
+ # Call LLM
307
+ try:
308
+ messages = [{"role": "user", "content": prompt}]
309
+ response = llm_service.generate(messages=messages)
310
+
311
+ # Parse response
312
+ raw_scores = parse_llm_response(response)
313
+ weighted_scores = calculate_scores(raw_scores)
314
+ total_score = sum(weighted_scores.values())
315
+
316
+ return {
317
+ 'id': entry_id,
318
+ 'raw_scores': raw_scores,
319
+ 'weighted_scores': weighted_scores,
320
+ 'total_score': total_score,
321
+ 'raw_response': response
322
+ }
323
+ except Exception as e:
324
+ print(f"[ERROR] Error evaluating review {entry_id}: {e}")
325
+ return {
326
+ 'id': entry_id,
327
+ 'raw_scores': {},
328
+ 'weighted_scores': {},
329
+ 'total_score': 0.0,
330
+ 'error': str(e),
331
+ 'raw_response': ''
332
+ }
333
+
334
+
335
+ def calculate_per_rubric_statistics(
336
+ valid_results: List[Dict[str, Any]],
337
+ rubric_titles: List[str]
338
+ ) -> Dict[str, Dict[str, float]]:
339
+ """Calculate per-rubric statistics from evaluation results."""
340
+ rubric_scores = {title: [] for title in rubric_titles}
341
+
342
+ for result in valid_results:
343
+ weighted_scores = result.get('weighted_scores', {})
344
+ if not isinstance(weighted_scores, dict):
345
+ continue
346
+
347
+ for rubric_title in rubric_titles:
348
+ if rubric_title in weighted_scores:
349
+ score = weighted_scores[rubric_title]
350
+ if isinstance(score, str):
351
+ try:
352
+ score = float(score)
353
+ except ValueError:
354
+ continue
355
+ elif isinstance(score, (int, float)):
356
+ score = float(score)
357
+ else:
358
+ continue
359
+ rubric_scores[rubric_title].append(score)
360
+
361
+ per_rubric_stats = {}
362
+ for rubric_title in rubric_titles:
363
+ scores = rubric_scores[rubric_title]
364
+ if not scores:
365
+ continue
366
+
367
+ mean_score = sum(scores) / len(scores)
368
+ min_score = min(scores)
369
+ max_score = max(scores)
370
+ count = len(scores)
371
+
372
+ if rubric_title == "False or Contradictory Claims":
373
+ pass_count = sum(1 for s in scores if s >= 0)
374
+ else:
375
+ pass_count = sum(1 for s in scores if s >= 1)
376
+ pass_rate = pass_count / count if count > 0 else 0.0
377
+
378
+ per_rubric_stats[rubric_title] = {
379
+ 'mean': mean_score,
380
+ 'min': min_score,
381
+ 'max': max_score,
382
+ 'count': count,
383
+ 'pass_rate': pass_rate
384
+ }
385
+
386
+ return per_rubric_stats
387
+
388
+
389
+ # ============================================================================
390
+ # Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
391
+ # ============================================================================
392
+
393
+ def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
394
+ """Extract numeric scores and decision from a review markdown text."""
395
+ if not review_text:
396
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
397
+
398
+ try:
399
+ parsed = parse_review_markdown(review_text)
400
+ decision = parsed.get('decision', '')
401
+ if decision:
402
+ decision_lower = decision.lower().strip()
403
+ if 'accept' in decision_lower:
404
+ decision = 'accept'
405
+ elif 'reject' in decision_lower:
406
+ decision = 'reject'
407
+ elif 'undecided' in decision_lower:
408
+ decision = 'undecided'
409
+ else:
410
+ decision = decision_lower
411
+ else:
412
+ decision = None
413
+
414
+ return {
415
+ 'soundness': parsed.get('soundness'),
416
+ 'presentation': parsed.get('presentation'),
417
+ 'rating': parsed.get('rating'),
418
+ 'confidence': parsed.get('confidence'),
419
+ 'decision': decision
420
+ }
421
+ except Exception as e:
422
+ print(f"Warning: Failed to parse review text: {e}")
423
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
424
+
425
+
426
+ def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
427
+ """Calculate Mean Squared Error for a single value."""
428
+ if predicted is None or ground_truth is None:
429
+ return None
430
+ return (predicted - ground_truth) ** 2
431
+
432
+
433
+ def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
434
+ """Calculate Mean Absolute Error for a single value."""
435
+ if predicted is None or ground_truth is None:
436
+ return None
437
+ return abs(predicted - ground_truth)
438
+
439
+
440
+ def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
441
+ """
442
+ Normalize a float score to the nearest discrete value based on scale type.
443
+ Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
444
+
445
+ Args:
446
+ score: The float score to normalize (can be None)
447
+ scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
448
+ or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
449
+
450
+ Returns:
451
+ Normalized discrete score, or None if input is None
452
+ """
453
+ if score is None:
454
+ return None
455
+
456
+ try:
457
+ score = float(score)
458
+ except (ValueError, TypeError):
459
+ return None
460
+
461
+ if scale_type == '0-5':
462
+ # Discrete values: 0, 1, 2, 3, 4, 5
463
+ discrete_values = [0, 1, 2, 3, 4, 5]
464
+ # Clamp to valid range
465
+ score = max(0, min(5, score))
466
+ # Find nearest discrete value, with round-half-up tie-breaking
467
+ # For ties, prefer the higher value
468
+ best_value = None
469
+ best_distance = float('inf')
470
+ for val in discrete_values:
471
+ distance = abs(val - score)
472
+ if distance < best_distance:
473
+ best_distance = distance
474
+ best_value = val
475
+ elif distance == best_distance and val > best_value:
476
+ # Tie-breaking: prefer higher value (round-half-up)
477
+ best_value = val
478
+ return best_value
479
+ elif scale_type == '0-10':
480
+ # Discrete values: 0, 2, 4, 6, 8, 10
481
+ discrete_values = [0, 2, 4, 6, 8, 10]
482
+ # Clamp to valid range
483
+ score = max(0, min(10, score))
484
+ # Find nearest discrete value, with round-half-up tie-breaking
485
+ best_value = None
486
+ best_distance = float('inf')
487
+ for val in discrete_values:
488
+ distance = abs(val - score)
489
+ if distance < best_distance:
490
+ best_distance = distance
491
+ best_value = val
492
+ elif distance == best_distance and val > best_value:
493
+ # Tie-breaking: prefer higher value (round-half-up)
494
+ best_value = val
495
+ return best_value
496
+ else:
497
+ raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
498
+
499
+
500
+ def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
501
+ """
502
+ Normalize all scores in a dictionary to their appropriate discrete scales.
503
+
504
+ Args:
505
+ scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
506
+
507
+ Returns:
508
+ Dictionary with normalized scores
509
+ """
510
+ normalized = {}
511
+
512
+ # soundness, presentation, confidence use 0-5 scale
513
+ for key in ['soundness', 'presentation', 'confidence']:
514
+ normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
515
+
516
+ # rating uses 0-10 scale
517
+ normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
518
+
519
+ return normalized
520
+
521
+
522
+ def calculate_score_metrics(
523
+ model_scores: Dict[str, float],
524
+ ground_truth_scores: Dict[str, float],
525
+ normalize: bool = False
526
+ ) -> Dict[str, Any]:
527
+ """
528
+ Calculate MSE and MAE metrics for each scoring dimension.
529
+
530
+ Args:
531
+ model_scores: Dictionary with model scores
532
+ ground_truth_scores: Dictionary with ground truth scores
533
+ normalize: If True, normalize scores to discrete scales before computing metrics
534
+
535
+ Returns:
536
+ Dictionary with MSE, MAE metrics and optionally normalized scores
537
+ """
538
+ dimensions = ['soundness', 'presentation', 'rating', 'confidence']
539
+
540
+ # Normalize scores to discrete scales if requested
541
+ if normalize:
542
+ model_scores_normalized = normalize_scores_dict(model_scores)
543
+ gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
544
+ else:
545
+ model_scores_normalized = model_scores
546
+ gt_scores_normalized = ground_truth_scores
547
+
548
+ mse_values = {}
549
+ mae_values = {}
550
+ valid_count = 0
551
+
552
+ for dim in dimensions:
553
+ # Use normalized scores for metric calculation
554
+ mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
555
+ mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
556
+ mse_values[f'{dim}_mse'] = mse
557
+ mae_values[f'{dim}_mae'] = mae
558
+ if mse is not None:
559
+ valid_count += 1
560
+
561
+ overall_error = sum([v for v in mse_values.values() if v is not None])
562
+
563
+ result = {
564
+ **mse_values,
565
+ **mae_values,
566
+ 'overall_error': overall_error if valid_count > 0 else None,
567
+ 'valid_dimensions': valid_count
568
+ }
569
+
570
+ # Include normalized scores in result for transparency (only if normalize=True)
571
+ if normalize:
572
+ result['model_scores_normalized'] = model_scores_normalized
573
+ result['gt_scores_normalized'] = gt_scores_normalized
574
+
575
+ return result
576
+
577
+
578
+ def normalize_score_value(value):
579
+ """Normalize score value to float, handling string representations."""
580
+ if value is None:
581
+ return None
582
+ if isinstance(value, (int, float)):
583
+ return float(value)
584
+ if isinstance(value, str):
585
+ # Try to extract numeric value from string (e.g., "2.75" -> 2.75)
586
+ try:
587
+ import re
588
+ match = re.search(r'(\d+\.?\d*)', value)
589
+ if match:
590
+ return float(match.group(1))
591
+ except:
592
+ pass
593
+ return None
594
+
595
+
596
+ def normalize_decision(decision):
597
+ """Normalize decision string to standard format."""
598
+ if decision is None:
599
+ return None
600
+ decision_lower = str(decision).lower().strip()
601
+ if 'accept' in decision_lower:
602
+ return 'accept'
603
+ elif 'reject' in decision_lower:
604
+ return 'reject'
605
+ elif 'undecided' in decision_lower:
606
+ return 'undecided'
607
+ else:
608
+ return decision_lower
609
+
610
+
611
+ def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
612
+ """
613
+ Extract scores from a structured dictionary (scores or initial_scores format).
614
+
615
+ Args:
616
+ scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
617
+
618
+ Returns:
619
+ Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
620
+ """
621
+ if not scores_dict:
622
+ return {
623
+ 'soundness': None,
624
+ 'presentation': None,
625
+ 'rating': None,
626
+ 'confidence': None,
627
+ 'decision': None
628
+ }
629
+
630
+ return {
631
+ 'soundness': normalize_score_value(scores_dict.get('soundness')),
632
+ 'presentation': normalize_score_value(scores_dict.get('presentation')),
633
+ 'rating': normalize_score_value(scores_dict.get('rating')),
634
+ 'confidence': normalize_score_value(scores_dict.get('confidence')),
635
+ 'decision': normalize_decision(scores_dict.get('decision'))
636
+ }
637
+
638
+
639
+ def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
640
+ """
641
+ Evaluate a single entry by extracting scores and calculating metrics.
642
+
643
+ Args:
644
+ entry: Evaluation entry containing model_review, scores, initial_scores, etc.
645
+ use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
646
+
647
+ Returns:
648
+ Dict containing evaluation metrics
649
+ """
650
+ entry_id = entry.get('id', 'unknown')
651
+ model_review = entry.get('model_review', '')
652
+ format_type = entry.get('format', 'unknown')
653
+
654
+ # Extract scores based on format
655
+ model_scores = {}
656
+ model_decision = None
657
+
658
+ if format_type == 'refined' and not use_initial_scores:
659
+ # Use refined scores from structured data
660
+ scores_dict = entry.get('scores', {})
661
+ model_data = extract_scores_from_dict(scores_dict)
662
+ model_scores = {
663
+ 'soundness': model_data.get('soundness'),
664
+ 'presentation': model_data.get('presentation'),
665
+ 'rating': model_data.get('rating'),
666
+ 'confidence': model_data.get('confidence')
667
+ }
668
+ model_decision = model_data.get('decision')
669
+ elif format_type == 'refined' and use_initial_scores:
670
+ # Use initial scores from structured data
671
+ initial_scores_dict = entry.get('initial_scores', {})
672
+ model_data = extract_scores_from_dict(initial_scores_dict)
673
+ model_scores = {
674
+ 'soundness': model_data.get('soundness'),
675
+ 'presentation': model_data.get('presentation'),
676
+ 'rating': model_data.get('rating'),
677
+ 'confidence': model_data.get('confidence')
678
+ }
679
+ model_decision = model_data.get('decision')
680
+ elif format_type == 'original':
681
+ # Use initial scores from structured data
682
+ initial_scores_dict = entry.get('initial_scores', {})
683
+ model_data = extract_scores_from_dict(initial_scores_dict)
684
+ model_scores = {
685
+ 'soundness': model_data.get('soundness'),
686
+ 'presentation': model_data.get('presentation'),
687
+ 'rating': model_data.get('rating'),
688
+ 'confidence': model_data.get('confidence')
689
+ }
690
+ model_decision = model_data.get('decision')
691
+
692
+ # Fallback: If confidence is missing from structured data, try to extract from review text
693
+ # (meta_review may not have confidence field, but review text might)
694
+ if model_scores.get('confidence') is None and model_review:
695
+ try:
696
+ review_data = extract_scores_from_review(model_review)
697
+ if review_data.get('confidence') is not None:
698
+ model_scores['confidence'] = review_data.get('confidence')
699
+ except Exception:
700
+ pass # Keep confidence as None if extraction fails
701
+ else:
702
+ # Fallback: extract from markdown review text
703
+ model_data = extract_scores_from_review(model_review)
704
+ model_scores = {
705
+ 'soundness': model_data.get('soundness'),
706
+ 'presentation': model_data.get('presentation'),
707
+ 'rating': model_data.get('rating'),
708
+ 'confidence': model_data.get('confidence')
709
+ }
710
+ model_decision = model_data.get('decision')
711
+
712
+ # Get ground truth scores from golden_review ONLY
713
+ # Ground truth must ONLY come from golden_review, never from model output
714
+ # If extraction fails, leave fields as None (do not use model_review as fallback)
715
+ ground_truth_review = entry.get('golden_review', '')
716
+ ground_truth_scores = {}
717
+ gt_decision = None
718
+
719
+ if not ground_truth_review:
720
+ print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
721
+ else:
722
+ try:
723
+ # Extract scores from golden_review markdown text
724
+ gt_data = extract_scores_from_review(ground_truth_review)
725
+ if not gt_data:
726
+ print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
727
+ else:
728
+ ground_truth_scores = {
729
+ 'soundness': gt_data.get('soundness'),
730
+ 'presentation': gt_data.get('presentation'),
731
+ 'rating': gt_data.get('rating'),
732
+ 'confidence': gt_data.get('confidence')
733
+ }
734
+ gt_decision = normalize_decision(gt_data.get('decision'))
735
+ # Note: If any field is None, it stays None - we do NOT use model_review as fallback
736
+ # Using model output as ground truth would inflate evaluation scores
737
+ except Exception as e:
738
+ print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
739
+ print(f" Ground truth scores will be empty. Error: {str(e)}")
740
+
741
+ # Calculate MSE and MAE metrics (with optional normalization in strict mode)
742
+ score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
743
+
744
+ # Calculate decision accuracy
745
+ decision_match = False
746
+ decision_accuracy = None
747
+ if model_decision is not None and gt_decision is not None:
748
+ model_decision_normalized = normalize_decision(model_decision)
749
+ decision_match = (model_decision_normalized == gt_decision)
750
+ decision_accuracy = 1.0 if decision_match else 0.0
751
+
752
+ result = {
753
+ 'id': entry_id,
754
+ 'format': format_type,
755
+ 'model_soundness': model_scores.get('soundness'),
756
+ 'model_presentation': model_scores.get('presentation'),
757
+ 'model_rating': model_scores.get('rating'),
758
+ 'model_confidence': model_scores.get('confidence'),
759
+ 'model_decision': model_decision,
760
+ 'gt_soundness': ground_truth_scores.get('soundness'),
761
+ 'gt_presentation': ground_truth_scores.get('presentation'),
762
+ 'gt_rating': ground_truth_scores.get('rating'),
763
+ 'gt_confidence': ground_truth_scores.get('confidence'),
764
+ 'gt_decision': gt_decision,
765
+ 'decision_match': decision_match,
766
+ 'decision_accuracy': decision_accuracy,
767
+ **score_metrics
768
+ }
769
+
770
+ # Add prefix to indicate which scores were used
771
+ if format_type == 'refined':
772
+ if use_initial_scores:
773
+ result['score_type'] = 'initial'
774
+ else:
775
+ result['score_type'] = 'refined'
776
+ else:
777
+ result['score_type'] = 'auto'
778
+
779
+ return result
780
+
781
+
782
+ def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
783
+ """Calculate pairwise accuracy for each metric by comparing rankings."""
784
+ if len(paper_scores) < 2:
785
+ return {}
786
+
787
+ total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
788
+ correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
789
+
790
+ for paper1, paper2 in combinations(paper_scores, 2):
791
+ # Check rating ranking
792
+ if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
793
+ paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
794
+ total_valid_pairs['rating'] += 1
795
+ true_order = paper1['true_rating'] > paper2['true_rating']
796
+ pred_order = paper1['pred_rating'] > paper2['pred_rating']
797
+ if true_order == pred_order:
798
+ correct_pairs['rating'] += 1
799
+
800
+ # Similar for other dimensions...
801
+ # (abbreviated for space, similar logic for soundness, presentation, confidence)
802
+ for metric in ['soundness', 'presentation', 'confidence']:
803
+ true_key = f'true_{metric}'
804
+ pred_key = f'pred_{metric}'
805
+ if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
806
+ paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
807
+ total_valid_pairs[metric] += 1
808
+ true_order = paper1[true_key] > paper2[true_key]
809
+ pred_order = paper1[pred_key] > paper2[pred_key]
810
+ if true_order == pred_order:
811
+ correct_pairs[metric] += 1
812
+
813
+ pairwise_accuracies = {
814
+ metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
815
+ for metric in ['rating', 'soundness', 'presentation', 'confidence']
816
+ }
817
+
818
+ return pairwise_accuracies
819
+
820
+
821
+ # ============================================================================
822
+ # Data Loading Functions
823
+ # ============================================================================
824
+
825
+ def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
826
+ """Load rubrics JSON and create lookup by id."""
827
+ with open(rubrics_path, 'r', encoding='utf-8') as f:
828
+ data = json.load(f)
829
+
830
+ if isinstance(data, list):
831
+ return {item['id']: item for item in data}
832
+ elif isinstance(data, dict):
833
+ return data
834
+ else:
835
+ raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
836
+
837
+
838
+ def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
839
+ """
840
+ Load model reviews JSON and extract reviews by id.
841
+
842
+ Supports two input formats:
843
+ 1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
844
+ 2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
845
+
846
+ Args:
847
+ reviews_path: Path to JSON file containing model reviews
848
+ format_override: Optional format override ('refined', 'original', or None for auto-detect)
849
+
850
+ Returns:
851
+ Dict mapping paper_id to dict containing:
852
+ - 'review': review text (markdown)
853
+ - 'scores': refined scores dict (if available)
854
+ - 'initial_scores': initial scores dict (if available)
855
+ - 'format': 'refined' or 'original'
856
+ """
857
+ with open(reviews_path, 'r', encoding='utf-8') as f:
858
+ data = json.load(f)
859
+
860
+ if isinstance(data, dict):
861
+ data = list(data.values())
862
+
863
+ reviews_dict = {}
864
+ for item in data:
865
+ item_id = None
866
+ review_text = ''
867
+ scores = None
868
+ initial_scores = None
869
+ format_type = None
870
+
871
+ # Use format override if provided, otherwise auto-detect
872
+ if format_override and format_override != 'auto':
873
+ # Force use specified format
874
+ if format_override == 'refined':
875
+ item_id = item.get('paper_id') or item.get('id')
876
+ if not item_id:
877
+ continue
878
+ format_type = 'refined'
879
+ review_text = item.get('review_markdown', '') or item.get('review', '')
880
+ scores = item.get('scores', {})
881
+ initial_scores = item.get('initial_scores', {})
882
+ elif format_override == 'original':
883
+ item_id = item.get('id')
884
+ if not item_id:
885
+ continue
886
+ format_type = 'original'
887
+ model_prediction = item.get('model_prediction', {})
888
+ meta_review = model_prediction.get('meta_review', {})
889
+ review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
890
+ initial_scores = {
891
+ 'rating': meta_review.get('rating'),
892
+ 'soundness': meta_review.get('soundness'),
893
+ 'presentation': meta_review.get('presentation'),
894
+ 'contribution': meta_review.get('contribution'),
895
+ 'decision': model_prediction.get('decision'),
896
+ }
897
+ else:
898
+ raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
899
+ else:
900
+ # Auto-detect format
901
+ if "paper_id" in item:
902
+ # Refined format (from refinement pipeline)
903
+ item_id = item.get('paper_id')
904
+ if not item_id:
905
+ continue
906
+
907
+ # Check if this is refined format (has scores and initial_scores)
908
+ if 'scores' in item and 'initial_scores' in item:
909
+ format_type = 'refined'
910
+ review_text = item.get('review_markdown', '') or item.get('review', '')
911
+ scores = item.get('scores', {})
912
+ initial_scores = item.get('initial_scores', {})
913
+ else:
914
+ # Standard format with paper_id
915
+ format_type = 'standard'
916
+ review_text = item.get('review_markdown', '') or item.get('review', '')
917
+ elif "model_prediction" in item:
918
+ # Original format (like ours.json)
919
+ item_id = item.get('id')
920
+ if not item_id:
921
+ continue
922
+
923
+ format_type = 'original'
924
+ model_prediction = item.get('model_prediction', {})
925
+
926
+ review_text = model_prediction.get('raw_text', '')
927
+
928
+ if review_text is None:
929
+ continue
930
+
931
+ # Detect format: agenticreview has raw_text as string with "Overall rating: x"
932
+ # ai_researcher format has raw_text as dict or JSON string with structured fields
933
+ is_agenticreview = False
934
+ if isinstance(review_text, str):
935
+ # Check if it's a JSON string (ai_researcher format)
936
+ try:
937
+ parsed_json = json.loads(review_text)
938
+ if isinstance(parsed_json, dict) and any(key in parsed_json for key in ["Summary", "Strengths", "Overall", "Decision"]):
939
+ # It's ai_researcher format
940
+ review_text = parsed_json
941
+ review_text, meta_review = convert_ai_researcher(review_text)
942
+ else:
943
+ # It's agenticreview format (plain text with "Overall rating: x")
944
+ is_agenticreview = True
945
+ except (json.JSONDecodeError, TypeError):
946
+ # Not JSON, check if it contains "Overall rating:" pattern
947
+ if re.search(r'Overall\s+rating\s*[:=]', review_text, re.IGNORECASE):
948
+ is_agenticreview = True
949
+ else:
950
+ # Try to parse as ai_researcher anyway
951
+ try:
952
+ review_text = json.loads(review_text)
953
+ review_text, meta_review = convert_ai_researcher(review_text)
954
+ except:
955
+ review_text = 'Empty Review'
956
+ meta_review = {}
957
+ elif isinstance(review_text, dict):
958
+ # It's ai_researcher format (dict)
959
+ review_text, meta_review = convert_ai_researcher(review_text)
960
+ else:
961
+ review_text = 'Empty Review'
962
+ meta_review = {}
963
+
964
+ # Handle agenticreview format
965
+ if is_agenticreview:
966
+ review_text, meta_review = convert_agenticreview(review_text)
967
+
968
+ # Extract initial scores
969
+ # Use meta_review as primary source (from convert_ai_researcher or convert_agenticreview)
970
+ # Fallback to model_prediction.get('decision') if not in meta_review
971
+ initial_scores = {
972
+ 'rating': meta_review.get('rating'),
973
+ 'soundness': meta_review.get('soundness'),
974
+ 'presentation': meta_review.get('presentation'),
975
+ 'contribution': meta_review.get('contribution'),
976
+ 'confidence': meta_review.get('confidence'),
977
+ 'decision': meta_review.get('decision') or model_prediction.get('decision'),
978
+ }
979
+ else:
980
+ # Legacy format (pred_fast_mode)
981
+ item_id = item.get('id')
982
+ if not item_id:
983
+ continue
984
+
985
+ format_type = 'legacy'
986
+ review_dict = item.get('pred_fast_mode', {})
987
+ if isinstance(review_dict, dict):
988
+ review_text = review_dict.get('raw_text', '')
989
+ else:
990
+ review_text = str(review_dict)
991
+
992
+ # Extract review content from the review text field
993
+ try:
994
+ if review_text:
995
+ # extracted_review = ReviewProcessor.extract_review_content(review_text)
996
+ extracted_review = review_text
997
+ else:
998
+ extracted_review = ''
999
+
1000
+ reviews_dict[item_id] = {
1001
+ 'review': extracted_review,
1002
+ 'scores': scores,
1003
+ 'initial_scores': initial_scores,
1004
+ 'format': format_type
1005
+ }
1006
+ except Exception as e:
1007
+ print(f"[WARN] Failed to extract review for {item_id}: {e}")
1008
+ continue
1009
+
1010
+ return reviews_dict
1011
+
1012
+
1013
+ def combine_rubrics_and_reviews(
1014
+ rubrics_data: Dict[str, Dict[str, Any]],
1015
+ reviews_dict: Dict[str, Dict[str, Any]]
1016
+ ) -> List[Dict[str, Any]]:
1017
+ """
1018
+ Combine rubrics and reviews into evaluation entries.
1019
+
1020
+ Args:
1021
+ rubrics_data: Dict mapping paper_id to rubric entry
1022
+ reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
1023
+
1024
+ Returns:
1025
+ List of evaluation entries with model_review, scores, initial_scores, and format info
1026
+ """
1027
+ combined = []
1028
+ missing_reviews = []
1029
+
1030
+ for paper_id, rubric_entry in rubrics_data.items():
1031
+ review_data = reviews_dict.get(paper_id)
1032
+ if not review_data or not review_data.get('review'):
1033
+ missing_reviews.append(paper_id)
1034
+ continue
1035
+
1036
+ entry = {
1037
+ 'id': paper_id,
1038
+ 'paper_context': rubric_entry.get('paper_context', ''),
1039
+ 'decision': rubric_entry.get('decision', ''),
1040
+ 'golden_review': rubric_entry.get('golden_review', ''),
1041
+ 'rubrics': rubric_entry.get('rubrics', []),
1042
+ 'model_review': review_data.get('review', ''),
1043
+ 'scores': review_data.get('scores'), # Refined scores (if available)
1044
+ 'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
1045
+ 'format': review_data.get('format', 'unknown') # Format type
1046
+ }
1047
+ combined.append(entry)
1048
+
1049
+ if missing_reviews:
1050
+ print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
1051
+
1052
+ return combined
1053
+
1054
+
1055
+ # ============================================================================
1056
+ # LLM Service Configuration
1057
+ # ============================================================================
1058
+
1059
+ def load_llm_config(config_path: str) -> Dict[str, Any]:
1060
+ """Load LLM configuration from YAML file."""
1061
+ with open(config_path, 'r', encoding='utf-8') as f:
1062
+ config = yaml.safe_load(f)
1063
+ return config
1064
+
1065
+
1066
+ def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
1067
+ """Create LLM service from configuration."""
1068
+ mode = config.get('mode', 'gpt').lower()
1069
+
1070
+ if mode == 'gpt':
1071
+ gpt_config = config.get('gpt', {})
1072
+ api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
1073
+ if not api_key:
1074
+ raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
1075
+
1076
+ service = GPTService(
1077
+ api_key=api_key,
1078
+ model_name=gpt_config.get('model_name', 'gpt-4o'),
1079
+ base_url=gpt_config.get('base_url'),
1080
+ timeout=gpt_config.get('timeout', 300)
1081
+ )
1082
+ return service
1083
+
1084
+ elif mode == 'vllm':
1085
+ vllm_config = config.get('vllm', {})
1086
+ service = VLLMService(
1087
+ base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
1088
+ api_key=vllm_config.get('api_key', 'dummy-key'),
1089
+ model_name=vllm_config.get('model_name'),
1090
+ timeout=vllm_config.get('timeout', 300),
1091
+ max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
1092
+ max_retries=vllm_config.get('max_retries', 3),
1093
+ retry_delay=vllm_config.get('retry_delay', 1.0),
1094
+ retry_backoff=vllm_config.get('retry_backoff', 2.0)
1095
+ )
1096
+ return service
1097
+
1098
+ else:
1099
+ raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
1100
+
1101
+
1102
+ # ============================================================================
1103
+ # Main Evaluation Functions
1104
+ # ============================================================================
1105
+
1106
+ def run_semantic_evaluation(
1107
+ evaluation_data: List[Dict[str, Any]],
1108
+ prompt_template: str,
1109
+ llm_service: LLMService,
1110
+ max_workers: int
1111
+ ) -> tuple:
1112
+ """Run semantic evaluation and return results and summary."""
1113
+ print(f"\n{'='*80}")
1114
+ print("RUNNING SEMANTIC EVALUATION")
1115
+ print(f"{'='*80}")
1116
+ print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
1117
+
1118
+ results = []
1119
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
1120
+ future_to_entry = {
1121
+ executor.submit(
1122
+ evaluate_review_semantic,
1123
+ entry,
1124
+ entry['paper_context'],
1125
+ prompt_template,
1126
+ llm_service
1127
+ ): entry
1128
+ for entry in evaluation_data
1129
+ }
1130
+
1131
+ for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
1132
+ try:
1133
+ result = future.result()
1134
+ results.append(result)
1135
+ except Exception as e:
1136
+ entry = future_to_entry[future]
1137
+ print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
1138
+ results.append({
1139
+ 'id': entry.get('id', 'unknown'),
1140
+ 'raw_scores': {},
1141
+ 'weighted_scores': {},
1142
+ 'total_score': 0.0,
1143
+ 'error': str(e),
1144
+ 'raw_response': ''
1145
+ })
1146
+
1147
+ # Calculate statistics
1148
+ valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
1149
+ review_scores = [r.get('total_score', 0.0) for r in valid_results]
1150
+
1151
+ summary = {
1152
+ 'total_entries': len(results),
1153
+ 'valid_entries': len(valid_results),
1154
+ 'failed_entries': len(results) - len(valid_results)
1155
+ }
1156
+
1157
+ if review_scores:
1158
+ summary['overall_score'] = {
1159
+ 'mean': sum(review_scores) / len(review_scores),
1160
+ 'min': min(review_scores),
1161
+ 'max': max(review_scores)
1162
+ }
1163
+
1164
+ # Calculate per-rubric statistics (extract rubric titles from first entry)
1165
+ if evaluation_data and evaluation_data[0].get('rubrics'):
1166
+ rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
1167
+ per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
1168
+ summary['per_rubric_statistics'] = per_rubric_stats
1169
+
1170
+ return results, summary
1171
+
1172
+
1173
+ def run_auto_metric_evaluation(
1174
+ evaluation_data: List[Dict[str, Any]],
1175
+ strict_mode: bool = False
1176
+ ) -> tuple:
1177
+ """
1178
+ Run auto-metric evaluation and return results and summary.
1179
+
1180
+ For refined format (has scores and initial_scores), evaluates both:
1181
+ - Refined scores evaluation
1182
+ - Initial scores evaluation
1183
+
1184
+ For original format (only initial_scores), evaluates:
1185
+ - Initial scores evaluation only
1186
+
1187
+ Returns:
1188
+ Tuple of (results_list, summary_dict)
1189
+ - results_list: List of evaluation results (may contain both refined and initial results for refined format)
1190
+ - summary_dict: Summary statistics
1191
+ """
1192
+ print(f"\n{'='*80}")
1193
+ print("RUNNING AUTO-METRIC EVALUATION")
1194
+ print(f"{'='*80}")
1195
+ print(f"Evaluating {len(evaluation_data)} entries...")
1196
+
1197
+ # Detect format types
1198
+ refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
1199
+ original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
1200
+
1201
+ if refined_format_count > 0:
1202
+ print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
1203
+ if original_format_count > 0:
1204
+ print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
1205
+
1206
+ results = []
1207
+ for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
1208
+ format_type = entry.get('format', 'unknown')
1209
+
1210
+ if format_type == 'refined':
1211
+ # Evaluate both refined scores and initial scores
1212
+ try:
1213
+ entry_id = entry.get('id', 'unknown')
1214
+
1215
+ # Evaluate refined scores
1216
+ refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1217
+ refined_result['paper_id'] = entry_id # Keep original paper_id
1218
+ refined_result['id'] = f"{entry_id}_refined"
1219
+ results.append(refined_result)
1220
+
1221
+ # Evaluate initial scores
1222
+ initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
1223
+ initial_result['paper_id'] = entry_id # Keep original paper_id
1224
+ initial_result['id'] = f"{entry_id}_initial"
1225
+ results.append(initial_result)
1226
+ except Exception as e:
1227
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1228
+ results.append({
1229
+ 'id': entry.get('id', 'unknown'),
1230
+ 'error': str(e)
1231
+ })
1232
+ else:
1233
+ # Evaluate initial scores only (or extract from markdown)
1234
+ try:
1235
+ result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1236
+ results.append(result)
1237
+ except Exception as e:
1238
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1239
+ results.append({
1240
+ 'id': entry.get('id', 'unknown'),
1241
+ 'error': str(e)
1242
+ })
1243
+
1244
+ # Calculate statistics
1245
+ valid_results = [r for r in results if 'error' not in r]
1246
+ mse_results = [r for r in valid_results if r.get('overall_error') is not None]
1247
+
1248
+ # Separate refined and initial results for refined format
1249
+ refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
1250
+ initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
1251
+ auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
1252
+
1253
+ summary = {
1254
+ 'total_entries': len(results),
1255
+ 'valid_entries': len(valid_results),
1256
+ 'mse_entries': len(mse_results),
1257
+ 'refined_results_count': len(refined_results),
1258
+ 'initial_results_count': len(initial_results),
1259
+ 'auto_results_count': len(auto_results)
1260
+ }
1261
+
1262
+ # Calculate MSE/MAE statistics
1263
+ # For refined format, only use refined results for overall statistics (avoid double counting)
1264
+ # For other formats, use all results
1265
+ if refined_format_count > 0:
1266
+ # Refined format: use only refined results for overall statistics
1267
+ stats_results = [r for r in refined_results if r.get('overall_error') is not None]
1268
+ else:
1269
+ # Original/other formats: use all results
1270
+ stats_results = mse_results
1271
+
1272
+ if stats_results:
1273
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1274
+ mse_stats = {}
1275
+ mae_stats = {}
1276
+
1277
+ for dim in dimensions:
1278
+ mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
1279
+ mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
1280
+
1281
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1282
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1283
+
1284
+ if mse_clean:
1285
+ mse_stats[dim] = {
1286
+ 'mean': sum(mse_clean) / len(mse_clean),
1287
+ 'count': len(mse_clean)
1288
+ }
1289
+ if mae_clean:
1290
+ mae_stats[dim] = {
1291
+ 'mean': sum(mae_clean) / len(mae_clean),
1292
+ 'count': len(mae_clean)
1293
+ }
1294
+
1295
+ overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
1296
+ overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
1297
+
1298
+ if overall_clean:
1299
+ summary['overall_error'] = {
1300
+ 'mean': sum(overall_clean) / len(overall_clean),
1301
+ 'count': len(overall_clean)
1302
+ }
1303
+
1304
+ summary['mse_statistics'] = mse_stats
1305
+ summary['mae_statistics'] = mae_stats
1306
+
1307
+ # Calculate separate statistics for refined and initial results
1308
+ if refined_results:
1309
+ refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
1310
+ if refined_mse_results:
1311
+ refined_mse_stats = {}
1312
+ refined_mae_stats = {}
1313
+ for dim in dimensions:
1314
+ mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
1315
+ mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
1316
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1317
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1318
+ if mse_clean:
1319
+ refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1320
+ if mae_clean:
1321
+ refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1322
+ summary['refined_mse_statistics'] = refined_mse_stats
1323
+ summary['refined_mae_statistics'] = refined_mae_stats
1324
+
1325
+ if initial_results:
1326
+ initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
1327
+ if initial_mse_results:
1328
+ initial_mse_stats = {}
1329
+ initial_mae_stats = {}
1330
+ for dim in dimensions:
1331
+ mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
1332
+ mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
1333
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1334
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1335
+ if mse_clean:
1336
+ initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1337
+ if mae_clean:
1338
+ initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1339
+ summary['initial_mse_statistics'] = initial_mse_stats
1340
+ summary['initial_mae_statistics'] = initial_mae_stats
1341
+
1342
+ # Calculate Spearman correlations
1343
+ def filter_valid_pairs(true_list, pred_list):
1344
+ filtered_true = []
1345
+ filtered_pred = []
1346
+ for t, p in zip(true_list, pred_list):
1347
+ if (t is not None and p is not None and
1348
+ not (isinstance(t, float) and math.isnan(t)) and
1349
+ not (isinstance(p, float) and math.isnan(p))):
1350
+ filtered_true.append(t)
1351
+ filtered_pred.append(p)
1352
+ return filtered_true, filtered_pred
1353
+
1354
+ # Calculate Spearman correlations
1355
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1356
+ # For other formats, use all results
1357
+ if refined_format_count > 0:
1358
+ # Calculate refined spearman correlations
1359
+ refined_spearman_stats = {}
1360
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1361
+ for dim in dimensions:
1362
+ true_values = [r.get(f'gt_{dim}') for r in refined_results]
1363
+ pred_values = [r.get(f'model_{dim}') for r in refined_results]
1364
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1365
+
1366
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1367
+ try:
1368
+ corr, _ = spearmanr(true_clean, pred_clean)
1369
+ if not math.isnan(corr):
1370
+ refined_spearman_stats[dim] = {
1371
+ 'correlation': corr,
1372
+ 'count': len(true_clean)
1373
+ }
1374
+ except Exception:
1375
+ pass
1376
+
1377
+ # Calculate initial spearman correlations
1378
+ initial_spearman_stats = {}
1379
+ for dim in dimensions:
1380
+ true_values = [r.get(f'gt_{dim}') for r in initial_results]
1381
+ pred_values = [r.get(f'model_{dim}') for r in initial_results]
1382
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1383
+
1384
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1385
+ try:
1386
+ corr, _ = spearmanr(true_clean, pred_clean)
1387
+ if not math.isnan(corr):
1388
+ initial_spearman_stats[dim] = {
1389
+ 'correlation': corr,
1390
+ 'count': len(true_clean)
1391
+ }
1392
+ except Exception:
1393
+ pass
1394
+
1395
+ # Use refined for overall statistics (avoid double counting)
1396
+ summary['spearman_correlations'] = refined_spearman_stats
1397
+ summary['refined_spearman_correlations'] = refined_spearman_stats
1398
+ summary['initial_spearman_correlations'] = initial_spearman_stats
1399
+ else:
1400
+ # Original/other formats: use all results
1401
+ correlation_results = valid_results
1402
+ spearman_stats = {}
1403
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1404
+ for dim in dimensions:
1405
+ true_values = [r.get(f'gt_{dim}') for r in correlation_results]
1406
+ pred_values = [r.get(f'model_{dim}') for r in correlation_results]
1407
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1408
+
1409
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1410
+ try:
1411
+ corr, _ = spearmanr(true_clean, pred_clean)
1412
+ if not math.isnan(corr):
1413
+ spearman_stats[dim] = {
1414
+ 'correlation': corr,
1415
+ 'count': len(true_clean)
1416
+ }
1417
+ except Exception:
1418
+ pass
1419
+
1420
+ summary['spearman_correlations'] = spearman_stats
1421
+
1422
+ # Calculate Decision metrics
1423
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1424
+ # For other formats, use all results
1425
+ if refined_format_count > 0:
1426
+ # Calculate refined decision metrics
1427
+ refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1428
+ if refined_decision_results:
1429
+ true_decisions = []
1430
+ pred_decisions = []
1431
+ decision_acc = []
1432
+
1433
+ for r in refined_decision_results:
1434
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1435
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1436
+
1437
+ if 'accept' in pred_decision:
1438
+ pred_binary = 1
1439
+ else:
1440
+ pred_binary = 0
1441
+
1442
+ if 'accept' in gt_decision:
1443
+ gt_binary = 1
1444
+ else:
1445
+ gt_binary = 0
1446
+
1447
+ true_decisions.append(gt_binary)
1448
+ pred_decisions.append(pred_binary)
1449
+
1450
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1451
+ decision_acc.append(1.0)
1452
+ else:
1453
+ decision_acc.append(0.0)
1454
+
1455
+ if decision_acc:
1456
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1457
+ try:
1458
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1459
+ refined_decision_metrics = {
1460
+ 'accuracy': decision_accuracy,
1461
+ 'f1_macro': f1_score,
1462
+ 'count': len(decision_acc)
1463
+ }
1464
+ except Exception:
1465
+ refined_decision_metrics = {
1466
+ 'accuracy': decision_accuracy,
1467
+ 'count': len(decision_acc)
1468
+ }
1469
+ summary['refined_decision_metrics'] = refined_decision_metrics
1470
+ summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
1471
+
1472
+ # Calculate initial decision metrics
1473
+ initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1474
+ if initial_decision_results:
1475
+ true_decisions = []
1476
+ pred_decisions = []
1477
+ decision_acc = []
1478
+
1479
+ for r in initial_decision_results:
1480
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1481
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1482
+
1483
+ if 'accept' in pred_decision:
1484
+ pred_binary = 1
1485
+ else:
1486
+ pred_binary = 0
1487
+
1488
+ if 'accept' in gt_decision:
1489
+ gt_binary = 1
1490
+ else:
1491
+ gt_binary = 0
1492
+
1493
+ true_decisions.append(gt_binary)
1494
+ pred_decisions.append(pred_binary)
1495
+
1496
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1497
+ decision_acc.append(1.0)
1498
+ else:
1499
+ decision_acc.append(0.0)
1500
+
1501
+ if decision_acc:
1502
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1503
+ try:
1504
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1505
+ initial_decision_metrics = {
1506
+ 'accuracy': decision_accuracy,
1507
+ 'f1_macro': f1_score,
1508
+ 'count': len(decision_acc)
1509
+ }
1510
+ except Exception:
1511
+ initial_decision_metrics = {
1512
+ 'accuracy': decision_accuracy,
1513
+ 'count': len(decision_acc)
1514
+ }
1515
+ summary['initial_decision_metrics'] = initial_decision_metrics
1516
+ else:
1517
+ # Original/other formats: use all results
1518
+ decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1519
+ if decision_results:
1520
+ true_decisions = []
1521
+ pred_decisions = []
1522
+ decision_acc = []
1523
+
1524
+ for r in decision_results:
1525
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1526
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1527
+
1528
+ if 'accept' in pred_decision:
1529
+ pred_binary = 1
1530
+ else:
1531
+ pred_binary = 0
1532
+
1533
+ if 'accept' in gt_decision:
1534
+ gt_binary = 1
1535
+ else:
1536
+ gt_binary = 0
1537
+
1538
+ true_decisions.append(gt_binary)
1539
+ pred_decisions.append(pred_binary)
1540
+
1541
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1542
+ decision_acc.append(1.0)
1543
+ else:
1544
+ decision_acc.append(0.0)
1545
+
1546
+ if decision_acc:
1547
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1548
+ try:
1549
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1550
+ summary['decision_metrics'] = {
1551
+ 'accuracy': decision_accuracy,
1552
+ 'f1_macro': f1_score,
1553
+ 'count': len(decision_acc)
1554
+ }
1555
+ except Exception:
1556
+ summary['decision_metrics'] = {
1557
+ 'accuracy': decision_accuracy,
1558
+ 'count': len(decision_acc)
1559
+ }
1560
+
1561
+ # Calculate Pairwise comparison
1562
+ # For refined format, only use refined results (avoid double counting)
1563
+ # For other formats, use all results
1564
+ if refined_format_count > 0:
1565
+ pairwise_results = refined_results
1566
+ else:
1567
+ pairwise_results = valid_results
1568
+
1569
+ paper_scores = []
1570
+ for r in pairwise_results:
1571
+ if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
1572
+ (r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
1573
+ paper_scores.append({
1574
+ 'true_rating': r.get('gt_rating'),
1575
+ 'pred_rating': r.get('model_rating'),
1576
+ 'true_soundness': r.get('gt_soundness'),
1577
+ 'pred_soundness': r.get('model_soundness'),
1578
+ 'true_presentation': r.get('gt_presentation'),
1579
+ 'pred_presentation': r.get('model_presentation'),
1580
+ 'true_confidence': r.get('gt_confidence'),
1581
+ 'pred_confidence': r.get('model_confidence')
1582
+ })
1583
+
1584
+ if len(paper_scores) >= 2:
1585
+ pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
1586
+ summary['pairwise_accuracies'] = pairwise_accuracies
1587
+
1588
+ return results, summary
1589
+
1590
+
1591
+ # ============================================================================
1592
+ # Main Function
1593
+ # ============================================================================
1594
+
1595
+ def parse_args():
1596
+ """Parse command line arguments."""
1597
+ parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
1598
+
1599
+ # Input paths
1600
+ parser.add_argument("--rubrics_path", type=str, required=True,
1601
+ help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
1602
+ parser.add_argument("--reviews_path", type=str, required=True,
1603
+ help="Path to JSON file with model reviews (contains pred_fast_mode)")
1604
+
1605
+ # Evaluation mode
1606
+ parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
1607
+ help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
1608
+
1609
+ # Output paths
1610
+ parser.add_argument("--semantic_output", type=str, default=None,
1611
+ help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
1612
+ parser.add_argument("--auto_metric_output", type=str, default=None,
1613
+ help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
1614
+
1615
+ # Semantic evaluation settings
1616
+ parser.add_argument("--yaml_path", type=str, default=None,
1617
+ help="Path to prompts.yaml file (required for semantic evaluation)")
1618
+ parser.add_argument("--config_path", type=str, default=None,
1619
+ help="Path to configs.yaml file (required for semantic evaluation)")
1620
+
1621
+ # Multi-threading
1622
+ parser.add_argument("--max_workers", type=int, default=None,
1623
+ help="Maximum number of worker threads for semantic evaluation (default: 5)")
1624
+
1625
+ # Strict mode (normalize scores to discrete scales)
1626
+ parser.add_argument("--strict_mode", action="store_true", default=False,
1627
+ help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
1628
+
1629
+ # Input format override
1630
+ parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
1631
+ help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
1632
+
1633
+ return parser.parse_args()
1634
+
1635
+
1636
+ def main():
1637
+ """Main execution function."""
1638
+ args = parse_args()
1639
+
1640
+ script_dir = os.path.dirname(os.path.abspath(__file__))
1641
+
1642
+ # Resolve paths
1643
+ rubrics_path = args.rubrics_path
1644
+ if not os.path.isabs(rubrics_path):
1645
+ rubrics_path = os.path.join(script_dir, rubrics_path)
1646
+
1647
+ reviews_path = args.reviews_path
1648
+ if not os.path.isabs(reviews_path):
1649
+ reviews_path = os.path.join(script_dir, reviews_path)
1650
+
1651
+ max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
1652
+
1653
+ # Validate mode and output paths
1654
+ if args.mode in ["semantic", "both"]:
1655
+ if not args.semantic_output:
1656
+ raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
1657
+ if not args.yaml_path:
1658
+ raise ValueError("--yaml_path is required for semantic evaluation")
1659
+ if not args.config_path:
1660
+ raise ValueError("--config_path is required for semantic evaluation")
1661
+
1662
+ if args.mode in ["auto_metric", "both"]:
1663
+ if not args.auto_metric_output:
1664
+ raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
1665
+
1666
+ # Check if files exist
1667
+ if not os.path.exists(rubrics_path):
1668
+ raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
1669
+ if not os.path.exists(reviews_path):
1670
+ raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
1671
+
1672
+ # Load data
1673
+ print(f"Loading rubrics from {rubrics_path}...")
1674
+ rubrics_data = load_rubrics_json(rubrics_path)
1675
+ print(f"Loaded {len(rubrics_data)} rubrics entries")
1676
+
1677
+ print(f"Loading model reviews from {reviews_path}...")
1678
+ if args.input_format != 'auto':
1679
+ print(f"Using manually specified format: {args.input_format}")
1680
+ else:
1681
+ print("Auto-detecting input format...")
1682
+ reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
1683
+ print(f"Loaded {len(reviews_dict)} model reviews")
1684
+
1685
+ # Combine rubrics and reviews
1686
+ print("Combining rubrics and reviews...")
1687
+ evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
1688
+ print(f"Prepared {len(evaluation_data)} entries for evaluation")
1689
+
1690
+ # Run evaluations based on mode
1691
+ if args.mode in ["semantic", "both"]:
1692
+ # Resolve semantic evaluation paths
1693
+ yaml_path = args.yaml_path
1694
+ if not os.path.isabs(yaml_path):
1695
+ yaml_path = os.path.join(script_dir, yaml_path)
1696
+
1697
+ config_path = args.config_path
1698
+ if not os.path.isabs(config_path):
1699
+ config_path = os.path.join(script_dir, config_path)
1700
+
1701
+ if not os.path.exists(yaml_path):
1702
+ raise FileNotFoundError(f"YAML file not found: {yaml_path}")
1703
+ if not os.path.exists(config_path):
1704
+ raise FileNotFoundError(f"Config file not found: {config_path}")
1705
+
1706
+ # Load prompt template
1707
+ print(f"Loading prompt template from {yaml_path}...")
1708
+ prompt_template = load_prompt_template(yaml_path)
1709
+ if not prompt_template:
1710
+ raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
1711
+
1712
+ # Initialize LLM service
1713
+ print(f"Loading LLM configuration from {config_path}...")
1714
+ llm_config = load_llm_config(config_path)
1715
+ llm_service = create_llm_service_from_config(llm_config)
1716
+ mode = llm_config.get('mode', 'gpt')
1717
+ print(f"LLM service initialized (mode: {mode})")
1718
+ if hasattr(llm_service, 'model_name'):
1719
+ print(f"Using model: {llm_service.model_name}")
1720
+
1721
+ # Run semantic evaluation
1722
+ semantic_results, semantic_summary = run_semantic_evaluation(
1723
+ evaluation_data, prompt_template, llm_service, max_workers
1724
+ )
1725
+
1726
+ # Save semantic results
1727
+ semantic_output = args.semantic_output
1728
+ if not os.path.isabs(semantic_output):
1729
+ semantic_output = os.path.join(script_dir, semantic_output)
1730
+
1731
+ output_dir = os.path.dirname(semantic_output)
1732
+ os.makedirs(output_dir, exist_ok=True)
1733
+
1734
+ with open(semantic_output, 'w', encoding='utf-8') as f:
1735
+ json.dump(semantic_results, f, ensure_ascii=False, indent=2)
1736
+ print(f"\nSemantic evaluation results saved to {semantic_output}")
1737
+
1738
+ # Save semantic summary
1739
+ semantic_summary_path = semantic_output.replace('.json', '_summary.json')
1740
+ with open(semantic_summary_path, 'w', encoding='utf-8') as f:
1741
+ json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
1742
+ print(f"Semantic evaluation summary saved to {semantic_summary_path}")
1743
+
1744
+ # Print semantic summary
1745
+ print("\n" + "="*80)
1746
+ print("SEMANTIC EVALUATION SUMMARY")
1747
+ print("="*80)
1748
+ print(f"Total entries: {semantic_summary['total_entries']}")
1749
+ print(f"Valid entries: {semantic_summary['valid_entries']}")
1750
+ print(f"Failed entries: {semantic_summary['failed_entries']}")
1751
+ if 'overall_score' in semantic_summary:
1752
+ score = semantic_summary['overall_score']
1753
+ print(f"\nOverall Score:")
1754
+ print(f" Mean: {score['mean']:.2f}")
1755
+ print(f" Min: {score['min']:.2f}")
1756
+ print(f" Max: {score['max']:.2f}")
1757
+
1758
+ if args.mode in ["auto_metric", "both"]:
1759
+ # Run auto-metric evaluation
1760
+ auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
1761
+ evaluation_data,
1762
+ strict_mode=args.strict_mode
1763
+ )
1764
+
1765
+ # Save auto-metric results
1766
+ auto_metric_output = args.auto_metric_output
1767
+ if not os.path.isabs(auto_metric_output):
1768
+ auto_metric_output = os.path.join(script_dir, auto_metric_output)
1769
+
1770
+ output_dir = os.path.dirname(auto_metric_output)
1771
+ os.makedirs(output_dir, exist_ok=True)
1772
+
1773
+ with open(auto_metric_output, 'w', encoding='utf-8') as f:
1774
+ json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
1775
+ print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
1776
+
1777
+ # Save auto-metric summary
1778
+ auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
1779
+ with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
1780
+ json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
1781
+ print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
1782
+
1783
+ # Print auto-metric summary
1784
+ print("\n" + "="*80)
1785
+ print("AUTO-METRIC EVALUATION SUMMARY")
1786
+ print("="*80)
1787
+ print(f"Total entries: {auto_metric_summary['total_entries']}")
1788
+ print(f"Valid entries: {auto_metric_summary['valid_entries']}")
1789
+ print(f"MSE entries: {auto_metric_summary['mse_entries']}")
1790
+
1791
+ if 'mse_statistics' in auto_metric_summary:
1792
+ print("\nMSE Statistics:")
1793
+ for dim, stats in auto_metric_summary['mse_statistics'].items():
1794
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1795
+
1796
+ if 'mae_statistics' in auto_metric_summary:
1797
+ print("\nMAE Statistics:")
1798
+ for dim, stats in auto_metric_summary['mae_statistics'].items():
1799
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1800
+
1801
+ # Print refined and initial statistics if available
1802
+ if 'refined_mse_statistics' in auto_metric_summary:
1803
+ print("\nRefined Scores - MSE Statistics:")
1804
+ for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
1805
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1806
+
1807
+ if 'refined_mae_statistics' in auto_metric_summary:
1808
+ print("\nRefined Scores - MAE Statistics:")
1809
+ for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
1810
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1811
+
1812
+ if 'initial_mse_statistics' in auto_metric_summary:
1813
+ print("\nInitial Scores - MSE Statistics:")
1814
+ for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
1815
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1816
+
1817
+ if 'initial_mae_statistics' in auto_metric_summary:
1818
+ print("\nInitial Scores - MAE Statistics:")
1819
+ for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
1820
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1821
+
1822
+ if 'spearman_correlations' in auto_metric_summary:
1823
+ print("\nSpearman Correlations:")
1824
+ for dim, stats in auto_metric_summary['spearman_correlations'].items():
1825
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1826
+
1827
+ # Print refined and initial spearman correlations if available
1828
+ if 'refined_spearman_correlations' in auto_metric_summary:
1829
+ print("\nRefined Scores - Spearman Correlations:")
1830
+ for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
1831
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1832
+
1833
+ if 'initial_spearman_correlations' in auto_metric_summary:
1834
+ print("\nInitial Scores - Spearman Correlations:")
1835
+ for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
1836
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1837
+
1838
+ if 'decision_metrics' in auto_metric_summary:
1839
+ dm = auto_metric_summary['decision_metrics']
1840
+ print(f"\nDecision Metrics:")
1841
+ print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
1842
+ if 'f1_macro' in dm:
1843
+ print(f" F1 (macro): {dm['f1_macro']:.4f}")
1844
+
1845
+ # Print refined and initial decision metrics if available
1846
+ if 'refined_decision_metrics' in auto_metric_summary:
1847
+ print("\nRefined Scores - Decision Metrics:")
1848
+ rdm = auto_metric_summary['refined_decision_metrics']
1849
+ print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
1850
+ if 'f1_macro' in rdm:
1851
+ print(f" F1 (macro): {rdm['f1_macro']:.4f}")
1852
+
1853
+ if 'initial_decision_metrics' in auto_metric_summary:
1854
+ print("\nInitial Scores - Decision Metrics:")
1855
+ idm = auto_metric_summary['initial_decision_metrics']
1856
+ print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
1857
+ if 'f1_macro' in idm:
1858
+ print(f" F1 (macro): {idm['f1_macro']:.4f}")
1859
+
1860
+ print("\n" + "="*80)
1861
+ print("EVALUATION COMPLETE")
1862
+ print("="*80)
1863
+
1864
+
1865
+ if __name__ == "__main__":
1866
+ main()
src/evaluator/2_evaluate_cyclereviewer.py ADDED
@@ -0,0 +1,1837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluation script for semantic (LLM-based) and auto_metric (rule-based) evaluation.
3
+
4
+ This script:
5
+ 1. Reads eval_rubrics.json (from 1_generate_review_based_rubrics.py) containing rubrics for each paper
6
+ 2. Reads input JSON file containing model reviews (supports multiple formats)
7
+ 3. Supports three evaluation modes:
8
+ - semantic: LLM-based rubrics evaluation (from 2_evaluate_direct.py)
9
+ - auto_metric: Rule-based metrics evaluation (from 3_rule_evaluate.py)
10
+ - both: Run both evaluations separately
11
+ 4. Supports strict mode: normalize scores to discrete scales before computing metrics (--strict_mode)
12
+ 5. Outputs separate JSON files for results and summaries
13
+
14
+ Usage:
15
+ # Semantic evaluation only
16
+ python 2_evaluate.py \
17
+ --rubrics_path eval_rubrics.json \
18
+ --reviews_path model_reviews.json \
19
+ --mode semantic \
20
+ --yaml_path prompts.yaml \
21
+ --config_path configs.yaml \
22
+ --semantic_output semantic_results.json \
23
+ --max_workers 5
24
+
25
+ # Auto-metric evaluation only
26
+ python 2_evaluate.py \
27
+ --rubrics_path eval_rubrics.json \
28
+ --reviews_path model_reviews.json \
29
+ --mode auto_metric \
30
+ --auto_metric_output auto_metric_results.json
31
+
32
+ # Auto-metric evaluation with strict mode (normalize scores to discrete scales)
33
+ python 2_evaluate.py \
34
+ --rubrics_path eval_rubrics.json \
35
+ --reviews_path model_reviews.json \
36
+ --mode auto_metric \
37
+ --auto_metric_output auto_metric_results.json \
38
+ --strict_mode
39
+
40
+ # Auto-metric evaluation with manually specified input format (refined)
41
+ python 2_evaluate.py \
42
+ --rubrics_path eval_rubrics.json \
43
+ --reviews_path model_reviews.json \
44
+ --mode auto_metric \
45
+ --auto_metric_output auto_metric_results.json \
46
+ --input_format refined
47
+
48
+ # Auto-metric evaluation with manually specified input format (original)
49
+ python 2_evaluate.py \
50
+ --rubrics_path eval_rubrics.json \
51
+ --reviews_path ours.json \
52
+ --mode auto_metric \
53
+ --auto_metric_output auto_metric_results.json \
54
+ --input_format original
55
+
56
+ # Both evaluations
57
+ python 2_evaluate.py \
58
+ --rubrics_path eval_rubrics.json \
59
+ --reviews_path model_reviews.json \
60
+ --mode both \
61
+ --yaml_path prompts.yaml \
62
+ --config_path configs.yaml \
63
+ --semantic_output semantic_results.json \
64
+ --auto_metric_output auto_metric_results.json \
65
+ --max_workers 32
66
+ """
67
+ from __future__ import annotations
68
+
69
+ import json
70
+ import os
71
+ import sys
72
+ import argparse
73
+ import yaml
74
+ import math
75
+ import re
76
+ from typing import Dict, List, Any, Optional
77
+ from concurrent.futures import ThreadPoolExecutor, as_completed
78
+ from tqdm import tqdm
79
+ from itertools import combinations
80
+ from scipy.stats import spearmanr
81
+ from sklearn.metrics import precision_recall_fscore_support
82
+
83
+ # Add parent directory to path
84
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
85
+ # Import parse_llm_response from local llm_service module
86
+ import llm_service as local_llm_service
87
+ parse_llm_response = local_llm_service.parse_llm_response
88
+
89
+ # Import from shared/utils for gpt/vllm support
90
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
91
+ if project_root not in sys.path:
92
+ sys.path.insert(0, project_root)
93
+
94
+ from shared.utils.llm_service import LLMService
95
+ from shared.utils.vllm_service import VLLMService
96
+ from shared.utils.gpt_service import GPTService
97
+ sys.path.insert(0, os.path.join(project_root, 'shared', 'utils'))
98
+ from json_parser import parse_review_markdown
99
+
100
+ def convert_cyclereviewer(review_text: str) -> tuple:
101
+ """
102
+ Convert the review text from cyclereviewer format to unified review system format.
103
+
104
+ The cyclereviewer format has markdown sections like:
105
+ "## Rating\n\n3: reject, not good enough\n\n## Confidence\n\n4: You are confident...\n\n"
106
+
107
+ Args:
108
+ review_text: Raw review text string (markdown format)
109
+
110
+ Returns:
111
+ Tuple of (formatted_review_text, meta_review_dict)
112
+ """
113
+ # Use parse_review_markdown to extract scores from markdown sections
114
+ parsed = {}
115
+ try:
116
+ parsed = parse_review_markdown(review_text)
117
+ except Exception:
118
+ pass
119
+
120
+ # Extract rating - can be from "## Rating\n\n3: reject..." or "## Score: 6: ..."
121
+ rating = parsed.get('rating')
122
+ if rating is None:
123
+ # Try to extract from "## Rating\n\n3: reject..." format
124
+ rating_match = re.search(r'##\s*Rating\s*\n\n\s*(\d+\.?\d*)\s*:', review_text, re.IGNORECASE | re.MULTILINE)
125
+ if rating_match:
126
+ try:
127
+ rating = float(rating_match.group(1))
128
+ except (ValueError, IndexError):
129
+ pass
130
+
131
+ # Try "## Score: 6: ..." format
132
+ if rating is None:
133
+ score_match = re.search(r'##\s*Score\s*:\s*(\d+\.?\d*)\s*:', review_text, re.IGNORECASE | re.MULTILINE)
134
+ if score_match:
135
+ try:
136
+ rating = float(score_match.group(1))
137
+ except (ValueError, IndexError):
138
+ pass
139
+
140
+ # Extract confidence
141
+ confidence = parsed.get('confidence')
142
+ if confidence is None:
143
+ # Try to extract from "## Confidence\n\n4: You are confident..." format
144
+ confidence_match = re.search(r'##\s*Confidence\s*\n\n\s*(\d+\.?\d*)\s*:', review_text, re.IGNORECASE | re.MULTILINE)
145
+ if confidence_match:
146
+ try:
147
+ confidence = float(confidence_match.group(1))
148
+ except (ValueError, IndexError):
149
+ pass
150
+
151
+ # Extract decision from rating text (e.g., "3: reject, not good enough")
152
+ decision = None
153
+ if rating is not None:
154
+ # Look for decision in rating section
155
+ rating_section_match = re.search(r'##\s*Rating\s*\n\n(.*?)(?=\n##|$)', review_text, re.IGNORECASE | re.DOTALL)
156
+ if rating_section_match:
157
+ rating_content = rating_section_match.group(1)
158
+ # Extract decision from text like "3: reject, not good enough"
159
+ decision_match = re.search(r':\s*(accept|reject|undecided)', rating_content, re.IGNORECASE)
160
+ if decision_match:
161
+ decision = decision_match.group(1).lower()
162
+
163
+ # Also try Score section
164
+ if decision is None:
165
+ score_section_match = re.search(r'##\s*Score\s*:\s*\d+\s*:\s*(.*?)(?=\n##|$)', review_text, re.IGNORECASE | re.DOTALL)
166
+ if score_section_match:
167
+ score_content = score_section_match.group(1)
168
+ decision_match = re.search(r'(accept|reject|undecided)', score_content, re.IGNORECASE)
169
+ if decision_match:
170
+ decision = decision_match.group(1).lower()
171
+
172
+ # Extract soundness, presentation from parsed data or markdown
173
+ soundness = parsed.get('soundness')
174
+ presentation = parsed.get('presentation')
175
+ contribution = parsed.get('contribution')
176
+
177
+ # Create meta_review dict
178
+ meta_review = {
179
+ "rating": rating,
180
+ "soundness": soundness,
181
+ "presentation": presentation,
182
+ "contribution": contribution,
183
+ "confidence": confidence,
184
+ "decision": decision,
185
+ }
186
+
187
+ # Return the review text as-is (it's already in markdown format)
188
+ return review_text, meta_review
189
+
190
+
191
+ class ReviewProcessor:
192
+ """Handles the extraction and processing of reviews from different sources."""
193
+
194
+ @staticmethod
195
+ def extract_review_content(pred_context):
196
+ """
197
+ Extract the review content from the prediction context.
198
+
199
+ Args:
200
+ pred_context: Raw prediction data that contains the review
201
+
202
+ Returns:
203
+ str: Extracted review content
204
+ """
205
+ try:
206
+ # First attempt to extract from boxed format
207
+ return pred_context.split(r'\boxed_review{')[-1].split('\n}')[0]
208
+ except Exception:
209
+ # Alternative extraction if the first method fails
210
+ if isinstance(pred_context, dict) and 'output' in pred_context:
211
+ return pred_context['output'].split(r'\boxed_review{')[-1].split('\n}')[0]
212
+ else:
213
+ # Return as is if extraction fails
214
+ return pred_context
215
+
216
+
217
+ # ============================================================================
218
+ # Semantic Evaluation Functions (from 2_evaluate_direct.py)
219
+ # ============================================================================
220
+
221
+ def load_prompt_template(yaml_path: str) -> str:
222
+ """Load the evaluator prompt from YAML file."""
223
+ with open(yaml_path, 'r', encoding='utf-8') as f:
224
+ prompts = yaml.safe_load(f)
225
+ return prompts.get('v1_evaluator_prompt', '')
226
+
227
+
228
+ def build_evaluation_prompt(
229
+ rubrics: List[Dict[str, Any]],
230
+ paper_content: str,
231
+ review: str,
232
+ prompt_template: str
233
+ ) -> str:
234
+ """Build the evaluation prompt by replacing placeholders."""
235
+ rubrics_json = json.dumps(rubrics, indent=4, ensure_ascii=False)
236
+ prompt = prompt_template.replace('{rubrics_json}', rubrics_json)
237
+ prompt = prompt.replace('<<paper_content>>', paper_content)
238
+ prompt = prompt.replace('<<review>>', review)
239
+ return prompt
240
+
241
+
242
+ def calculate_weighted_scores(
243
+ raw_scores: Dict[str, Dict[str, Any]],
244
+ rubrics: List[Dict[str, Any]]
245
+ ) -> Dict[str, float]:
246
+ """Calculate weighted scores for each rubric."""
247
+ rubric_weights = {r['title']: r['weight'] for r in rubrics}
248
+ weighted_scores = {}
249
+
250
+ for rubric_title, rubric_data in raw_scores.items():
251
+ if rubric_title not in rubric_weights:
252
+ continue
253
+
254
+ rubric_score = rubric_data.get('score', 0)
255
+ if isinstance(rubric_score, str):
256
+ try:
257
+ rubric_score = int(rubric_score)
258
+ except ValueError:
259
+ rubric_score = 0
260
+
261
+ if rubric_score not in [0, 1]:
262
+ rubric_score = 1 if rubric_score > 0 else 0
263
+
264
+ weight = rubric_weights[rubric_title]
265
+ weighted_scores[rubric_title] = rubric_score * weight
266
+
267
+ return weighted_scores
268
+
269
+
270
+ def calculate_scores(raw_scores: Dict[str, Dict[str, Any]]) -> Dict[str, float]:
271
+ """Calculate scores for each rubric."""
272
+ scores = {}
273
+ for rubric_title, rubric_data in raw_scores.items():
274
+ scores[rubric_title] = rubric_data.get('score', 0)
275
+ return scores
276
+
277
+
278
+ def evaluate_review_semantic(
279
+ entry: Dict[str, Any],
280
+ paper_content: str,
281
+ prompt_template: str,
282
+ llm_service: LLMService
283
+ ) -> Dict[str, Any]:
284
+ """Evaluate a single review using article-specific rubrics."""
285
+ entry_id = entry.get('id', 'unknown')
286
+ rubrics = entry.get('rubrics', [])
287
+ model_review = entry.get('model_review', '')
288
+
289
+ if not rubrics:
290
+ return {
291
+ 'id': entry_id,
292
+ 'raw_scores': {},
293
+ 'weighted_scores': {},
294
+ 'total_score': 0.0,
295
+ 'error': 'No valid rubrics found',
296
+ 'raw_response': ''
297
+ }
298
+
299
+ # Build prompt
300
+ prompt = build_evaluation_prompt(rubrics, paper_content, model_review, prompt_template)
301
+
302
+ # Call LLM
303
+ try:
304
+ messages = [{"role": "user", "content": prompt}]
305
+ response = llm_service.generate(messages=messages)
306
+
307
+ # Parse response
308
+ raw_scores = parse_llm_response(response)
309
+ weighted_scores = calculate_scores(raw_scores)
310
+ total_score = sum(weighted_scores.values())
311
+
312
+ return {
313
+ 'id': entry_id,
314
+ 'raw_scores': raw_scores,
315
+ 'weighted_scores': weighted_scores,
316
+ 'total_score': total_score,
317
+ 'raw_response': response
318
+ }
319
+ except Exception as e:
320
+ print(f"[ERROR] Error evaluating review {entry_id}: {e}")
321
+ return {
322
+ 'id': entry_id,
323
+ 'raw_scores': {},
324
+ 'weighted_scores': {},
325
+ 'total_score': 0.0,
326
+ 'error': str(e),
327
+ 'raw_response': ''
328
+ }
329
+
330
+
331
+ def calculate_per_rubric_statistics(
332
+ valid_results: List[Dict[str, Any]],
333
+ rubric_titles: List[str]
334
+ ) -> Dict[str, Dict[str, float]]:
335
+ """Calculate per-rubric statistics from evaluation results."""
336
+ rubric_scores = {title: [] for title in rubric_titles}
337
+
338
+ for result in valid_results:
339
+ weighted_scores = result.get('weighted_scores', {})
340
+ if not isinstance(weighted_scores, dict):
341
+ continue
342
+
343
+ for rubric_title in rubric_titles:
344
+ if rubric_title in weighted_scores:
345
+ score = weighted_scores[rubric_title]
346
+ if isinstance(score, str):
347
+ try:
348
+ score = float(score)
349
+ except ValueError:
350
+ continue
351
+ elif isinstance(score, (int, float)):
352
+ score = float(score)
353
+ else:
354
+ continue
355
+ rubric_scores[rubric_title].append(score)
356
+
357
+ per_rubric_stats = {}
358
+ for rubric_title in rubric_titles:
359
+ scores = rubric_scores[rubric_title]
360
+ if not scores:
361
+ continue
362
+
363
+ mean_score = sum(scores) / len(scores)
364
+ min_score = min(scores)
365
+ max_score = max(scores)
366
+ count = len(scores)
367
+
368
+ if rubric_title == "False or Contradictory Claims":
369
+ pass_count = sum(1 for s in scores if s >= 0)
370
+ else:
371
+ pass_count = sum(1 for s in scores if s >= 1)
372
+ pass_rate = pass_count / count if count > 0 else 0.0
373
+
374
+ per_rubric_stats[rubric_title] = {
375
+ 'mean': mean_score,
376
+ 'min': min_score,
377
+ 'max': max_score,
378
+ 'count': count,
379
+ 'pass_rate': pass_rate
380
+ }
381
+
382
+ return per_rubric_stats
383
+
384
+
385
+ # ============================================================================
386
+ # Auto-Metric Evaluation Functions (from 3_rule_evaluate.py)
387
+ # ============================================================================
388
+
389
+ def extract_scores_from_review(review_text: str) -> Dict[str, Any]:
390
+ """Extract numeric scores and decision from a review markdown text."""
391
+ if not review_text:
392
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
393
+
394
+ try:
395
+ parsed = parse_review_markdown(review_text)
396
+ decision = parsed.get('decision', '')
397
+ if decision:
398
+ decision_lower = decision.lower().strip()
399
+ if 'accept' in decision_lower:
400
+ decision = 'accept'
401
+ elif 'reject' in decision_lower:
402
+ decision = 'reject'
403
+ elif 'undecided' in decision_lower:
404
+ decision = 'undecided'
405
+ else:
406
+ decision = decision_lower
407
+ else:
408
+ decision = None
409
+
410
+ return {
411
+ 'soundness': parsed.get('soundness'),
412
+ 'presentation': parsed.get('presentation'),
413
+ 'rating': parsed.get('rating'),
414
+ 'confidence': parsed.get('confidence'),
415
+ 'decision': decision
416
+ }
417
+ except Exception as e:
418
+ print(f"Warning: Failed to parse review text: {e}")
419
+ return {'soundness': None, 'presentation': None, 'rating': None, 'confidence': None, 'decision': None}
420
+
421
+
422
+ def calculate_mse(predicted: float, ground_truth: float) -> Optional[float]:
423
+ """Calculate Mean Squared Error for a single value."""
424
+ if predicted is None or ground_truth is None:
425
+ return None
426
+ return (predicted - ground_truth) ** 2
427
+
428
+
429
+ def calculate_mae(predicted: float, ground_truth: float) -> Optional[float]:
430
+ """Calculate Mean Absolute Error for a single value."""
431
+ if predicted is None or ground_truth is None:
432
+ return None
433
+ return abs(predicted - ground_truth)
434
+
435
+
436
+ def normalize_to_discrete_scale(score: Optional[float], scale_type: str) -> Optional[float]:
437
+ """
438
+ Normalize a float score to the nearest discrete value based on scale type.
439
+ Uses round-half-up tie-breaking (e.g., 3.5 rounds to 4, 1.5 rounds to 2).
440
+
441
+ Args:
442
+ score: The float score to normalize (can be None)
443
+ scale_type: Either '0-5' for 0-5 scale (discrete: 0,1,2,3,4,5)
444
+ or '0-10' for 0-10 scale (discrete: 0,2,4,6,8,10)
445
+
446
+ Returns:
447
+ Normalized discrete score, or None if input is None
448
+ """
449
+ if score is None:
450
+ return None
451
+
452
+ try:
453
+ score = float(score)
454
+ except (ValueError, TypeError):
455
+ return None
456
+
457
+ if scale_type == '0-5':
458
+ # Discrete values: 0, 1, 2, 3, 4, 5
459
+ discrete_values = [0, 1, 2, 3, 4, 5]
460
+ # Clamp to valid range
461
+ score = max(0, min(5, score))
462
+ # Find nearest discrete value, with round-half-up tie-breaking
463
+ # For ties, prefer the higher value
464
+ best_value = None
465
+ best_distance = float('inf')
466
+ for val in discrete_values:
467
+ distance = abs(val - score)
468
+ if distance < best_distance:
469
+ best_distance = distance
470
+ best_value = val
471
+ elif distance == best_distance and val > best_value:
472
+ # Tie-breaking: prefer higher value (round-half-up)
473
+ best_value = val
474
+ return best_value
475
+ elif scale_type == '0-10':
476
+ # Discrete values: 0, 2, 4, 6, 8, 10
477
+ discrete_values = [0, 2, 4, 6, 8, 10]
478
+ # Clamp to valid range
479
+ score = max(0, min(10, score))
480
+ # Find nearest discrete value, with round-half-up tie-breaking
481
+ best_value = None
482
+ best_distance = float('inf')
483
+ for val in discrete_values:
484
+ distance = abs(val - score)
485
+ if distance < best_distance:
486
+ best_distance = distance
487
+ best_value = val
488
+ elif distance == best_distance and val > best_value:
489
+ # Tie-breaking: prefer higher value (round-half-up)
490
+ best_value = val
491
+ return best_value
492
+ else:
493
+ raise ValueError(f"Unknown scale_type: {scale_type}. Must be '0-5' or '0-10'")
494
+
495
+
496
+ def normalize_scores_dict(scores: Dict[str, Optional[float]]) -> Dict[str, Optional[float]]:
497
+ """
498
+ Normalize all scores in a dictionary to their appropriate discrete scales.
499
+
500
+ Args:
501
+ scores: Dictionary with keys 'soundness', 'presentation', 'rating', 'confidence'
502
+
503
+ Returns:
504
+ Dictionary with normalized scores
505
+ """
506
+ normalized = {}
507
+
508
+ # soundness, presentation, confidence use 0-5 scale
509
+ for key in ['soundness', 'presentation', 'confidence']:
510
+ normalized[key] = normalize_to_discrete_scale(scores.get(key), '0-5')
511
+
512
+ # rating uses 0-10 scale
513
+ normalized['rating'] = normalize_to_discrete_scale(scores.get('rating'), '0-10')
514
+
515
+ return normalized
516
+
517
+
518
+ def calculate_score_metrics(
519
+ model_scores: Dict[str, float],
520
+ ground_truth_scores: Dict[str, float],
521
+ normalize: bool = False
522
+ ) -> Dict[str, Any]:
523
+ """
524
+ Calculate MSE and MAE metrics for each scoring dimension.
525
+
526
+ Args:
527
+ model_scores: Dictionary with model scores
528
+ ground_truth_scores: Dictionary with ground truth scores
529
+ normalize: If True, normalize scores to discrete scales before computing metrics
530
+
531
+ Returns:
532
+ Dictionary with MSE, MAE metrics and optionally normalized scores
533
+ """
534
+ dimensions = ['soundness', 'presentation', 'rating', 'confidence']
535
+
536
+ # Normalize scores to discrete scales if requested
537
+ if normalize:
538
+ model_scores_normalized = normalize_scores_dict(model_scores)
539
+ gt_scores_normalized = normalize_scores_dict(ground_truth_scores)
540
+ else:
541
+ model_scores_normalized = model_scores
542
+ gt_scores_normalized = ground_truth_scores
543
+
544
+ mse_values = {}
545
+ mae_values = {}
546
+ valid_count = 0
547
+
548
+ for dim in dimensions:
549
+ # Use normalized scores for metric calculation
550
+ mse = calculate_mse(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
551
+ mae = calculate_mae(model_scores_normalized.get(dim), gt_scores_normalized.get(dim))
552
+ mse_values[f'{dim}_mse'] = mse
553
+ mae_values[f'{dim}_mae'] = mae
554
+ if mse is not None:
555
+ valid_count += 1
556
+
557
+ overall_error = sum([v for v in mse_values.values() if v is not None])
558
+
559
+ result = {
560
+ **mse_values,
561
+ **mae_values,
562
+ 'overall_error': overall_error if valid_count > 0 else None,
563
+ 'valid_dimensions': valid_count
564
+ }
565
+
566
+ # Include normalized scores in result for transparency (only if normalize=True)
567
+ if normalize:
568
+ result['model_scores_normalized'] = model_scores_normalized
569
+ result['gt_scores_normalized'] = gt_scores_normalized
570
+
571
+ return result
572
+
573
+
574
+ def normalize_score_value(value):
575
+ """Normalize score value to float, handling string representations."""
576
+ if value is None:
577
+ return None
578
+ if isinstance(value, (int, float)):
579
+ return float(value)
580
+ if isinstance(value, str):
581
+ # Try to extract numeric value from string (e.g., "2.75" -> 2.75)
582
+ try:
583
+ import re
584
+ match = re.search(r'(\d+\.?\d*)', value)
585
+ if match:
586
+ return float(match.group(1))
587
+ except:
588
+ pass
589
+ return None
590
+
591
+
592
+ def normalize_decision(decision):
593
+ """Normalize decision string to standard format."""
594
+ if decision is None:
595
+ return None
596
+ decision_lower = str(decision).lower().strip()
597
+ if 'accept' in decision_lower:
598
+ return 'accept'
599
+ elif 'reject' in decision_lower:
600
+ return 'reject'
601
+ elif 'undecided' in decision_lower:
602
+ return 'undecided'
603
+ else:
604
+ return decision_lower
605
+
606
+
607
+ def extract_scores_from_dict(scores_dict: Dict[str, Any]) -> Dict[str, Any]:
608
+ """
609
+ Extract scores from a structured dictionary (scores or initial_scores format).
610
+
611
+ Args:
612
+ scores_dict: Dict containing scores (e.g., {'rating': 5.75, 'soundness': '2.75', ...})
613
+
614
+ Returns:
615
+ Dict with normalized scores: {'soundness', 'presentation', 'rating', 'confidence', 'decision'}
616
+ """
617
+ if not scores_dict:
618
+ return {
619
+ 'soundness': None,
620
+ 'presentation': None,
621
+ 'rating': None,
622
+ 'confidence': None,
623
+ 'decision': None
624
+ }
625
+
626
+ return {
627
+ 'soundness': normalize_score_value(scores_dict.get('soundness')),
628
+ 'presentation': normalize_score_value(scores_dict.get('presentation')),
629
+ 'rating': normalize_score_value(scores_dict.get('rating')),
630
+ 'confidence': normalize_score_value(scores_dict.get('confidence')),
631
+ 'decision': normalize_decision(scores_dict.get('decision'))
632
+ }
633
+
634
+
635
+ def evaluate_review_auto_metric(entry: Dict[str, Any], use_initial_scores: bool = False, strict_mode: bool = False) -> Dict[str, Any]:
636
+ """
637
+ Evaluate a single entry by extracting scores and calculating metrics.
638
+
639
+ Args:
640
+ entry: Evaluation entry containing model_review, scores, initial_scores, etc.
641
+ use_initial_scores: If True, use initial_scores instead of refined scores (for refined format)
642
+
643
+ Returns:
644
+ Dict containing evaluation metrics
645
+ """
646
+ entry_id = entry.get('id', 'unknown')
647
+ model_review = entry.get('model_review', '')
648
+ format_type = entry.get('format', 'unknown')
649
+
650
+ # Extract scores based on format
651
+ model_scores = {}
652
+ model_decision = None
653
+
654
+ if format_type == 'refined' and not use_initial_scores:
655
+ # Use refined scores from structured data
656
+ scores_dict = entry.get('scores', {})
657
+ model_data = extract_scores_from_dict(scores_dict)
658
+ model_scores = {
659
+ 'soundness': model_data.get('soundness'),
660
+ 'presentation': model_data.get('presentation'),
661
+ 'rating': model_data.get('rating'),
662
+ 'confidence': model_data.get('confidence')
663
+ }
664
+ model_decision = model_data.get('decision')
665
+ elif format_type == 'refined' and use_initial_scores:
666
+ # Use initial scores from structured data
667
+ initial_scores_dict = entry.get('initial_scores', {})
668
+ model_data = extract_scores_from_dict(initial_scores_dict)
669
+ model_scores = {
670
+ 'soundness': model_data.get('soundness'),
671
+ 'presentation': model_data.get('presentation'),
672
+ 'rating': model_data.get('rating'),
673
+ 'confidence': model_data.get('confidence')
674
+ }
675
+ model_decision = model_data.get('decision')
676
+ elif format_type == 'original':
677
+ # Use initial scores from structured data
678
+ initial_scores_dict = entry.get('initial_scores', {})
679
+ model_data = extract_scores_from_dict(initial_scores_dict)
680
+ model_scores = {
681
+ 'soundness': model_data.get('soundness'),
682
+ 'presentation': model_data.get('presentation'),
683
+ 'rating': model_data.get('rating'),
684
+ 'confidence': model_data.get('confidence')
685
+ }
686
+ model_decision = model_data.get('decision')
687
+
688
+ # Fallback: If confidence is missing from structured data, try to extract from review text
689
+ # (meta_review may not have confidence field, but review text might)
690
+ if model_scores.get('confidence') is None and model_review:
691
+ try:
692
+ review_data = extract_scores_from_review(model_review)
693
+ if review_data.get('confidence') is not None:
694
+ model_scores['confidence'] = review_data.get('confidence')
695
+ except Exception:
696
+ pass # Keep confidence as None if extraction fails
697
+ else:
698
+ # Fallback: extract from markdown review text
699
+ model_data = extract_scores_from_review(model_review)
700
+ model_scores = {
701
+ 'soundness': model_data.get('soundness'),
702
+ 'presentation': model_data.get('presentation'),
703
+ 'rating': model_data.get('rating'),
704
+ 'confidence': model_data.get('confidence')
705
+ }
706
+ model_decision = model_data.get('decision')
707
+
708
+ # Get ground truth scores from golden_review ONLY
709
+ # Ground truth must ONLY come from golden_review, never from model output
710
+ # If extraction fails, leave fields as None (do not use model_review as fallback)
711
+ ground_truth_review = entry.get('golden_review', '')
712
+ ground_truth_scores = {}
713
+ gt_decision = None
714
+
715
+ if not ground_truth_review:
716
+ print(f"Warning: No golden_review found for entry {entry_id}. Ground truth scores will be empty.")
717
+ else:
718
+ try:
719
+ # Extract scores from golden_review markdown text
720
+ gt_data = extract_scores_from_review(ground_truth_review)
721
+ if not gt_data:
722
+ print(f"Warning: Failed to parse golden_review for entry {entry_id}. Ground truth scores will be empty.")
723
+ else:
724
+ ground_truth_scores = {
725
+ 'soundness': gt_data.get('soundness'),
726
+ 'presentation': gt_data.get('presentation'),
727
+ 'rating': gt_data.get('rating'),
728
+ 'confidence': gt_data.get('confidence')
729
+ }
730
+ gt_decision = normalize_decision(gt_data.get('decision'))
731
+ # Note: If any field is None, it stays None - we do NOT use model_review as fallback
732
+ # Using model output as ground truth would inflate evaluation scores
733
+ except Exception as e:
734
+ print(f"Warning: Failed to extract scores from golden_review for {entry_id}: {e}")
735
+ print(f" Ground truth scores will be empty. Error: {str(e)}")
736
+
737
+ # Calculate MSE and MAE metrics (with optional normalization in strict mode)
738
+ score_metrics = calculate_score_metrics(model_scores, ground_truth_scores, normalize=strict_mode)
739
+
740
+ # Calculate decision accuracy
741
+ decision_match = False
742
+ decision_accuracy = None
743
+ if model_decision is not None and gt_decision is not None:
744
+ model_decision_normalized = normalize_decision(model_decision)
745
+ decision_match = (model_decision_normalized == gt_decision)
746
+ decision_accuracy = 1.0 if decision_match else 0.0
747
+
748
+ result = {
749
+ 'id': entry_id,
750
+ 'format': format_type,
751
+ 'model_soundness': model_scores.get('soundness'),
752
+ 'model_presentation': model_scores.get('presentation'),
753
+ 'model_rating': model_scores.get('rating'),
754
+ 'model_confidence': model_scores.get('confidence'),
755
+ 'model_decision': model_decision,
756
+ 'gt_soundness': ground_truth_scores.get('soundness'),
757
+ 'gt_presentation': ground_truth_scores.get('presentation'),
758
+ 'gt_rating': ground_truth_scores.get('rating'),
759
+ 'gt_confidence': ground_truth_scores.get('confidence'),
760
+ 'gt_decision': gt_decision,
761
+ 'decision_match': decision_match,
762
+ 'decision_accuracy': decision_accuracy,
763
+ **score_metrics
764
+ }
765
+
766
+ # Add prefix to indicate which scores were used
767
+ if format_type == 'refined':
768
+ if use_initial_scores:
769
+ result['score_type'] = 'initial'
770
+ else:
771
+ result['score_type'] = 'refined'
772
+ else:
773
+ result['score_type'] = 'auto'
774
+
775
+ return result
776
+
777
+
778
+ def calculate_pairwise_accuracies(paper_scores: List[Dict[str, float]]) -> Dict[str, float]:
779
+ """Calculate pairwise accuracy for each metric by comparing rankings."""
780
+ if len(paper_scores) < 2:
781
+ return {}
782
+
783
+ total_valid_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
784
+ correct_pairs = {'rating': 0, 'soundness': 0, 'presentation': 0, 'confidence': 0}
785
+
786
+ for paper1, paper2 in combinations(paper_scores, 2):
787
+ # Check rating ranking
788
+ if (paper1.get('true_rating') is not None and paper2.get('true_rating') is not None and
789
+ paper1.get('pred_rating') is not None and paper2.get('pred_rating') is not None):
790
+ total_valid_pairs['rating'] += 1
791
+ true_order = paper1['true_rating'] > paper2['true_rating']
792
+ pred_order = paper1['pred_rating'] > paper2['pred_rating']
793
+ if true_order == pred_order:
794
+ correct_pairs['rating'] += 1
795
+
796
+ # Similar for other dimensions...
797
+ # (abbreviated for space, similar logic for soundness, presentation, confidence)
798
+ for metric in ['soundness', 'presentation', 'confidence']:
799
+ true_key = f'true_{metric}'
800
+ pred_key = f'pred_{metric}'
801
+ if (paper1.get(true_key) is not None and paper2.get(true_key) is not None and
802
+ paper1.get(pred_key) is not None and paper2.get(pred_key) is not None):
803
+ total_valid_pairs[metric] += 1
804
+ true_order = paper1[true_key] > paper2[true_key]
805
+ pred_order = paper1[pred_key] > paper2[pred_key]
806
+ if true_order == pred_order:
807
+ correct_pairs[metric] += 1
808
+
809
+ pairwise_accuracies = {
810
+ metric: correct_pairs[metric] / total_valid_pairs[metric] if total_valid_pairs[metric] > 0 else 0.0
811
+ for metric in ['rating', 'soundness', 'presentation', 'confidence']
812
+ }
813
+
814
+ return pairwise_accuracies
815
+
816
+
817
+ # ============================================================================
818
+ # Data Loading Functions
819
+ # ============================================================================
820
+
821
+ def load_rubrics_json(rubrics_path: str) -> Dict[str, Dict[str, Any]]:
822
+ """Load rubrics JSON and create lookup by id."""
823
+ with open(rubrics_path, 'r', encoding='utf-8') as f:
824
+ data = json.load(f)
825
+
826
+ if isinstance(data, list):
827
+ return {item['id']: item for item in data}
828
+ elif isinstance(data, dict):
829
+ return data
830
+ else:
831
+ raise ValueError(f"Invalid rubrics JSON format: expected list or dict, got {type(data)}")
832
+
833
+
834
+ def load_model_reviews_json(reviews_path: str, format_override: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
835
+ """
836
+ Load model reviews JSON and extract reviews by id.
837
+
838
+ Supports two input formats:
839
+ 1. Refined format: Contains 'scores' and 'initial_scores' fields (from refinement pipeline)
840
+ 2. Original format: Contains 'model_prediction' with 'meta_review' and 'decision' (like ours.json)
841
+
842
+ Args:
843
+ reviews_path: Path to JSON file containing model reviews
844
+ format_override: Optional format override ('refined', 'original', or None for auto-detect)
845
+
846
+ Returns:
847
+ Dict mapping paper_id to dict containing:
848
+ - 'review': review text (markdown)
849
+ - 'scores': refined scores dict (if available)
850
+ - 'initial_scores': initial scores dict (if available)
851
+ - 'format': 'refined' or 'original'
852
+ """
853
+ with open(reviews_path, 'r', encoding='utf-8') as f:
854
+ data = json.load(f)
855
+
856
+ if isinstance(data, dict):
857
+ data = list(data.values())
858
+
859
+ reviews_dict = {}
860
+ for item in data:
861
+ item_id = None
862
+ review_text = ''
863
+ scores = None
864
+ initial_scores = None
865
+ format_type = None
866
+
867
+ # Use format override if provided, otherwise auto-detect
868
+ if format_override and format_override != 'auto':
869
+ # Force use specified format
870
+ if format_override == 'refined':
871
+ item_id = item.get('paper_id') or item.get('id')
872
+ if not item_id:
873
+ continue
874
+ format_type = 'refined'
875
+ review_text = item.get('review_markdown', '') or item.get('review', '')
876
+ scores = item.get('scores', {})
877
+ initial_scores = item.get('initial_scores', {})
878
+ elif format_override == 'original':
879
+ item_id = item.get('id')
880
+ if not item_id:
881
+ continue
882
+ format_type = 'original'
883
+ model_prediction = item.get('model_prediction', {})
884
+ meta_review = model_prediction.get('meta_review', {})
885
+ review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
886
+ initial_scores = {
887
+ 'rating': meta_review.get('rating'),
888
+ 'soundness': meta_review.get('soundness'),
889
+ 'presentation': meta_review.get('presentation'),
890
+ 'contribution': meta_review.get('contribution'),
891
+ 'decision': model_prediction.get('decision'),
892
+ }
893
+ else:
894
+ raise ValueError(f"Unknown format_override: {format_override}. Must be 'refined', 'original', or 'auto'")
895
+ else:
896
+ # Auto-detect format
897
+ if "paper_id" in item:
898
+ # Refined format (from refinement pipeline)
899
+ item_id = item.get('paper_id')
900
+ if not item_id:
901
+ continue
902
+
903
+ # Check if this is refined format (has scores and initial_scores)
904
+ if 'scores' in item and 'initial_scores' in item:
905
+ format_type = 'refined'
906
+ review_text = item.get('review_markdown', '') or item.get('review', '')
907
+ scores = item.get('scores', {})
908
+ initial_scores = item.get('initial_scores', {})
909
+ else:
910
+ # Standard format with paper_id
911
+ format_type = 'standard'
912
+ review_text = item.get('review_markdown', '') or item.get('review', '')
913
+ elif "model_prediction" in item:
914
+ # Original format (like ours.json) or cyclereviewer format
915
+ item_id = item.get('id')
916
+ if not item_id:
917
+ continue
918
+
919
+ format_type = 'original'
920
+ model_prediction = item.get('model_prediction', {})
921
+ meta_review = model_prediction.get('meta_review', {})
922
+
923
+ # Extract review content (prefer meta_review.content, fallback to raw_text)
924
+ review_text = meta_review.get('content', '') or model_prediction.get('raw_text', '')
925
+
926
+ # Detect cyclereviewer format: has raw_text as markdown string with "## Rating" or "## Score:" patterns
927
+ is_cyclereviewer = False
928
+ if isinstance(review_text, str) and review_text:
929
+ # Check if it contains cyclereviewer patterns
930
+ if (re.search(r'##\s*(Rating|Score)\s*:', review_text, re.IGNORECASE) or
931
+ re.search(r'##\s*Rating\s*\n\n\s*\d+\s*:', review_text, re.IGNORECASE | re.MULTILINE)):
932
+ is_cyclereviewer = True
933
+
934
+ # Handle cyclereviewer format
935
+ if is_cyclereviewer:
936
+ review_text, meta_review = convert_cyclereviewer(review_text)
937
+
938
+ # Extract initial scores
939
+ # Use meta_review as primary source (from convert_cyclereviewer or original meta_review)
940
+ # Fallback to model_prediction.get('decision') if not in meta_review
941
+ initial_scores = {
942
+ 'rating': meta_review.get('rating'),
943
+ 'soundness': meta_review.get('soundness'),
944
+ 'presentation': meta_review.get('presentation'),
945
+ 'contribution': meta_review.get('contribution'),
946
+ 'confidence': meta_review.get('confidence'),
947
+ 'decision': meta_review.get('decision') or model_prediction.get('decision'),
948
+ }
949
+ else:
950
+ # Legacy format (pred_fast_mode)
951
+ item_id = item.get('id')
952
+ if not item_id:
953
+ continue
954
+
955
+ format_type = 'legacy'
956
+ review_dict = item.get('pred_fast_mode', {})
957
+ if isinstance(review_dict, dict):
958
+ # review_text = review_dict.get('raw_text', '')
959
+ review_text = review_dict
960
+ else:
961
+ review_text = str(review_dict)
962
+
963
+ # Extract review content from the review text field
964
+ try:
965
+ if review_text:
966
+ extracted_review = ReviewProcessor.extract_review_content(review_text)
967
+ else:
968
+ extracted_review = ''
969
+
970
+ reviews_dict[item_id] = {
971
+ 'review': extracted_review,
972
+ 'scores': scores,
973
+ 'initial_scores': initial_scores,
974
+ 'format': format_type
975
+ }
976
+ except Exception as e:
977
+ print(f"[WARN] Failed to extract review for {item_id}: {e}")
978
+ continue
979
+
980
+ return reviews_dict
981
+
982
+
983
+ def combine_rubrics_and_reviews(
984
+ rubrics_data: Dict[str, Dict[str, Any]],
985
+ reviews_dict: Dict[str, Dict[str, Any]]
986
+ ) -> List[Dict[str, Any]]:
987
+ """
988
+ Combine rubrics and reviews into evaluation entries.
989
+
990
+ Args:
991
+ rubrics_data: Dict mapping paper_id to rubric entry
992
+ reviews_dict: Dict mapping paper_id to dict containing 'review', 'scores', 'initial_scores', 'format'
993
+
994
+ Returns:
995
+ List of evaluation entries with model_review, scores, initial_scores, and format info
996
+ """
997
+ combined = []
998
+ missing_reviews = []
999
+
1000
+ for paper_id, rubric_entry in rubrics_data.items():
1001
+ review_data = reviews_dict.get(paper_id)
1002
+ if not review_data or not review_data.get('review'):
1003
+ missing_reviews.append(paper_id)
1004
+ continue
1005
+
1006
+ entry = {
1007
+ 'id': paper_id,
1008
+ 'paper_context': rubric_entry.get('paper_context', ''),
1009
+ 'decision': rubric_entry.get('decision', ''),
1010
+ 'golden_review': rubric_entry.get('golden_review', ''),
1011
+ 'rubrics': rubric_entry.get('rubrics', []),
1012
+ 'model_review': review_data.get('review', ''),
1013
+ 'scores': review_data.get('scores'), # Refined scores (if available)
1014
+ 'initial_scores': review_data.get('initial_scores'), # Initial scores (if available)
1015
+ 'format': review_data.get('format', 'unknown') # Format type
1016
+ }
1017
+ combined.append(entry)
1018
+
1019
+ if missing_reviews:
1020
+ print(f"[WARN] {len(missing_reviews)} papers have no model review, skipping them")
1021
+
1022
+ return combined
1023
+
1024
+
1025
+ # ============================================================================
1026
+ # LLM Service Configuration
1027
+ # ============================================================================
1028
+
1029
+ def load_llm_config(config_path: str) -> Dict[str, Any]:
1030
+ """Load LLM configuration from YAML file."""
1031
+ with open(config_path, 'r', encoding='utf-8') as f:
1032
+ config = yaml.safe_load(f)
1033
+ return config
1034
+
1035
+
1036
+ def create_llm_service_from_config(config: Dict[str, Any]) -> LLMService:
1037
+ """Create LLM service from configuration."""
1038
+ mode = config.get('mode', 'gpt').lower()
1039
+
1040
+ if mode == 'gpt':
1041
+ gpt_config = config.get('gpt', {})
1042
+ api_key = gpt_config.get('api_key') or os.getenv('OPENAI_API_KEY')
1043
+ if not api_key:
1044
+ raise ValueError("GPT mode requires api_key in configs.yaml or OPENAI_API_KEY environment variable")
1045
+
1046
+ service = GPTService(
1047
+ api_key=api_key,
1048
+ model_name=gpt_config.get('model_name', 'gpt-4o'),
1049
+ base_url=gpt_config.get('base_url'),
1050
+ timeout=gpt_config.get('timeout', 300)
1051
+ )
1052
+ return service
1053
+
1054
+ elif mode == 'vllm':
1055
+ vllm_config = config.get('vllm', {})
1056
+ service = VLLMService(
1057
+ base_url=vllm_config.get('base_url', 'http://localhost:8000/v1'),
1058
+ api_key=vllm_config.get('api_key', 'dummy-key'),
1059
+ model_name=vllm_config.get('model_name'),
1060
+ timeout=vllm_config.get('timeout', 300),
1061
+ max_concurrent_requests=vllm_config.get('max_concurrent_requests', 64),
1062
+ max_retries=vllm_config.get('max_retries', 3),
1063
+ retry_delay=vllm_config.get('retry_delay', 1.0),
1064
+ retry_backoff=vllm_config.get('retry_backoff', 2.0)
1065
+ )
1066
+ return service
1067
+
1068
+ else:
1069
+ raise ValueError(f"Unknown mode: {mode}. Must be 'gpt' or 'vllm'")
1070
+
1071
+
1072
+ # ============================================================================
1073
+ # Main Evaluation Functions
1074
+ # ============================================================================
1075
+
1076
+ def run_semantic_evaluation(
1077
+ evaluation_data: List[Dict[str, Any]],
1078
+ prompt_template: str,
1079
+ llm_service: LLMService,
1080
+ max_workers: int
1081
+ ) -> tuple:
1082
+ """Run semantic evaluation and return results and summary."""
1083
+ print(f"\n{'='*80}")
1084
+ print("RUNNING SEMANTIC EVALUATION")
1085
+ print(f"{'='*80}")
1086
+ print(f"Evaluating {len(evaluation_data)} reviews using {max_workers} workers...")
1087
+
1088
+ results = []
1089
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
1090
+ future_to_entry = {
1091
+ executor.submit(
1092
+ evaluate_review_semantic,
1093
+ entry,
1094
+ entry['paper_context'],
1095
+ prompt_template,
1096
+ llm_service
1097
+ ): entry
1098
+ for entry in evaluation_data
1099
+ }
1100
+
1101
+ for future in tqdm(as_completed(future_to_entry), total=len(evaluation_data), desc="Semantic evaluation"):
1102
+ try:
1103
+ result = future.result()
1104
+ results.append(result)
1105
+ except Exception as e:
1106
+ entry = future_to_entry[future]
1107
+ print(f"\n[ERROR] Failed to process entry {entry.get('id', 'unknown')}: {e}")
1108
+ results.append({
1109
+ 'id': entry.get('id', 'unknown'),
1110
+ 'raw_scores': {},
1111
+ 'weighted_scores': {},
1112
+ 'total_score': 0.0,
1113
+ 'error': str(e),
1114
+ 'raw_response': ''
1115
+ })
1116
+
1117
+ # Calculate statistics
1118
+ valid_results = [r for r in results if 'error' not in r and r.get('weighted_scores')]
1119
+ review_scores = [r.get('total_score', 0.0) for r in valid_results]
1120
+
1121
+ summary = {
1122
+ 'total_entries': len(results),
1123
+ 'valid_entries': len(valid_results),
1124
+ 'failed_entries': len(results) - len(valid_results)
1125
+ }
1126
+
1127
+ if review_scores:
1128
+ summary['overall_score'] = {
1129
+ 'mean': sum(review_scores) / len(review_scores),
1130
+ 'min': min(review_scores),
1131
+ 'max': max(review_scores)
1132
+ }
1133
+
1134
+ # Calculate per-rubric statistics (extract rubric titles from first entry)
1135
+ if evaluation_data and evaluation_data[0].get('rubrics'):
1136
+ rubric_titles = [r['title'] for r in evaluation_data[0]['rubrics']]
1137
+ per_rubric_stats = calculate_per_rubric_statistics(valid_results, rubric_titles)
1138
+ summary['per_rubric_statistics'] = per_rubric_stats
1139
+
1140
+ return results, summary
1141
+
1142
+
1143
+ def run_auto_metric_evaluation(
1144
+ evaluation_data: List[Dict[str, Any]],
1145
+ strict_mode: bool = False
1146
+ ) -> tuple:
1147
+ """
1148
+ Run auto-metric evaluation and return results and summary.
1149
+
1150
+ For refined format (has scores and initial_scores), evaluates both:
1151
+ - Refined scores evaluation
1152
+ - Initial scores evaluation
1153
+
1154
+ For original format (only initial_scores), evaluates:
1155
+ - Initial scores evaluation only
1156
+
1157
+ Returns:
1158
+ Tuple of (results_list, summary_dict)
1159
+ - results_list: List of evaluation results (may contain both refined and initial results for refined format)
1160
+ - summary_dict: Summary statistics
1161
+ """
1162
+ print(f"\n{'='*80}")
1163
+ print("RUNNING AUTO-METRIC EVALUATION")
1164
+ print(f"{'='*80}")
1165
+ print(f"Evaluating {len(evaluation_data)} entries...")
1166
+
1167
+ # Detect format types
1168
+ refined_format_count = sum(1 for e in evaluation_data if e.get('format') == 'refined')
1169
+ original_format_count = sum(1 for e in evaluation_data if e.get('format') == 'original')
1170
+
1171
+ if refined_format_count > 0:
1172
+ print(f"Detected {refined_format_count} entries in refined format (will evaluate both refined and initial scores)")
1173
+ if original_format_count > 0:
1174
+ print(f"Detected {original_format_count} entries in original format (will evaluate initial scores only)")
1175
+
1176
+ results = []
1177
+ for entry in tqdm(evaluation_data, desc="Auto-metric evaluation"):
1178
+ format_type = entry.get('format', 'unknown')
1179
+
1180
+ if format_type == 'refined':
1181
+ # Evaluate both refined scores and initial scores
1182
+ try:
1183
+ entry_id = entry.get('id', 'unknown')
1184
+
1185
+ # Evaluate refined scores
1186
+ refined_result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1187
+ refined_result['paper_id'] = entry_id # Keep original paper_id
1188
+ refined_result['id'] = f"{entry_id}_refined"
1189
+ results.append(refined_result)
1190
+
1191
+ # Evaluate initial scores
1192
+ initial_result = evaluate_review_auto_metric(entry, use_initial_scores=True, strict_mode=strict_mode)
1193
+ initial_result['paper_id'] = entry_id # Keep original paper_id
1194
+ initial_result['id'] = f"{entry_id}_initial"
1195
+ results.append(initial_result)
1196
+ except Exception as e:
1197
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1198
+ results.append({
1199
+ 'id': entry.get('id', 'unknown'),
1200
+ 'error': str(e)
1201
+ })
1202
+ else:
1203
+ # Evaluate initial scores only (or extract from markdown)
1204
+ try:
1205
+ result = evaluate_review_auto_metric(entry, use_initial_scores=False, strict_mode=strict_mode)
1206
+ results.append(result)
1207
+ except Exception as e:
1208
+ print(f"Error evaluating entry {entry.get('id', 'unknown')}: {e}")
1209
+ results.append({
1210
+ 'id': entry.get('id', 'unknown'),
1211
+ 'error': str(e)
1212
+ })
1213
+
1214
+ # Calculate statistics
1215
+ valid_results = [r for r in results if 'error' not in r]
1216
+ mse_results = [r for r in valid_results if r.get('overall_error') is not None]
1217
+
1218
+ # Separate refined and initial results for refined format
1219
+ refined_results = [r for r in valid_results if r.get('score_type') == 'refined']
1220
+ initial_results = [r for r in valid_results if r.get('score_type') == 'initial']
1221
+ auto_results = [r for r in valid_results if r.get('score_type') == 'auto' or r.get('score_type') is None]
1222
+
1223
+ summary = {
1224
+ 'total_entries': len(results),
1225
+ 'valid_entries': len(valid_results),
1226
+ 'mse_entries': len(mse_results),
1227
+ 'refined_results_count': len(refined_results),
1228
+ 'initial_results_count': len(initial_results),
1229
+ 'auto_results_count': len(auto_results)
1230
+ }
1231
+
1232
+ # Calculate MSE/MAE statistics
1233
+ # For refined format, only use refined results for overall statistics (avoid double counting)
1234
+ # For other formats, use all results
1235
+ if refined_format_count > 0:
1236
+ # Refined format: use only refined results for overall statistics
1237
+ stats_results = [r for r in refined_results if r.get('overall_error') is not None]
1238
+ else:
1239
+ # Original/other formats: use all results
1240
+ stats_results = mse_results
1241
+
1242
+ if stats_results:
1243
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1244
+ mse_stats = {}
1245
+ mae_stats = {}
1246
+
1247
+ for dim in dimensions:
1248
+ mse_list = [r.get(f'{dim}_mse') for r in stats_results if r.get(f'{dim}_mse') is not None]
1249
+ mae_list = [r.get(f'{dim}_mae') for r in stats_results if r.get(f'{dim}_mae') is not None]
1250
+
1251
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1252
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1253
+
1254
+ if mse_clean:
1255
+ mse_stats[dim] = {
1256
+ 'mean': sum(mse_clean) / len(mse_clean),
1257
+ 'count': len(mse_clean)
1258
+ }
1259
+ if mae_clean:
1260
+ mae_stats[dim] = {
1261
+ 'mean': sum(mae_clean) / len(mae_clean),
1262
+ 'count': len(mae_clean)
1263
+ }
1264
+
1265
+ overall_errors = [r.get('overall_error') for r in stats_results if r.get('overall_error') is not None]
1266
+ overall_clean = [x for x in overall_errors if x is not None and not (isinstance(x, float) and math.isnan(x))]
1267
+
1268
+ if overall_clean:
1269
+ summary['overall_error'] = {
1270
+ 'mean': sum(overall_clean) / len(overall_clean),
1271
+ 'count': len(overall_clean)
1272
+ }
1273
+
1274
+ summary['mse_statistics'] = mse_stats
1275
+ summary['mae_statistics'] = mae_stats
1276
+
1277
+ # Calculate separate statistics for refined and initial results
1278
+ if refined_results:
1279
+ refined_mse_results = [r for r in refined_results if r.get('overall_error') is not None]
1280
+ if refined_mse_results:
1281
+ refined_mse_stats = {}
1282
+ refined_mae_stats = {}
1283
+ for dim in dimensions:
1284
+ mse_list = [r.get(f'{dim}_mse') for r in refined_mse_results if r.get(f'{dim}_mse') is not None]
1285
+ mae_list = [r.get(f'{dim}_mae') for r in refined_mse_results if r.get(f'{dim}_mae') is not None]
1286
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1287
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1288
+ if mse_clean:
1289
+ refined_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1290
+ if mae_clean:
1291
+ refined_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1292
+ summary['refined_mse_statistics'] = refined_mse_stats
1293
+ summary['refined_mae_statistics'] = refined_mae_stats
1294
+
1295
+ if initial_results:
1296
+ initial_mse_results = [r for r in initial_results if r.get('overall_error') is not None]
1297
+ if initial_mse_results:
1298
+ initial_mse_stats = {}
1299
+ initial_mae_stats = {}
1300
+ for dim in dimensions:
1301
+ mse_list = [r.get(f'{dim}_mse') for r in initial_mse_results if r.get(f'{dim}_mse') is not None]
1302
+ mae_list = [r.get(f'{dim}_mae') for r in initial_mse_results if r.get(f'{dim}_mae') is not None]
1303
+ mse_clean = [x for x in mse_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1304
+ mae_clean = [x for x in mae_list if x is not None and not (isinstance(x, float) and math.isnan(x))]
1305
+ if mse_clean:
1306
+ initial_mse_stats[dim] = {'mean': sum(mse_clean) / len(mse_clean), 'count': len(mse_clean)}
1307
+ if mae_clean:
1308
+ initial_mae_stats[dim] = {'mean': sum(mae_clean) / len(mae_clean), 'count': len(mae_clean)}
1309
+ summary['initial_mse_statistics'] = initial_mse_stats
1310
+ summary['initial_mae_statistics'] = initial_mae_stats
1311
+
1312
+ # Calculate Spearman correlations
1313
+ def filter_valid_pairs(true_list, pred_list):
1314
+ filtered_true = []
1315
+ filtered_pred = []
1316
+ for t, p in zip(true_list, pred_list):
1317
+ if (t is not None and p is not None and
1318
+ not (isinstance(t, float) and math.isnan(t)) and
1319
+ not (isinstance(p, float) and math.isnan(p))):
1320
+ filtered_true.append(t)
1321
+ filtered_pred.append(p)
1322
+ return filtered_true, filtered_pred
1323
+
1324
+ # Calculate Spearman correlations
1325
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1326
+ # For other formats, use all results
1327
+ if refined_format_count > 0:
1328
+ # Calculate refined spearman correlations
1329
+ refined_spearman_stats = {}
1330
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1331
+ for dim in dimensions:
1332
+ true_values = [r.get(f'gt_{dim}') for r in refined_results]
1333
+ pred_values = [r.get(f'model_{dim}') for r in refined_results]
1334
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1335
+
1336
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1337
+ try:
1338
+ corr, _ = spearmanr(true_clean, pred_clean)
1339
+ if not math.isnan(corr):
1340
+ refined_spearman_stats[dim] = {
1341
+ 'correlation': corr,
1342
+ 'count': len(true_clean)
1343
+ }
1344
+ except Exception:
1345
+ pass
1346
+
1347
+ # Calculate initial spearman correlations
1348
+ initial_spearman_stats = {}
1349
+ for dim in dimensions:
1350
+ true_values = [r.get(f'gt_{dim}') for r in initial_results]
1351
+ pred_values = [r.get(f'model_{dim}') for r in initial_results]
1352
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1353
+
1354
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1355
+ try:
1356
+ corr, _ = spearmanr(true_clean, pred_clean)
1357
+ if not math.isnan(corr):
1358
+ initial_spearman_stats[dim] = {
1359
+ 'correlation': corr,
1360
+ 'count': len(true_clean)
1361
+ }
1362
+ except Exception:
1363
+ pass
1364
+
1365
+ # Use refined for overall statistics (avoid double counting)
1366
+ summary['spearman_correlations'] = refined_spearman_stats
1367
+ summary['refined_spearman_correlations'] = refined_spearman_stats
1368
+ summary['initial_spearman_correlations'] = initial_spearman_stats
1369
+ else:
1370
+ # Original/other formats: use all results
1371
+ correlation_results = valid_results
1372
+ spearman_stats = {}
1373
+ dimensions = ['soundness', 'presentation', 'confidence', 'rating']
1374
+ for dim in dimensions:
1375
+ true_values = [r.get(f'gt_{dim}') for r in correlation_results]
1376
+ pred_values = [r.get(f'model_{dim}') for r in correlation_results]
1377
+ true_clean, pred_clean = filter_valid_pairs(true_values, pred_values)
1378
+
1379
+ if len(true_clean) >= 2 and len(pred_clean) >= 2:
1380
+ try:
1381
+ corr, _ = spearmanr(true_clean, pred_clean)
1382
+ if not math.isnan(corr):
1383
+ spearman_stats[dim] = {
1384
+ 'correlation': corr,
1385
+ 'count': len(true_clean)
1386
+ }
1387
+ except Exception:
1388
+ pass
1389
+
1390
+ summary['spearman_correlations'] = spearman_stats
1391
+
1392
+ # Calculate Decision metrics
1393
+ # For refined format, calculate separately for refined and initial, and use refined for overall
1394
+ # For other formats, use all results
1395
+ if refined_format_count > 0:
1396
+ # Calculate refined decision metrics
1397
+ refined_decision_results = [r for r in refined_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1398
+ if refined_decision_results:
1399
+ true_decisions = []
1400
+ pred_decisions = []
1401
+ decision_acc = []
1402
+
1403
+ for r in refined_decision_results:
1404
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1405
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1406
+
1407
+ if 'accept' in pred_decision:
1408
+ pred_binary = 1
1409
+ else:
1410
+ pred_binary = 0
1411
+
1412
+ if 'accept' in gt_decision:
1413
+ gt_binary = 1
1414
+ else:
1415
+ gt_binary = 0
1416
+
1417
+ true_decisions.append(gt_binary)
1418
+ pred_decisions.append(pred_binary)
1419
+
1420
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1421
+ decision_acc.append(1.0)
1422
+ else:
1423
+ decision_acc.append(0.0)
1424
+
1425
+ if decision_acc:
1426
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1427
+ try:
1428
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1429
+ refined_decision_metrics = {
1430
+ 'accuracy': decision_accuracy,
1431
+ 'f1_macro': f1_score,
1432
+ 'count': len(decision_acc)
1433
+ }
1434
+ except Exception:
1435
+ refined_decision_metrics = {
1436
+ 'accuracy': decision_accuracy,
1437
+ 'count': len(decision_acc)
1438
+ }
1439
+ summary['refined_decision_metrics'] = refined_decision_metrics
1440
+ summary['decision_metrics'] = refined_decision_metrics # Use refined for overall
1441
+
1442
+ # Calculate initial decision metrics
1443
+ initial_decision_results = [r for r in initial_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1444
+ if initial_decision_results:
1445
+ true_decisions = []
1446
+ pred_decisions = []
1447
+ decision_acc = []
1448
+
1449
+ for r in initial_decision_results:
1450
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1451
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1452
+
1453
+ if 'accept' in pred_decision:
1454
+ pred_binary = 1
1455
+ else:
1456
+ pred_binary = 0
1457
+
1458
+ if 'accept' in gt_decision:
1459
+ gt_binary = 1
1460
+ else:
1461
+ gt_binary = 0
1462
+
1463
+ true_decisions.append(gt_binary)
1464
+ pred_decisions.append(pred_binary)
1465
+
1466
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1467
+ decision_acc.append(1.0)
1468
+ else:
1469
+ decision_acc.append(0.0)
1470
+
1471
+ if decision_acc:
1472
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1473
+ try:
1474
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1475
+ initial_decision_metrics = {
1476
+ 'accuracy': decision_accuracy,
1477
+ 'f1_macro': f1_score,
1478
+ 'count': len(decision_acc)
1479
+ }
1480
+ except Exception:
1481
+ initial_decision_metrics = {
1482
+ 'accuracy': decision_accuracy,
1483
+ 'count': len(decision_acc)
1484
+ }
1485
+ summary['initial_decision_metrics'] = initial_decision_metrics
1486
+ else:
1487
+ # Original/other formats: use all results
1488
+ decision_results = [r for r in valid_results if r.get('gt_decision') is not None and r.get('model_decision') is not None]
1489
+ if decision_results:
1490
+ true_decisions = []
1491
+ pred_decisions = []
1492
+ decision_acc = []
1493
+
1494
+ for r in decision_results:
1495
+ gt_decision = str(r.get('gt_decision', '')).lower().strip()
1496
+ pred_decision = str(r.get('model_decision', '')).lower().strip()
1497
+
1498
+ if 'accept' in pred_decision:
1499
+ pred_binary = 1
1500
+ else:
1501
+ pred_binary = 0
1502
+
1503
+ if 'accept' in gt_decision:
1504
+ gt_binary = 1
1505
+ else:
1506
+ gt_binary = 0
1507
+
1508
+ true_decisions.append(gt_binary)
1509
+ pred_decisions.append(pred_binary)
1510
+
1511
+ if pred_decision == gt_decision or ('accept' in pred_decision and 'accept' in gt_decision) or ('reject' in pred_decision and 'reject' in gt_decision):
1512
+ decision_acc.append(1.0)
1513
+ else:
1514
+ decision_acc.append(0.0)
1515
+
1516
+ if decision_acc:
1517
+ decision_accuracy = sum(decision_acc) / len(decision_acc)
1518
+ try:
1519
+ _, _, f1_score, _ = precision_recall_fscore_support(true_decisions, pred_decisions, average='macro')
1520
+ summary['decision_metrics'] = {
1521
+ 'accuracy': decision_accuracy,
1522
+ 'f1_macro': f1_score,
1523
+ 'count': len(decision_acc)
1524
+ }
1525
+ except Exception:
1526
+ summary['decision_metrics'] = {
1527
+ 'accuracy': decision_accuracy,
1528
+ 'count': len(decision_acc)
1529
+ }
1530
+
1531
+ # Calculate Pairwise comparison
1532
+ # For refined format, only use refined results (avoid double counting)
1533
+ # For other formats, use all results
1534
+ if refined_format_count > 0:
1535
+ pairwise_results = refined_results
1536
+ else:
1537
+ pairwise_results = valid_results
1538
+
1539
+ paper_scores = []
1540
+ for r in pairwise_results:
1541
+ if (r.get('gt_rating') is not None and r.get('model_rating') is not None) or \
1542
+ (r.get('gt_soundness') is not None and r.get('model_soundness') is not None):
1543
+ paper_scores.append({
1544
+ 'true_rating': r.get('gt_rating'),
1545
+ 'pred_rating': r.get('model_rating'),
1546
+ 'true_soundness': r.get('gt_soundness'),
1547
+ 'pred_soundness': r.get('model_soundness'),
1548
+ 'true_presentation': r.get('gt_presentation'),
1549
+ 'pred_presentation': r.get('model_presentation'),
1550
+ 'true_confidence': r.get('gt_confidence'),
1551
+ 'pred_confidence': r.get('model_confidence')
1552
+ })
1553
+
1554
+ if len(paper_scores) >= 2:
1555
+ pairwise_accuracies = calculate_pairwise_accuracies(paper_scores)
1556
+ summary['pairwise_accuracies'] = pairwise_accuracies
1557
+
1558
+ return results, summary
1559
+
1560
+
1561
+ # ============================================================================
1562
+ # Main Function
1563
+ # ============================================================================
1564
+
1565
+ def parse_args():
1566
+ """Parse command line arguments."""
1567
+ parser = argparse.ArgumentParser(description="Unified evaluation script for semantic and auto-metric evaluation")
1568
+
1569
+ # Input paths
1570
+ parser.add_argument("--rubrics_path", type=str, required=True,
1571
+ help="Path to eval_rubrics.json file (from 1_generate_review_based_rubrics.py)")
1572
+ parser.add_argument("--reviews_path", type=str, required=True,
1573
+ help="Path to JSON file with model reviews (contains pred_fast_mode)")
1574
+
1575
+ # Evaluation mode
1576
+ parser.add_argument("--mode", type=str, choices=["semantic", "auto_metric", "both"], default="both",
1577
+ help="Evaluation mode: semantic (LLM-based), auto_metric (rule-based), or both")
1578
+
1579
+ # Output paths
1580
+ parser.add_argument("--semantic_output", type=str, default=None,
1581
+ help="Path to output JSON file for semantic evaluation results (required if mode is semantic or both)")
1582
+ parser.add_argument("--auto_metric_output", type=str, default=None,
1583
+ help="Path to output JSON file for auto-metric evaluation results (required if mode is auto_metric or both)")
1584
+
1585
+ # Semantic evaluation settings
1586
+ parser.add_argument("--yaml_path", type=str, default=None,
1587
+ help="Path to prompts.yaml file (required for semantic evaluation)")
1588
+ parser.add_argument("--config_path", type=str, default=None,
1589
+ help="Path to configs.yaml file (required for semantic evaluation)")
1590
+
1591
+ # Multi-threading
1592
+ parser.add_argument("--max_workers", type=int, default=None,
1593
+ help="Maximum number of worker threads for semantic evaluation (default: 5)")
1594
+
1595
+ # Strict mode (normalize scores to discrete scales)
1596
+ parser.add_argument("--strict_mode", action="store_true", default=False,
1597
+ help="Enable strict mode: normalize scores to discrete scales before computing metrics (default: False)")
1598
+
1599
+ # Input format override
1600
+ parser.add_argument("--input_format", type=str, choices=['auto', 'refined', 'original'], default='auto',
1601
+ help="Manually specify input JSON format: 'refined' (has scores and initial_scores), 'original' (has model_prediction), or 'auto' for auto-detection (default: 'auto')")
1602
+
1603
+ return parser.parse_args()
1604
+
1605
+
1606
+ def main():
1607
+ """Main execution function."""
1608
+ args = parse_args()
1609
+
1610
+ script_dir = os.path.dirname(os.path.abspath(__file__))
1611
+
1612
+ # Resolve paths
1613
+ rubrics_path = args.rubrics_path
1614
+ if not os.path.isabs(rubrics_path):
1615
+ rubrics_path = os.path.join(script_dir, rubrics_path)
1616
+
1617
+ reviews_path = args.reviews_path
1618
+ if not os.path.isabs(reviews_path):
1619
+ reviews_path = os.path.join(script_dir, reviews_path)
1620
+
1621
+ max_workers = args.max_workers or int(os.getenv("MAX_WORKERS", "5"))
1622
+
1623
+ # Validate mode and output paths
1624
+ if args.mode in ["semantic", "both"]:
1625
+ if not args.semantic_output:
1626
+ raise ValueError("--semantic_output is required when mode is 'semantic' or 'both'")
1627
+ if not args.yaml_path:
1628
+ raise ValueError("--yaml_path is required for semantic evaluation")
1629
+ if not args.config_path:
1630
+ raise ValueError("--config_path is required for semantic evaluation")
1631
+
1632
+ if args.mode in ["auto_metric", "both"]:
1633
+ if not args.auto_metric_output:
1634
+ raise ValueError("--auto_metric_output is required when mode is 'auto_metric' or 'both'")
1635
+
1636
+ # Check if files exist
1637
+ if not os.path.exists(rubrics_path):
1638
+ raise FileNotFoundError(f"Rubrics file not found: {rubrics_path}")
1639
+ if not os.path.exists(reviews_path):
1640
+ raise FileNotFoundError(f"Reviews file not found: {reviews_path}")
1641
+
1642
+ # Load data
1643
+ print(f"Loading rubrics from {rubrics_path}...")
1644
+ rubrics_data = load_rubrics_json(rubrics_path)
1645
+ print(f"Loaded {len(rubrics_data)} rubrics entries")
1646
+
1647
+ print(f"Loading model reviews from {reviews_path}...")
1648
+ if args.input_format != 'auto':
1649
+ print(f"Using manually specified format: {args.input_format}")
1650
+ else:
1651
+ print("Auto-detecting input format...")
1652
+ reviews_dict = load_model_reviews_json(reviews_path, format_override=args.input_format if args.input_format != 'auto' else None)
1653
+ print(f"Loaded {len(reviews_dict)} model reviews")
1654
+
1655
+ # Combine rubrics and reviews
1656
+ print("Combining rubrics and reviews...")
1657
+ evaluation_data = combine_rubrics_and_reviews(rubrics_data, reviews_dict)
1658
+ print(f"Prepared {len(evaluation_data)} entries for evaluation")
1659
+
1660
+ # Run evaluations based on mode
1661
+ if args.mode in ["semantic", "both"]:
1662
+ # Resolve semantic evaluation paths
1663
+ yaml_path = args.yaml_path
1664
+ if not os.path.isabs(yaml_path):
1665
+ yaml_path = os.path.join(script_dir, yaml_path)
1666
+
1667
+ config_path = args.config_path
1668
+ if not os.path.isabs(config_path):
1669
+ config_path = os.path.join(script_dir, config_path)
1670
+
1671
+ if not os.path.exists(yaml_path):
1672
+ raise FileNotFoundError(f"YAML file not found: {yaml_path}")
1673
+ if not os.path.exists(config_path):
1674
+ raise FileNotFoundError(f"Config file not found: {config_path}")
1675
+
1676
+ # Load prompt template
1677
+ print(f"Loading prompt template from {yaml_path}...")
1678
+ prompt_template = load_prompt_template(yaml_path)
1679
+ if not prompt_template:
1680
+ raise ValueError("Could not find 'v1_evaluator_prompt' in YAML file")
1681
+
1682
+ # Initialize LLM service
1683
+ print(f"Loading LLM configuration from {config_path}...")
1684
+ llm_config = load_llm_config(config_path)
1685
+ llm_service = create_llm_service_from_config(llm_config)
1686
+ mode = llm_config.get('mode', 'gpt')
1687
+ print(f"LLM service initialized (mode: {mode})")
1688
+ if hasattr(llm_service, 'model_name'):
1689
+ print(f"Using model: {llm_service.model_name}")
1690
+
1691
+ # Run semantic evaluation
1692
+ semantic_results, semantic_summary = run_semantic_evaluation(
1693
+ evaluation_data, prompt_template, llm_service, max_workers
1694
+ )
1695
+
1696
+ # Save semantic results
1697
+ semantic_output = args.semantic_output
1698
+ if not os.path.isabs(semantic_output):
1699
+ semantic_output = os.path.join(script_dir, semantic_output)
1700
+
1701
+ output_dir = os.path.dirname(semantic_output)
1702
+ os.makedirs(output_dir, exist_ok=True)
1703
+
1704
+ with open(semantic_output, 'w', encoding='utf-8') as f:
1705
+ json.dump(semantic_results, f, ensure_ascii=False, indent=2)
1706
+ print(f"\nSemantic evaluation results saved to {semantic_output}")
1707
+
1708
+ # Save semantic summary
1709
+ semantic_summary_path = semantic_output.replace('.json', '_summary.json')
1710
+ with open(semantic_summary_path, 'w', encoding='utf-8') as f:
1711
+ json.dump(semantic_summary, f, ensure_ascii=False, indent=2)
1712
+ print(f"Semantic evaluation summary saved to {semantic_summary_path}")
1713
+
1714
+ # Print semantic summary
1715
+ print("\n" + "="*80)
1716
+ print("SEMANTIC EVALUATION SUMMARY")
1717
+ print("="*80)
1718
+ print(f"Total entries: {semantic_summary['total_entries']}")
1719
+ print(f"Valid entries: {semantic_summary['valid_entries']}")
1720
+ print(f"Failed entries: {semantic_summary['failed_entries']}")
1721
+ if 'overall_score' in semantic_summary:
1722
+ score = semantic_summary['overall_score']
1723
+ print(f"\nOverall Score:")
1724
+ print(f" Mean: {score['mean']:.2f}")
1725
+ print(f" Min: {score['min']:.2f}")
1726
+ print(f" Max: {score['max']:.2f}")
1727
+
1728
+ if args.mode in ["auto_metric", "both"]:
1729
+ # Run auto-metric evaluation
1730
+ auto_metric_results, auto_metric_summary = run_auto_metric_evaluation(
1731
+ evaluation_data,
1732
+ strict_mode=args.strict_mode
1733
+ )
1734
+
1735
+ # Save auto-metric results
1736
+ auto_metric_output = args.auto_metric_output
1737
+ if not os.path.isabs(auto_metric_output):
1738
+ auto_metric_output = os.path.join(script_dir, auto_metric_output)
1739
+
1740
+ output_dir = os.path.dirname(auto_metric_output)
1741
+ os.makedirs(output_dir, exist_ok=True)
1742
+
1743
+ with open(auto_metric_output, 'w', encoding='utf-8') as f:
1744
+ json.dump(auto_metric_results, f, ensure_ascii=False, indent=2)
1745
+ print(f"\nAuto-metric evaluation results saved to {auto_metric_output}")
1746
+
1747
+ # Save auto-metric summary
1748
+ auto_metric_summary_path = auto_metric_output.replace('.json', '_summary.json')
1749
+ with open(auto_metric_summary_path, 'w', encoding='utf-8') as f:
1750
+ json.dump(auto_metric_summary, f, ensure_ascii=False, indent=2)
1751
+ print(f"Auto-metric evaluation summary saved to {auto_metric_summary_path}")
1752
+
1753
+ # Print auto-metric summary
1754
+ print("\n" + "="*80)
1755
+ print("AUTO-METRIC EVALUATION SUMMARY")
1756
+ print("="*80)
1757
+ print(f"Total entries: {auto_metric_summary['total_entries']}")
1758
+ print(f"Valid entries: {auto_metric_summary['valid_entries']}")
1759
+ print(f"MSE entries: {auto_metric_summary['mse_entries']}")
1760
+
1761
+ if 'mse_statistics' in auto_metric_summary:
1762
+ print("\nMSE Statistics:")
1763
+ for dim, stats in auto_metric_summary['mse_statistics'].items():
1764
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1765
+
1766
+ if 'mae_statistics' in auto_metric_summary:
1767
+ print("\nMAE Statistics:")
1768
+ for dim, stats in auto_metric_summary['mae_statistics'].items():
1769
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1770
+
1771
+ # Print refined and initial statistics if available
1772
+ if 'refined_mse_statistics' in auto_metric_summary:
1773
+ print("\nRefined Scores - MSE Statistics:")
1774
+ for dim, stats in auto_metric_summary['refined_mse_statistics'].items():
1775
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1776
+
1777
+ if 'refined_mae_statistics' in auto_metric_summary:
1778
+ print("\nRefined Scores - MAE Statistics:")
1779
+ for dim, stats in auto_metric_summary['refined_mae_statistics'].items():
1780
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1781
+
1782
+ if 'initial_mse_statistics' in auto_metric_summary:
1783
+ print("\nInitial Scores - MSE Statistics:")
1784
+ for dim, stats in auto_metric_summary['initial_mse_statistics'].items():
1785
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1786
+
1787
+ if 'initial_mae_statistics' in auto_metric_summary:
1788
+ print("\nInitial Scores - MAE Statistics:")
1789
+ for dim, stats in auto_metric_summary['initial_mae_statistics'].items():
1790
+ print(f" {dim.capitalize()}: Mean={stats['mean']:.4f}, Count={stats['count']}")
1791
+
1792
+ if 'spearman_correlations' in auto_metric_summary:
1793
+ print("\nSpearman Correlations:")
1794
+ for dim, stats in auto_metric_summary['spearman_correlations'].items():
1795
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1796
+
1797
+ # Print refined and initial spearman correlations if available
1798
+ if 'refined_spearman_correlations' in auto_metric_summary:
1799
+ print("\nRefined Scores - Spearman Correlations:")
1800
+ for dim, stats in auto_metric_summary['refined_spearman_correlations'].items():
1801
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1802
+
1803
+ if 'initial_spearman_correlations' in auto_metric_summary:
1804
+ print("\nInitial Scores - Spearman Correlations:")
1805
+ for dim, stats in auto_metric_summary['initial_spearman_correlations'].items():
1806
+ print(f" {dim.capitalize()}: {stats['correlation']:.4f} (n={stats['count']})")
1807
+
1808
+ if 'decision_metrics' in auto_metric_summary:
1809
+ dm = auto_metric_summary['decision_metrics']
1810
+ print(f"\nDecision Metrics:")
1811
+ print(f" Accuracy: {dm['accuracy']:.4f} (n={dm['count']})")
1812
+ if 'f1_macro' in dm:
1813
+ print(f" F1 (macro): {dm['f1_macro']:.4f}")
1814
+
1815
+ # Print refined and initial decision metrics if available
1816
+ if 'refined_decision_metrics' in auto_metric_summary:
1817
+ print("\nRefined Scores - Decision Metrics:")
1818
+ rdm = auto_metric_summary['refined_decision_metrics']
1819
+ print(f" Accuracy: {rdm['accuracy']:.4f} (n={rdm['count']})")
1820
+ if 'f1_macro' in rdm:
1821
+ print(f" F1 (macro): {rdm['f1_macro']:.4f}")
1822
+
1823
+ if 'initial_decision_metrics' in auto_metric_summary:
1824
+ print("\nInitial Scores - Decision Metrics:")
1825
+ idm = auto_metric_summary['initial_decision_metrics']
1826
+ print(f" Accuracy: {idm['accuracy']:.4f} (n={idm['count']})")
1827
+ if 'f1_macro' in idm:
1828
+ print(f" F1 (macro): {idm['f1_macro']:.4f}")
1829
+
1830
+ print("\n" + "="*80)
1831
+ print("EVALUATION COMPLETE")
1832
+ print("="*80)
1833
+
1834
+
1835
+ if __name__ == "__main__":
1836
+ main()
1837
+
src/evaluator/configs.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM Service Configuration for Rubric Generation
2
+ # Choose one mode: "gpt" or "vllm"
3
+ mode: "vllm" # Options: "gpt" or "vllm"
4
+
5
+ # GPT API Configuration (used when mode="gpt")
6
+ gpt:
7
+ api_key: "your-api-key-here" # Replace with your actual OpenAI API key or set OPENAI_API_KEY env var
8
+ model_name: "gpt-4o" # Options: gpt-4o, gpt-4-turbo, gpt-3.5-turbo, gpt-5, etc.
9
+ base_url: null # Default: https://api.openai.com/v1
10
+ timeout: 300
11
+
12
+ # Default sampling parameters
13
+ temperature: 0.7
14
+ top_p: 0.95
15
+ max_tokens: 16384
16
+ presence_penalty: 0.0
17
+
18
+ # vLLM Service Configuration (used when mode="vllm")
19
+ vllm:
20
+ base_url: "http://localhost:8000/" # vLLM server base URL
21
+ api_key: "dummy-key" # Not used for local vLLM, but required by OpenAI client
22
+ model_name: "openai/gpt-oss-120b" # Model name on vLLM server
23
+ timeout: 300
24
+
25
+ # Rate limiting: Maximum concurrent requests to vLLM server
26
+ max_concurrent_requests: 64
27
+
28
+ # Retry configuration for server errors
29
+ max_retries: 3
30
+ retry_delay: 1.0
31
+ retry_backoff: 2.0
32
+
33
+ # Default sampling parameters
34
+ temperature: 0.7
35
+ top_p: 0.8
36
+ top_k: 20
37
+ max_tokens: 16384
38
+ presence_penalty: 0.0