jmisak commited on
Commit
61c1961
Β·
verified Β·
1 Parent(s): 9619c6a

Upload llm.py

Browse files
Files changed (1) hide show
  1. llm.py +171 -47
llm.py CHANGED
@@ -180,86 +180,188 @@ def build_extraction_template(interviewee_type: str) -> str:
180
 
181
  def parse_structured_response(text: str, interviewee_type: str) -> Dict:
182
  """Extract structured data from LLM response"""
183
-
 
 
 
184
  # Try to find JSON block
185
  json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
186
-
187
  if json_match:
 
188
  try:
189
  data = json.loads(json_match.group())
190
- log(f"Successfully extracted JSON: {data}")
191
  return data
192
- except json.JSONDecodeError:
193
- log("Failed to parse JSON from response")
194
-
 
 
 
195
  # Fallback: Extract from text using patterns
196
  data = {}
197
-
198
  if interviewee_type == "HCP":
 
199
  # Extract diagnoses
200
  diag_pattern = r'(?:diagnos[ei]s|condition):\s*([^\n]+)'
201
  data["diagnoses"] = re.findall(diag_pattern, text, re.IGNORECASE)
202
-
203
  # Extract prescriptions
204
  rx_pattern = r'(?:prescri[bp]\w*|medication):\s*([^\n]+)'
205
  data["prescriptions"] = re.findall(rx_pattern, text, re.IGNORECASE)
206
-
207
  # Extract treatment rationale
208
  treat_pattern = r'(?:treatment|therapy|rationale):\s*([^\n]+)'
209
  data["treatment_rationale"] = re.findall(treat_pattern, text, re.IGNORECASE)
210
-
211
  elif interviewee_type == "Patient":
 
212
  # Extract symptoms
213
  symptom_pattern = r'(?:symptom|complaint|experienc\w*):\s*([^\n]+)'
214
  data["symptoms"] = re.findall(symptom_pattern, text, re.IGNORECASE)
215
-
216
  # Extract concerns
217
  concern_pattern = r'(?:concern|worry|question|anxious):\s*([^\n]+)'
218
  data["concerns"] = re.findall(concern_pattern, text, re.IGNORECASE)
219
-
220
  # Extract side effects
221
  se_pattern = r'(?:side effect|adverse|reaction):\s*([^\n]+)'
222
  data["side_effects"] = re.findall(se_pattern, text, re.IGNORECASE)
223
-
224
  # Clean and deduplicate
225
  for key in data:
226
  data[key] = list(set([item.strip() for item in data[key] if item.strip()]))
227
-
228
- log(f"Extracted data from text: {data}")
 
229
  return data
230
 
231
 
232
- def query_llm_hf_api(prompt: str, max_tokens: int = 500) -> str:
233
- """Use Hugging Face Inference API for better quality"""
 
 
 
 
 
 
 
 
 
 
 
 
234
  try:
235
- from huggingface_hub import InferenceClient
236
-
237
- client = InferenceClient(token=HF_TOKEN)
238
-
239
- # Use chat completions instead
240
- messages = [
241
- {"role": "system", "content": "You are an expert transcript analyzer. Provide detailed, structured analysis."},
242
- {"role": "user", "content": prompt}
243
- ]
244
-
245
- response = client.chat_completion(
246
- messages=messages,
247
- model="microsoft/Phi-3-mini-4k-instruct",
248
- max_tokens=max_tokens,
249
- temperature=0.3
250
- )
251
-
252
- return response.choices[0].message.content.strip()
253
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  except Exception as e:
255
  import traceback
256
  full_error = traceback.format_exc()
257
- log(f"HF API error: {e}\n{full_error}")
258
- print(f"[HF API Full Error]\n{full_error}") # Print to console
259
  return f"[Error] HF API failed: {e}"
260
 
261
 
262
- def query_llm_local(prompt: str, max_tokens: int = 500) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  """Local model optimized for L4 GPU"""
264
  try:
265
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@@ -346,12 +448,17 @@ Be specific and include relevant details (dosages, durations, severity levels, e
346
  log(f"Prompt truncated to {len(full_prompt)} characters")
347
 
348
  def generate():
349
- if os.getenv("USE_LMSTUDIO", "False").lower() == "true":
350
- return query_llm_lmstudio(full_prompt, max_tokens=600)
351
- elif USE_HF_API and HF_TOKEN:
352
- return query_llm_hf_api(full_prompt, max_tokens=600)
 
 
 
 
 
353
  else:
354
- return query_llm_local(full_prompt, max_tokens=600)
355
 
356
  # Execute with timeout
357
  with ThreadPoolExecutor(max_workers=1) as executor:
@@ -359,13 +466,30 @@ Be specific and include relevant details (dosages, durations, severity levels, e
359
  try:
360
  response = future.result(timeout=timeout)
361
  log(f"LLM response received ({len(response)} chars)")
362
-
363
  # Extract structured data if requested
364
  structured_data = {}
 
365
  if extract_structured:
366
  structured_data = parse_structured_response(response, interviewee_type)
367
-
368
- return response, structured_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  except ThreadTimeout:
371
  log("LLM generation timed out")
 
180
 
181
  def parse_structured_response(text: str, interviewee_type: str) -> Dict:
182
  """Extract structured data from LLM response"""
183
+
184
+ log(f"Parsing response ({len(text)} chars) for type: {interviewee_type}")
185
+ log(f"Response preview: {text[:500]}...")
186
+
187
  # Try to find JSON block
188
  json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
189
+
190
  if json_match:
191
+ log(f"Found JSON match: {json_match.group()[:200]}...")
192
  try:
193
  data = json.loads(json_match.group())
194
+ log(f"βœ… Successfully extracted JSON with {len(data)} fields: {list(data.keys())}")
195
  return data
196
+ except json.JSONDecodeError as e:
197
+ log(f"❌ JSON parsing failed: {e}")
198
+ log(f"Attempted to parse: {json_match.group()[:300]}")
199
+ else:
200
+ log("⚠️ No JSON block found in response, using regex fallback")
201
+
202
  # Fallback: Extract from text using patterns
203
  data = {}
204
+
205
  if interviewee_type == "HCP":
206
+ log("Using HCP extraction patterns...")
207
  # Extract diagnoses
208
  diag_pattern = r'(?:diagnos[ei]s|condition):\s*([^\n]+)'
209
  data["diagnoses"] = re.findall(diag_pattern, text, re.IGNORECASE)
210
+
211
  # Extract prescriptions
212
  rx_pattern = r'(?:prescri[bp]\w*|medication):\s*([^\n]+)'
213
  data["prescriptions"] = re.findall(rx_pattern, text, re.IGNORECASE)
214
+
215
  # Extract treatment rationale
216
  treat_pattern = r'(?:treatment|therapy|rationale):\s*([^\n]+)'
217
  data["treatment_rationale"] = re.findall(treat_pattern, text, re.IGNORECASE)
218
+
219
  elif interviewee_type == "Patient":
220
+ log("Using Patient extraction patterns...")
221
  # Extract symptoms
222
  symptom_pattern = r'(?:symptom|complaint|experienc\w*):\s*([^\n]+)'
223
  data["symptoms"] = re.findall(symptom_pattern, text, re.IGNORECASE)
224
+
225
  # Extract concerns
226
  concern_pattern = r'(?:concern|worry|question|anxious):\s*([^\n]+)'
227
  data["concerns"] = re.findall(concern_pattern, text, re.IGNORECASE)
228
+
229
  # Extract side effects
230
  se_pattern = r'(?:side effect|adverse|reaction):\s*([^\n]+)'
231
  data["side_effects"] = re.findall(se_pattern, text, re.IGNORECASE)
232
+
233
  # Clean and deduplicate
234
  for key in data:
235
  data[key] = list(set([item.strip() for item in data[key] if item.strip()]))
236
+
237
+ log(f"Fallback extraction result: {len(data)} fields, {sum(len(v) for v in data.values())} total items")
238
+ log(f"Extracted fields: {data}")
239
  return data
240
 
241
 
242
+ def query_llm_hf_api(prompt: str, max_tokens: int = 1500) -> str:
243
+ """Use Hugging Face Inference API with proper authentication"""
244
+ import requests
245
+ import json
246
+
247
+ hf_token = os.getenv("HUGGINGFACE_TOKEN", "")
248
+
249
+ if not hf_token:
250
+ error_msg = "[Error] HUGGINGFACE_TOKEN not set in environment!"
251
+ print(f"❌ {error_msg}")
252
+ return error_msg
253
+
254
+ print(f"[HF API] Using token for authentication: {hf_token[:20]}...")
255
+
256
  try:
257
+ # Get model from environment variable (default to Phi-3 if not set)
258
+ hf_model = os.getenv("HF_MODEL", "microsoft/Phi-3-mini-4k-instruct")
259
+ API_URL = f"https://api-inference.huggingface.co/models/{hf_model}"
260
+
261
+ # Use Bearer token in Authorization header
262
+ headers = {
263
+ "Authorization": f"Bearer {hf_token}",
264
+ "Content-Type": "application/json"
265
+ }
266
+
267
+ # Get temperature from environment
268
+ temperature = float(os.getenv("LLM_TEMPERATURE", "0.5"))
269
+
270
+ # Use the FULL prompt (don't truncate - the model can handle it)
271
+ payload = {
272
+ "inputs": prompt,
273
+ "parameters": {
274
+ "max_new_tokens": max_tokens, # Use parameter passed to function
275
+ "temperature": temperature,
276
+ "return_full_text": False
277
+ }
278
+ }
279
+
280
+ # Get timeout from environment
281
+ timeout = int(os.getenv("LLM_TIMEOUT", "60"))
282
+
283
+ print(f"[HF API] Calling {hf_model} ({max_tokens} tokens, temp={temperature})...")
284
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
285
+
286
+ print(f"[HF API] Status code: {response.status_code}")
287
+
288
+ if response.status_code == 200:
289
+ result = response.json()
290
+ if isinstance(result, list) and len(result) > 0:
291
+ generated_text = result[0].get("generated_text", "")
292
+ print(f"[HF API] βœ… Response: {len(generated_text)} characters")
293
+ print(f"[HF API] First 200 chars: {generated_text[:200]}")
294
+ return generated_text
295
+ else:
296
+ print(f"[HF API] Unexpected response format: {result}")
297
+ return "[Error] Unexpected API response format"
298
+ elif response.status_code == 401:
299
+ print(f"[HF API] ❌ 401 Unauthorized - Token invalid or expired")
300
+ print(f"[HF API] Token used: {hf_token}")
301
+ print(f"[HF API] Response: {response.text[:500]}")
302
+ return "[Error] Invalid HuggingFace token - create a new one at https://huggingface.co/settings/tokens"
303
+ else:
304
+ print(f"[HF API] Failed with status {response.status_code}")
305
+ print(f"[HF API] Response: {response.text[:500]}")
306
+ return f"[Error] API returned status {response.status_code}"
307
+
308
  except Exception as e:
309
  import traceback
310
  full_error = traceback.format_exc()
311
+ print(f"[HF API] Error:\n{full_error}")
 
312
  return f"[Error] HF API failed: {e}"
313
 
314
 
315
+ def query_llm_lmstudio(prompt: str, max_tokens: int = 1500) -> str:
316
+ """Query LM Studio local server (OpenAI-compatible API)"""
317
+ import requests
318
+ import json
319
+
320
+ lmstudio_url = os.getenv("LMSTUDIO_URL", "http://localhost:1234/v1/chat/completions")
321
+
322
+ print(f"[LM Studio] Calling {lmstudio_url}...")
323
+
324
+ try:
325
+ payload = {
326
+ "messages": [
327
+ {
328
+ "role": "user",
329
+ "content": prompt
330
+ }
331
+ ],
332
+ "temperature": float(os.getenv("LLM_TEMPERATURE", "0.7")),
333
+ "max_tokens": max_tokens,
334
+ "stream": False
335
+ }
336
+
337
+ response = requests.post(lmstudio_url, json=payload, timeout=120)
338
+
339
+ print(f"[LM Studio] Status code: {response.status_code}")
340
+
341
+ if response.status_code == 200:
342
+ result = response.json()
343
+ generated_text = result["choices"][0]["message"]["content"]
344
+ print(f"[LM Studio] βœ“ Response: {len(generated_text)} characters")
345
+ print(f"[LM Studio] First 300 chars: {generated_text[:300]}")
346
+ return generated_text
347
+ else:
348
+ error_msg = f"[Error] LM Studio returned status {response.status_code}: {response.text[:200]}"
349
+ print(f"[LM Studio] {error_msg}")
350
+ return error_msg
351
+
352
+ except requests.exceptions.ConnectionError:
353
+ error_msg = "[Error] Cannot connect to LM Studio. Make sure:\n1. LM Studio is running\n2. Server is started (in LM Studio's Server tab)\n3. A model is loaded\n4. Server is on http://localhost:1234"
354
+ print(f"[LM Studio] {error_msg}")
355
+ return error_msg
356
+ except Exception as e:
357
+ error_msg = f"[Error] LM Studio failed: {e}"
358
+ print(f"[LM Studio] {error_msg}")
359
+ import traceback
360
+ traceback.print_exc()
361
+ return error_msg
362
+
363
+
364
+ def query_llm_local(prompt: str, max_tokens: int = 1500) -> str:
365
  """Local model optimized for L4 GPU"""
366
  try:
367
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
448
  log(f"Prompt truncated to {len(full_prompt)} characters")
449
 
450
  def generate():
451
+ # Check environment variables dynamically (not using module-level USE_HF_API)
452
+ use_lmstudio = os.getenv("USE_LMSTUDIO", "False").lower() == "true"
453
+ use_hf_api = os.getenv("USE_HF_API", "False").lower() == "true"
454
+ hf_token = os.getenv("HUGGINGFACE_TOKEN", "")
455
+
456
+ if use_lmstudio:
457
+ return query_llm_lmstudio(full_prompt, max_tokens=2000)
458
+ elif use_hf_api and hf_token:
459
+ return query_llm_hf_api(full_prompt, max_tokens=1500)
460
  else:
461
+ return query_llm_local(full_prompt, max_tokens=1500)
462
 
463
  # Execute with timeout
464
  with ThreadPoolExecutor(max_workers=1) as executor:
 
466
  try:
467
  response = future.result(timeout=timeout)
468
  log(f"LLM response received ({len(response)} chars)")
469
+
470
  # Extract structured data if requested
471
  structured_data = {}
472
+ clean_response = response
473
  if extract_structured:
474
  structured_data = parse_structured_response(response, interviewee_type)
475
+
476
+ # Remove JSON blocks from the narrative text (handle nested braces)
477
+ # Remove all {....} blocks repeatedly until none remain
478
+ prev_response = ""
479
+ while prev_response != clean_response:
480
+ prev_response = clean_response
481
+ clean_response = re.sub(r'\{[^{}]*\}', '', clean_response, flags=re.DOTALL)
482
+
483
+ # Also remove common JSON artifacts
484
+ clean_response = re.sub(r'###\s*JSON\s*Structure:', '', clean_response, flags=re.IGNORECASE)
485
+ clean_response = re.sub(r'###\s*Analysis:', '', clean_response, flags=re.IGNORECASE)
486
+ clean_response = re.sub(r'###\s*Response:', '', clean_response, flags=re.IGNORECASE)
487
+ clean_response = re.sub(r'Please provide.*?structured JSON.*', '', clean_response, flags=re.IGNORECASE|re.DOTALL)
488
+
489
+ clean_response = clean_response.strip()
490
+ log(f"Cleaned response: {len(clean_response)} chars (removed JSON)")
491
+
492
+ return clean_response, structured_data
493
 
494
  except ThreadTimeout:
495
  log("LLM generation timed out")