SsebaA commited on
Commit
bdb791b
·
verified ·
1 Parent(s): b03f59f

Update vips_classifier.py

Browse files
Files changed (1) hide show
  1. vips_classifier.py +92 -34
vips_classifier.py CHANGED
@@ -103,27 +103,65 @@ def parse_vips_response(response: str) -> dict:
103
 
104
 
105
  def parse_combined_response(response: str) -> dict:
106
- """Split combined response by section headers and parse each block."""
107
- headers = {
108
- "zero_shot": "===ZERO-SHOT===",
109
- "few_shot": "===FEW-SHOT===",
110
- "chain_of_thought": "===CHAIN-OF-THOUGHT===",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
- results = {}
113
- header_positions = {k: response.find(h) for k, h in headers.items()}
114
-
115
- for key, header in headers.items():
116
- pos = header_positions[key]
117
- if pos == -1:
118
- results[key] = {k: "Ingen relevant information." for k in ["V", "I", "P", "S"]}
119
- continue
120
- # Section ends at next header
121
- next_pos = [p for p in header_positions.values() if p > pos]
122
- end = min(next_pos) if next_pos else len(response)
123
- section = response[pos:end]
124
- results[key] = parse_vips_response(section)
125
-
126
- return results
127
 
128
 
129
  def format_vips_for_display(vips: dict) -> str:
@@ -139,18 +177,38 @@ def format_vips_for_display(vips: dict) -> str:
139
 
140
  def classify_all(english_text: str, mistral_client) -> dict:
141
  """
142
- Run all three VIPS strategies in ONE Mistral API call.
143
- 3x faster than sequential calls. Retry handles rate limits automatically.
144
  """
145
- logger.info("Classifying with all 3 strategies in single API call...")
146
- try:
147
- raw = mistral_client.generate(
148
- prompt=build_prompt_combined(english_text),
149
- max_tokens=1000,
150
- temperature=Config.LLM_TEMPERATURE,
151
- )
152
- return parse_combined_response(raw)
153
- except Exception as e:
154
- logger.error(f"Combined classification failed: {e}")
155
- empty = {k: f"[FEL: {e}]" for k in ["V", "I", "P", "S"]}
156
- return {"zero_shot": empty, "few_shot": empty, "chain_of_thought": empty}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  def parse_combined_response(response: str) -> dict:
106
+ """
107
+ Parse combined response containing 3 VIPS sections.
108
+ Robust against marker variations: ===ZERO-SHOT===, ZERO-SHOT:, ### ZERO-SHOT, etc.
109
+ Splits by finding the 3 VIPS blocks (each has V/I/P/S lines).
110
+ """
111
+ import re
112
+
113
+ # Try to find section headers with flexible matching
114
+ # Match: ZERO-SHOT, ZERO SHOT, Zero-shot, ZEROSHOT, Method 1, etc.
115
+ patterns = [
116
+ ("zero_shot", r'(?i)(?:=+\s*)?(?:###\s*)?(?:method\s*1|zero[\s\-]?shot)(?:\s*=+)?'),
117
+ ("few_shot", r'(?i)(?:=+\s*)?(?:###\s*)?(?:method\s*2|few[\s\-]?shot)(?:\s*=+)?'),
118
+ ("chain_of_thought", r'(?i)(?:=+\s*)?(?:###\s*)?(?:method\s*3|chain[\s\-]?of[\s\-]?thought|cot)(?:\s*=+)?'),
119
+ ]
120
+
121
+ # Find position of each section
122
+ positions = {}
123
+ for key, pattern in patterns:
124
+ matches = list(re.finditer(pattern, response))
125
+ if matches:
126
+ # Take the first occurrence (the header, not VIPS content)
127
+ positions[key] = matches[0].start()
128
+
129
+ # If we found all 3, split by them
130
+ if len(positions) == 3:
131
+ sorted_keys = sorted(positions.keys(), key=lambda k: positions[k])
132
+ results = {}
133
+ for i, key in enumerate(sorted_keys):
134
+ start = positions[key]
135
+ end = positions[sorted_keys[i+1]] if i+1 < len(sorted_keys) else len(response)
136
+ section = response[start:end]
137
+ results[key] = parse_vips_response(section)
138
+ return results
139
+
140
+ # Fallback: split by VIPS blocks
141
+ # Find all "V (Välbefinnande):" or "V:" lines and their positions
142
+ v_matches = list(re.finditer(r'^V\s*(?:\(|:)', response, re.MULTILINE))
143
+
144
+ if len(v_matches) >= 3:
145
+ # We have at least 3 V-blocks, treat them as the 3 strategies in order
146
+ results = {}
147
+ keys = ["zero_shot", "few_shot", "chain_of_thought"]
148
+ for i, key in enumerate(keys):
149
+ if i < len(v_matches):
150
+ start = v_matches[i].start()
151
+ end = v_matches[i+1].start() if i+1 < len(v_matches) else len(response)
152
+ section = response[start:end]
153
+ results[key] = parse_vips_response(section)
154
+ else:
155
+ results[key] = {k: "Ingen relevant information." for k in ["V", "I", "P", "S"]}
156
+ return results
157
+
158
+ # Final fallback: same result for all 3
159
+ single = parse_vips_response(response)
160
+ return {
161
+ "zero_shot": single,
162
+ "few_shot": single,
163
+ "chain_of_thought": single,
164
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  def format_vips_for_display(vips: dict) -> str:
 
177
 
178
  def classify_all(english_text: str, mistral_client) -> dict:
179
  """
180
+ Run all three prompt strategies in PARALLEL (Scaleway is fast ~2-3s).
181
+ Each strategy gets its own API call, ensuring distinct results.
182
  """
183
+ from concurrent.futures import ThreadPoolExecutor, as_completed
184
+
185
+ strategies = {
186
+ "zero_shot": (build_prompt_zero_shot, 300),
187
+ "few_shot": (build_prompt_few_shot, 300),
188
+ "chain_of_thought": (build_prompt_chain_of_thought, 500),
189
+ }
190
+
191
+ def run_one(name, prompt_fn, max_tokens):
192
+ try:
193
+ raw = mistral_client.generate(
194
+ prompt=prompt_fn(english_text),
195
+ max_tokens=max_tokens,
196
+ temperature=Config.LLM_TEMPERATURE,
197
+ )
198
+ logger.info(f"{name}: {len(raw)} chars")
199
+ return name, parse_vips_response(raw)
200
+ except Exception as e:
201
+ logger.error(f"{name} failed: {e}")
202
+ return name, {k: f"[FEL: {e}]" for k in ["V", "I", "P", "S"]}
203
+
204
+ results = {}
205
+ with ThreadPoolExecutor(max_workers=3) as executor:
206
+ futures = {
207
+ executor.submit(run_one, name, fn, tokens): name
208
+ for name, (fn, tokens) in strategies.items()
209
+ }
210
+ for future in as_completed(futures):
211
+ name, result = future.result()
212
+ results[name] = result
213
+
214
+ return results