Tirath5504 commited on
Commit
6c03205
·
verified ·
1 Parent(s): 64acd41

Update pipeline/disagreement_resolution.py

Browse files
Files changed (1) hide show
  1. pipeline/disagreement_resolution.py +114 -95
pipeline/disagreement_resolution.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
  from typing import List, Dict
4
  from openai import OpenAI
5
  from pydantic import BaseModel
@@ -14,6 +15,16 @@ client = OpenAI(
14
  api_key=os.getenv("OPENROUTER_API_KEY"),
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
17
  class ResolutionDetails(BaseModel):
18
  accepted_critique_points: Dict[str, List[str]]
19
  rejected_critique_points: Dict[str, List[str]]
@@ -33,22 +44,13 @@ def construct_resolution_prompt(
33
  ) -> tuple:
34
  """
35
  Construct prompt for disagreement resolution
36
-
37
- Args:
38
- paper_title: Title of the paper
39
- paper_abstract: Abstract of the paper
40
- disagreement: Disagreement analysis results
41
- combined_critiques: Combined critique points
42
- sota_results: State-of-the-art findings
43
- retrieved_evidence: Retrieved evidence per category
44
-
45
- Returns:
46
- Tuple of (system_prompt, user_prompt)
47
  """
48
  system_prompt = """
49
  You are an AI specialized in resolving academic peer review disagreements.
50
  Your task is to analyze critiques, verify evidence, and provide a structured resolution.
51
 
 
 
52
  Respond in the following JSON format:
53
  {
54
  "accepted_critique_points": {"category": ["critique_1", "critique_2"]},
@@ -73,14 +75,11 @@ def construct_resolution_prompt(
73
  - **Novelty:** {', '.join(disagreement_details.get('Novelty', ['N/A']))}
74
 
75
  ### **Supporting Information**
76
- **Combined Critique Points from Reviews:**
77
- {json.dumps(combined_critiques, indent=2)}
78
 
79
- **State-of-the-Art (SoTA) Findings:**
80
- {sota_results[:2000]}
81
 
82
- **Retrieved Evidence:**
83
- {json.dumps(retrieved_evidence, indent=2)[:2000]}
84
 
85
  ### **Resolution Task**
86
  1. Validate critique points and categorize them into accepted or rejected.
@@ -92,6 +91,37 @@ def construct_resolution_prompt(
92
 
93
  return system_prompt, user_prompt
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  async def resolve_single_disagreement(
96
  paper_title: str,
97
  paper_abstract: str,
@@ -99,22 +129,10 @@ async def resolve_single_disagreement(
99
  combined_critiques: Dict,
100
  sota_results: str,
101
  retrieved_evidence: Dict,
102
- retries: int = 5
103
  ) -> Dict:
104
  """
105
- Resolve a single disagreement using DeepSeek-R1
106
-
107
- Args:
108
- paper_title: Paper title
109
- paper_abstract: Paper abstract
110
- disagreement: Disagreement analysis
111
- combined_critiques: Combined critique points
112
- sota_results: SoTA findings
113
- retrieved_evidence: Evidence per category
114
- retries: Maximum retry attempts
115
-
116
- Returns:
117
- Resolution results
118
  """
119
  system_prompt, user_prompt = construct_resolution_prompt(
120
  paper_title,
@@ -130,61 +148,72 @@ async def resolve_single_disagreement(
130
  {"role": "user", "content": user_prompt},
131
  ]
132
 
133
- for attempt in range(retries):
134
- try:
135
- response = await asyncio.to_thread(
136
- client.chat.completions.create,
137
- model="deepseek/deepseek-r1",
138
- messages=messages,
139
- response_format={"type": "json_object"},
140
- )
141
-
142
- if not response.choices or not response.choices[0].message.content.strip():
143
- raise ValueError("Empty response from DeepSeek-R1")
144
-
145
- # Parse response (remove potential prefix)
146
- content = response.choices[0].message.content.strip()
147
- if content.startswith("```json"):
148
- content = content[7:-3].strip()
149
- elif content.startswith("```"):
150
- content = content[3:-3].strip()
151
-
152
- llm_output = json.loads(content)
153
-
154
- # Validate required keys
155
- required_keys = {
156
- "accepted_critique_points",
157
- "rejected_critique_points",
158
- "final_resolution_summary"
159
- }
160
-
161
- if not required_keys.issubset(llm_output.keys()):
162
- raise ValueError(f"Missing keys. Present: {llm_output.keys()}")
163
-
164
- # Validate structure
165
- resolution = DisagreementResolutionResult(
166
- review_pair=disagreement.get('review_pair', [0, 1]),
167
- resolution_details=ResolutionDetails(**llm_output)
168
- )
169
-
170
- return resolution.model_dump()
171
-
172
- except Exception as e:
173
- wait_time = 2 ** attempt
174
- print(f"Resolution attempt {attempt + 1} failed: {e}")
175
-
176
- if attempt < retries - 1:
177
- await asyncio.sleep(wait_time)
178
- else:
179
- return {
180
- "review_pair": disagreement.get('review_pair', [0, 1]),
181
- "resolution_details": {
182
- "accepted_critique_points": {},
183
- "rejected_critique_points": {},
184
- "final_resolution_summary": f"Error: {str(e)}"
185
- },
186
- "error": str(e)
187
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  async def resolve_disagreements(
190
  paper_title: str,
@@ -195,16 +224,6 @@ async def resolve_disagreements(
195
  ) -> List[Dict]:
196
  """
197
  Resolve all disagreements
198
-
199
- Args:
200
- paper_title: Paper title
201
- paper_abstract: Paper abstract
202
- disagreements: List of disagreement analyses
203
- critique_points: List of critique points
204
- search_results: Search and retrieval results
205
-
206
- Returns:
207
- List of resolution results
208
  """
209
  if not disagreements:
210
  return []
 
1
  import json
2
  import os
3
+ import re
4
  from typing import List, Dict
5
  from openai import OpenAI
6
  from pydantic import BaseModel
 
15
  api_key=os.getenv("OPENROUTER_API_KEY"),
16
  )
17
 
18
+ # Priority list of models to try
19
+ # 1. DeepSeek R1 (Best reasoning, most expensive)
20
+ # 2. DeepSeek R1 Distill (Good reasoning, cheaper)
21
+ # 3. Gemini 2.0 Flash (Free/Cheap, very fast fallback)
22
+ MODELS = [
23
+ "deepseek/deepseek-r1",
24
+ "deepseek/deepseek-r1-distill-llama-70b",
25
+ "google/gemini-2.0-flash-exp:free"
26
+ ]
27
+
28
  class ResolutionDetails(BaseModel):
29
  accepted_critique_points: Dict[str, List[str]]
30
  rejected_critique_points: Dict[str, List[str]]
 
44
  ) -> tuple:
45
  """
46
  Construct prompt for disagreement resolution
 
 
 
 
 
 
 
 
 
 
 
47
  """
48
  system_prompt = """
49
  You are an AI specialized in resolving academic peer review disagreements.
50
  Your task is to analyze critiques, verify evidence, and provide a structured resolution.
51
 
52
+ IMPORTANT: detailed reasoning is allowed, but the FINAL output must be valid JSON only.
53
+
54
  Respond in the following JSON format:
55
  {
56
  "accepted_critique_points": {"category": ["critique_1", "critique_2"]},
 
75
  - **Novelty:** {', '.join(disagreement_details.get('Novelty', ['N/A']))}
76
 
77
  ### **Supporting Information**
78
+ **Combined Critique Points from Reviews:** {json.dumps(combined_critiques, indent=2)}
 
79
 
80
+ **State-of-the-Art (SoTA) Findings:** {sota_results[:2000]}
 
81
 
82
+ **Retrieved Evidence:** {json.dumps(retrieved_evidence, indent=2)[:2000]}
 
83
 
84
  ### **Resolution Task**
85
  1. Validate critique points and categorize them into accepted or rejected.
 
91
 
92
  return system_prompt, user_prompt
93
 
94
+ def extract_json_from_text(text: str) -> Dict:
95
+ """
96
+ Robustly extract JSON from text that might contain markdown or thinking traces.
97
+ """
98
+ try:
99
+ # 1. Try straightforward parse
100
+ return json.loads(text)
101
+ except json.JSONDecodeError:
102
+ pass
103
+
104
+ # 2. Try removing markdown code blocks
105
+ if "```json" in text:
106
+ pattern = r"```json(.*?)```"
107
+ match = re.search(pattern, text, re.DOTALL)
108
+ if match:
109
+ try:
110
+ return json.loads(match.group(1).strip())
111
+ except:
112
+ pass
113
+
114
+ # 3. Regex search for the outermost curly braces
115
+ # This handles cases where DeepSeek outputs <think>...</think> before the JSON
116
+ try:
117
+ match = re.search(r"(\{.*\})", text, re.DOTALL)
118
+ if match:
119
+ return json.loads(match.group(1))
120
+ except:
121
+ pass
122
+
123
+ raise ValueError("Could not extract valid JSON from model response")
124
+
125
  async def resolve_single_disagreement(
126
  paper_title: str,
127
  paper_abstract: str,
 
129
  combined_critiques: Dict,
130
  sota_results: str,
131
  retrieved_evidence: Dict,
132
+ retries: int = 3 # Reduced retries since we have model fallback
133
  ) -> Dict:
134
  """
135
+ Resolve a single disagreement with Model Fallback and Token Limits
 
 
 
 
 
 
 
 
 
 
 
 
136
  """
137
  system_prompt, user_prompt = construct_resolution_prompt(
138
  paper_title,
 
148
  {"role": "user", "content": user_prompt},
149
  ]
150
 
151
+ last_exception = None
152
+
153
+ # Loop through available models in case of error (402 Payment, 429 Rate Limit)
154
+ for model in MODELS:
155
+ print(f"Attempting resolution with model: {model}")
156
+
157
+ for attempt in range(retries):
158
+ try:
159
+ response = await asyncio.to_thread(
160
+ client.chat.completions.create,
161
+ model=model,
162
+ messages=messages,
163
+ # CRITICAL FIX: Limit max_tokens to prevent "Insufficient Credits" error
164
+ # OpenRouter reserves credits based on this number.
165
+ max_tokens=4096,
166
+ response_format={"type": "json_object"},
167
+ )
168
+
169
+ if not response.choices or not response.choices[0].message.content.strip():
170
+ raise ValueError("Empty response from AI")
171
+
172
+ content = response.choices[0].message.content.strip()
173
+ llm_output = extract_json_from_text(content)
174
+
175
+ # Validate required keys
176
+ required_keys = {
177
+ "accepted_critique_points",
178
+ "rejected_critique_points",
179
+ "final_resolution_summary"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  }
181
+
182
+ if not required_keys.issubset(llm_output.keys()):
183
+ raise ValueError(f"Missing keys. Present: {llm_output.keys()}")
184
+
185
+ # Validate structure
186
+ resolution = DisagreementResolutionResult(
187
+ review_pair=disagreement.get('review_pair', [0, 1]),
188
+ resolution_details=ResolutionDetails(**llm_output)
189
+ )
190
+
191
+ return resolution.model_dump()
192
+
193
+ except Exception as e:
194
+ last_exception = e
195
+ error_msg = str(e)
196
+ print(f"Model {model} - Attempt {attempt + 1} failed: {error_msg}")
197
+
198
+ # Immediate fallback on payment errors
199
+ if "402" in error_msg or "insufficient_quota" in error_msg:
200
+ print("Insufficient credits detected. Switching to cheaper model...")
201
+ break # Break retry loop to go to next model
202
+
203
+ wait_time = 2 ** attempt
204
+ if attempt < retries - 1:
205
+ await asyncio.sleep(wait_time)
206
+
207
+ # If all models and retries fail
208
+ return {
209
+ "review_pair": disagreement.get('review_pair', [0, 1]),
210
+ "resolution_details": {
211
+ "accepted_critique_points": {},
212
+ "rejected_critique_points": {},
213
+ "final_resolution_summary": f"Failed to resolve disagreement after trying multiple models. Final Error: {str(last_exception)}"
214
+ },
215
+ "error": str(last_exception)
216
+ }
217
 
218
  async def resolve_disagreements(
219
  paper_title: str,
 
224
  ) -> List[Dict]:
225
  """
226
  Resolve all disagreements
 
 
 
 
 
 
 
 
 
 
227
  """
228
  if not disagreements:
229
  return []