Spaces:
Sleeping
Sleeping
Upload llm.py
Browse files
llm.py
CHANGED
|
@@ -180,86 +180,188 @@ def build_extraction_template(interviewee_type: str) -> str:
|
|
| 180 |
|
| 181 |
def parse_structured_response(text: str, interviewee_type: str) -> Dict:
|
| 182 |
"""Extract structured data from LLM response"""
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
| 184 |
# Try to find JSON block
|
| 185 |
json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
| 186 |
-
|
| 187 |
if json_match:
|
|
|
|
| 188 |
try:
|
| 189 |
data = json.loads(json_match.group())
|
| 190 |
-
log(f"Successfully extracted JSON: {data}")
|
| 191 |
return data
|
| 192 |
-
except json.JSONDecodeError:
|
| 193 |
-
log("
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
| 195 |
# Fallback: Extract from text using patterns
|
| 196 |
data = {}
|
| 197 |
-
|
| 198 |
if interviewee_type == "HCP":
|
|
|
|
| 199 |
# Extract diagnoses
|
| 200 |
diag_pattern = r'(?:diagnos[ei]s|condition):\s*([^\n]+)'
|
| 201 |
data["diagnoses"] = re.findall(diag_pattern, text, re.IGNORECASE)
|
| 202 |
-
|
| 203 |
# Extract prescriptions
|
| 204 |
rx_pattern = r'(?:prescri[bp]\w*|medication):\s*([^\n]+)'
|
| 205 |
data["prescriptions"] = re.findall(rx_pattern, text, re.IGNORECASE)
|
| 206 |
-
|
| 207 |
# Extract treatment rationale
|
| 208 |
treat_pattern = r'(?:treatment|therapy|rationale):\s*([^\n]+)'
|
| 209 |
data["treatment_rationale"] = re.findall(treat_pattern, text, re.IGNORECASE)
|
| 210 |
-
|
| 211 |
elif interviewee_type == "Patient":
|
|
|
|
| 212 |
# Extract symptoms
|
| 213 |
symptom_pattern = r'(?:symptom|complaint|experienc\w*):\s*([^\n]+)'
|
| 214 |
data["symptoms"] = re.findall(symptom_pattern, text, re.IGNORECASE)
|
| 215 |
-
|
| 216 |
# Extract concerns
|
| 217 |
concern_pattern = r'(?:concern|worry|question|anxious):\s*([^\n]+)'
|
| 218 |
data["concerns"] = re.findall(concern_pattern, text, re.IGNORECASE)
|
| 219 |
-
|
| 220 |
# Extract side effects
|
| 221 |
se_pattern = r'(?:side effect|adverse|reaction):\s*([^\n]+)'
|
| 222 |
data["side_effects"] = re.findall(se_pattern, text, re.IGNORECASE)
|
| 223 |
-
|
| 224 |
# Clean and deduplicate
|
| 225 |
for key in data:
|
| 226 |
data[key] = list(set([item.strip() for item in data[key] if item.strip()]))
|
| 227 |
-
|
| 228 |
-
log(f"
|
|
|
|
| 229 |
return data
|
| 230 |
|
| 231 |
|
| 232 |
-
def query_llm_hf_api(prompt: str, max_tokens: int =
|
| 233 |
-
"""Use Hugging Face Inference API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
try:
|
| 235 |
-
from
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
# Use
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
except Exception as e:
|
| 255 |
import traceback
|
| 256 |
full_error = traceback.format_exc()
|
| 257 |
-
|
| 258 |
-
print(f"[HF API Full Error]\n{full_error}") # Print to console
|
| 259 |
return f"[Error] HF API failed: {e}"
|
| 260 |
|
| 261 |
|
| 262 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
"""Local model optimized for L4 GPU"""
|
| 264 |
try:
|
| 265 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
@@ -346,12 +448,17 @@ Be specific and include relevant details (dosages, durations, severity levels, e
|
|
| 346 |
log(f"Prompt truncated to {len(full_prompt)} characters")
|
| 347 |
|
| 348 |
def generate():
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
else:
|
| 354 |
-
return query_llm_local(full_prompt, max_tokens=
|
| 355 |
|
| 356 |
# Execute with timeout
|
| 357 |
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
@@ -359,13 +466,30 @@ Be specific and include relevant details (dosages, durations, severity levels, e
|
|
| 359 |
try:
|
| 360 |
response = future.result(timeout=timeout)
|
| 361 |
log(f"LLM response received ({len(response)} chars)")
|
| 362 |
-
|
| 363 |
# Extract structured data if requested
|
| 364 |
structured_data = {}
|
|
|
|
| 365 |
if extract_structured:
|
| 366 |
structured_data = parse_structured_response(response, interviewee_type)
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
except ThreadTimeout:
|
| 371 |
log("LLM generation timed out")
|
|
|
|
| 180 |
|
| 181 |
def parse_structured_response(text: str, interviewee_type: str) -> Dict:
|
| 182 |
"""Extract structured data from LLM response"""
|
| 183 |
+
|
| 184 |
+
log(f"Parsing response ({len(text)} chars) for type: {interviewee_type}")
|
| 185 |
+
log(f"Response preview: {text[:500]}...")
|
| 186 |
+
|
| 187 |
# Try to find JSON block
|
| 188 |
json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)
|
| 189 |
+
|
| 190 |
if json_match:
|
| 191 |
+
log(f"Found JSON match: {json_match.group()[:200]}...")
|
| 192 |
try:
|
| 193 |
data = json.loads(json_match.group())
|
| 194 |
+
log(f"β
Successfully extracted JSON with {len(data)} fields: {list(data.keys())}")
|
| 195 |
return data
|
| 196 |
+
except json.JSONDecodeError as e:
|
| 197 |
+
log(f"β JSON parsing failed: {e}")
|
| 198 |
+
log(f"Attempted to parse: {json_match.group()[:300]}")
|
| 199 |
+
else:
|
| 200 |
+
log("β οΈ No JSON block found in response, using regex fallback")
|
| 201 |
+
|
| 202 |
# Fallback: Extract from text using patterns
|
| 203 |
data = {}
|
| 204 |
+
|
| 205 |
if interviewee_type == "HCP":
|
| 206 |
+
log("Using HCP extraction patterns...")
|
| 207 |
# Extract diagnoses
|
| 208 |
diag_pattern = r'(?:diagnos[ei]s|condition):\s*([^\n]+)'
|
| 209 |
data["diagnoses"] = re.findall(diag_pattern, text, re.IGNORECASE)
|
| 210 |
+
|
| 211 |
# Extract prescriptions
|
| 212 |
rx_pattern = r'(?:prescri[bp]\w*|medication):\s*([^\n]+)'
|
| 213 |
data["prescriptions"] = re.findall(rx_pattern, text, re.IGNORECASE)
|
| 214 |
+
|
| 215 |
# Extract treatment rationale
|
| 216 |
treat_pattern = r'(?:treatment|therapy|rationale):\s*([^\n]+)'
|
| 217 |
data["treatment_rationale"] = re.findall(treat_pattern, text, re.IGNORECASE)
|
| 218 |
+
|
| 219 |
elif interviewee_type == "Patient":
|
| 220 |
+
log("Using Patient extraction patterns...")
|
| 221 |
# Extract symptoms
|
| 222 |
symptom_pattern = r'(?:symptom|complaint|experienc\w*):\s*([^\n]+)'
|
| 223 |
data["symptoms"] = re.findall(symptom_pattern, text, re.IGNORECASE)
|
| 224 |
+
|
| 225 |
# Extract concerns
|
| 226 |
concern_pattern = r'(?:concern|worry|question|anxious):\s*([^\n]+)'
|
| 227 |
data["concerns"] = re.findall(concern_pattern, text, re.IGNORECASE)
|
| 228 |
+
|
| 229 |
# Extract side effects
|
| 230 |
se_pattern = r'(?:side effect|adverse|reaction):\s*([^\n]+)'
|
| 231 |
data["side_effects"] = re.findall(se_pattern, text, re.IGNORECASE)
|
| 232 |
+
|
| 233 |
# Clean and deduplicate
|
| 234 |
for key in data:
|
| 235 |
data[key] = list(set([item.strip() for item in data[key] if item.strip()]))
|
| 236 |
+
|
| 237 |
+
log(f"Fallback extraction result: {len(data)} fields, {sum(len(v) for v in data.values())} total items")
|
| 238 |
+
log(f"Extracted fields: {data}")
|
| 239 |
return data
|
| 240 |
|
| 241 |
|
| 242 |
+
def query_llm_hf_api(prompt: str, max_tokens: int = 1500) -> str:
|
| 243 |
+
"""Use Hugging Face Inference API with proper authentication"""
|
| 244 |
+
import requests
|
| 245 |
+
import json
|
| 246 |
+
|
| 247 |
+
hf_token = os.getenv("HUGGINGFACE_TOKEN", "")
|
| 248 |
+
|
| 249 |
+
if not hf_token:
|
| 250 |
+
error_msg = "[Error] HUGGINGFACE_TOKEN not set in environment!"
|
| 251 |
+
print(f"β {error_msg}")
|
| 252 |
+
return error_msg
|
| 253 |
+
|
| 254 |
+
print(f"[HF API] Using token for authentication: {hf_token[:20]}...")
|
| 255 |
+
|
| 256 |
try:
|
| 257 |
+
# Get model from environment variable (default to Phi-3 if not set)
|
| 258 |
+
hf_model = os.getenv("HF_MODEL", "microsoft/Phi-3-mini-4k-instruct")
|
| 259 |
+
API_URL = f"https://api-inference.huggingface.co/models/{hf_model}"
|
| 260 |
+
|
| 261 |
+
# Use Bearer token in Authorization header
|
| 262 |
+
headers = {
|
| 263 |
+
"Authorization": f"Bearer {hf_token}",
|
| 264 |
+
"Content-Type": "application/json"
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
# Get temperature from environment
|
| 268 |
+
temperature = float(os.getenv("LLM_TEMPERATURE", "0.5"))
|
| 269 |
+
|
| 270 |
+
# Use the FULL prompt (don't truncate - the model can handle it)
|
| 271 |
+
payload = {
|
| 272 |
+
"inputs": prompt,
|
| 273 |
+
"parameters": {
|
| 274 |
+
"max_new_tokens": max_tokens, # Use parameter passed to function
|
| 275 |
+
"temperature": temperature,
|
| 276 |
+
"return_full_text": False
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
# Get timeout from environment
|
| 281 |
+
timeout = int(os.getenv("LLM_TIMEOUT", "60"))
|
| 282 |
+
|
| 283 |
+
print(f"[HF API] Calling {hf_model} ({max_tokens} tokens, temp={temperature})...")
|
| 284 |
+
response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
|
| 285 |
+
|
| 286 |
+
print(f"[HF API] Status code: {response.status_code}")
|
| 287 |
+
|
| 288 |
+
if response.status_code == 200:
|
| 289 |
+
result = response.json()
|
| 290 |
+
if isinstance(result, list) and len(result) > 0:
|
| 291 |
+
generated_text = result[0].get("generated_text", "")
|
| 292 |
+
print(f"[HF API] β
Response: {len(generated_text)} characters")
|
| 293 |
+
print(f"[HF API] First 200 chars: {generated_text[:200]}")
|
| 294 |
+
return generated_text
|
| 295 |
+
else:
|
| 296 |
+
print(f"[HF API] Unexpected response format: {result}")
|
| 297 |
+
return "[Error] Unexpected API response format"
|
| 298 |
+
elif response.status_code == 401:
|
| 299 |
+
print(f"[HF API] β 401 Unauthorized - Token invalid or expired")
|
| 300 |
+
print(f"[HF API] Token used: {hf_token}")
|
| 301 |
+
print(f"[HF API] Response: {response.text[:500]}")
|
| 302 |
+
return "[Error] Invalid HuggingFace token - create a new one at https://huggingface.co/settings/tokens"
|
| 303 |
+
else:
|
| 304 |
+
print(f"[HF API] Failed with status {response.status_code}")
|
| 305 |
+
print(f"[HF API] Response: {response.text[:500]}")
|
| 306 |
+
return f"[Error] API returned status {response.status_code}"
|
| 307 |
+
|
| 308 |
except Exception as e:
|
| 309 |
import traceback
|
| 310 |
full_error = traceback.format_exc()
|
| 311 |
+
print(f"[HF API] Error:\n{full_error}")
|
|
|
|
| 312 |
return f"[Error] HF API failed: {e}"
|
| 313 |
|
| 314 |
|
| 315 |
+
def query_llm_lmstudio(prompt: str, max_tokens: int = 1500) -> str:
|
| 316 |
+
"""Query LM Studio local server (OpenAI-compatible API)"""
|
| 317 |
+
import requests
|
| 318 |
+
import json
|
| 319 |
+
|
| 320 |
+
lmstudio_url = os.getenv("LMSTUDIO_URL", "http://localhost:1234/v1/chat/completions")
|
| 321 |
+
|
| 322 |
+
print(f"[LM Studio] Calling {lmstudio_url}...")
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
payload = {
|
| 326 |
+
"messages": [
|
| 327 |
+
{
|
| 328 |
+
"role": "user",
|
| 329 |
+
"content": prompt
|
| 330 |
+
}
|
| 331 |
+
],
|
| 332 |
+
"temperature": float(os.getenv("LLM_TEMPERATURE", "0.7")),
|
| 333 |
+
"max_tokens": max_tokens,
|
| 334 |
+
"stream": False
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
response = requests.post(lmstudio_url, json=payload, timeout=120)
|
| 338 |
+
|
| 339 |
+
print(f"[LM Studio] Status code: {response.status_code}")
|
| 340 |
+
|
| 341 |
+
if response.status_code == 200:
|
| 342 |
+
result = response.json()
|
| 343 |
+
generated_text = result["choices"][0]["message"]["content"]
|
| 344 |
+
print(f"[LM Studio] β Response: {len(generated_text)} characters")
|
| 345 |
+
print(f"[LM Studio] First 300 chars: {generated_text[:300]}")
|
| 346 |
+
return generated_text
|
| 347 |
+
else:
|
| 348 |
+
error_msg = f"[Error] LM Studio returned status {response.status_code}: {response.text[:200]}"
|
| 349 |
+
print(f"[LM Studio] {error_msg}")
|
| 350 |
+
return error_msg
|
| 351 |
+
|
| 352 |
+
except requests.exceptions.ConnectionError:
|
| 353 |
+
error_msg = "[Error] Cannot connect to LM Studio. Make sure:\n1. LM Studio is running\n2. Server is started (in LM Studio's Server tab)\n3. A model is loaded\n4. Server is on http://localhost:1234"
|
| 354 |
+
print(f"[LM Studio] {error_msg}")
|
| 355 |
+
return error_msg
|
| 356 |
+
except Exception as e:
|
| 357 |
+
error_msg = f"[Error] LM Studio failed: {e}"
|
| 358 |
+
print(f"[LM Studio] {error_msg}")
|
| 359 |
+
import traceback
|
| 360 |
+
traceback.print_exc()
|
| 361 |
+
return error_msg
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def query_llm_local(prompt: str, max_tokens: int = 1500) -> str:
|
| 365 |
"""Local model optimized for L4 GPU"""
|
| 366 |
try:
|
| 367 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
| 448 |
log(f"Prompt truncated to {len(full_prompt)} characters")
|
| 449 |
|
| 450 |
def generate():
|
| 451 |
+
# Check environment variables dynamically (not using module-level USE_HF_API)
|
| 452 |
+
use_lmstudio = os.getenv("USE_LMSTUDIO", "False").lower() == "true"
|
| 453 |
+
use_hf_api = os.getenv("USE_HF_API", "False").lower() == "true"
|
| 454 |
+
hf_token = os.getenv("HUGGINGFACE_TOKEN", "")
|
| 455 |
+
|
| 456 |
+
if use_lmstudio:
|
| 457 |
+
return query_llm_lmstudio(full_prompt, max_tokens=2000)
|
| 458 |
+
elif use_hf_api and hf_token:
|
| 459 |
+
return query_llm_hf_api(full_prompt, max_tokens=1500)
|
| 460 |
else:
|
| 461 |
+
return query_llm_local(full_prompt, max_tokens=1500)
|
| 462 |
|
| 463 |
# Execute with timeout
|
| 464 |
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
|
|
| 466 |
try:
|
| 467 |
response = future.result(timeout=timeout)
|
| 468 |
log(f"LLM response received ({len(response)} chars)")
|
| 469 |
+
|
| 470 |
# Extract structured data if requested
|
| 471 |
structured_data = {}
|
| 472 |
+
clean_response = response
|
| 473 |
if extract_structured:
|
| 474 |
structured_data = parse_structured_response(response, interviewee_type)
|
| 475 |
+
|
| 476 |
+
# Remove JSON blocks from the narrative text (handle nested braces)
|
| 477 |
+
# Remove all {....} blocks repeatedly until none remain
|
| 478 |
+
prev_response = ""
|
| 479 |
+
while prev_response != clean_response:
|
| 480 |
+
prev_response = clean_response
|
| 481 |
+
clean_response = re.sub(r'\{[^{}]*\}', '', clean_response, flags=re.DOTALL)
|
| 482 |
+
|
| 483 |
+
# Also remove common JSON artifacts
|
| 484 |
+
clean_response = re.sub(r'###\s*JSON\s*Structure:', '', clean_response, flags=re.IGNORECASE)
|
| 485 |
+
clean_response = re.sub(r'###\s*Analysis:', '', clean_response, flags=re.IGNORECASE)
|
| 486 |
+
clean_response = re.sub(r'###\s*Response:', '', clean_response, flags=re.IGNORECASE)
|
| 487 |
+
clean_response = re.sub(r'Please provide.*?structured JSON.*', '', clean_response, flags=re.IGNORECASE|re.DOTALL)
|
| 488 |
+
|
| 489 |
+
clean_response = clean_response.strip()
|
| 490 |
+
log(f"Cleaned response: {len(clean_response)} chars (removed JSON)")
|
| 491 |
+
|
| 492 |
+
return clean_response, structured_data
|
| 493 |
|
| 494 |
except ThreadTimeout:
|
| 495 |
log("LLM generation timed out")
|