Update app.py
Browse files
app.py
CHANGED
|
@@ -336,7 +336,7 @@ def run_vlm_and_get_features(face_path: str, eye_path: str, prompt: Optional[str
|
|
| 336 |
return parsed_features, text_out
|
| 337 |
|
| 338 |
# -----------------------
|
| 339 |
-
# Gradio / LLM helper (
|
| 340 |
# -----------------------
|
| 341 |
def run_llm_on_vlm(vlm_features_or_raw: Any,
|
| 342 |
max_new_tokens: int = 1024,
|
|
@@ -346,19 +346,21 @@ def run_llm_on_vlm(vlm_features_or_raw: Any,
|
|
| 346 |
system_prompt: Optional[str] = None,
|
| 347 |
developer_prompt: Optional[str] = None) -> Dict[str, Any]:
|
| 348 |
"""
|
| 349 |
-
Call the remote LLM Space's /chat endpoint.
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
After the LLM returns, we use a regex-based extractor to pull numeric values and strings,
|
| 355 |
-
reconstruct a clean JSON dict with numeric defaults (no NaN).
|
| 356 |
"""
|
| 357 |
if not GRADIO_AVAILABLE:
|
| 358 |
raise RuntimeError("gradio_client not installed. Add gradio_client to requirements.txt")
|
| 359 |
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
|
|
|
| 362 |
model_identity = model_identity or LLM_MODEL_IDENTITY
|
| 363 |
system_prompt = system_prompt or LLM_SYSTEM_PROMPT
|
| 364 |
developer_prompt = developer_prompt or LLM_DEVELOPER_PROMPT
|
|
@@ -383,103 +385,123 @@ def run_llm_on_vlm(vlm_features_or_raw: Any,
|
|
| 383 |
f"{vlm_json_str}\n"
|
| 384 |
"===END VLM OUTPUT===\n\n"
|
| 385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
logger.info("Calling LLM Space %s with strict JSON-only prompt", LLM_GRADIO_SPACE)
|
| 391 |
-
result = client.predict(
|
| 392 |
-
input_data=input_payload_str,
|
| 393 |
-
max_new_tokens=float(max_new_tokens),
|
| 394 |
-
model_identity=model_identity,
|
| 395 |
-
system_prompt=system_prompt,
|
| 396 |
-
developer_prompt=developer_prompt,
|
| 397 |
-
reasoning_effort=reasoning_effort,
|
| 398 |
-
temperature=float(temperature),
|
| 399 |
-
top_p=0.9,
|
| 400 |
-
top_k=50,
|
| 401 |
-
repetition_penalty=1.0,
|
| 402 |
-
api_name="/chat"
|
| 403 |
-
)
|
| 404 |
-
except Exception as e:
|
| 405 |
-
logger.exception("LLM call failed")
|
| 406 |
-
raise RuntimeError(f"LLM call failed: {e}")
|
| 407 |
-
|
| 408 |
-
# Normalize result to string
|
| 409 |
-
if isinstance(result, (dict, list)):
|
| 410 |
-
text_out = json.dumps(result)
|
| 411 |
-
else:
|
| 412 |
-
text_out = str(result)
|
| 413 |
-
|
| 414 |
-
if not text_out or len(text_out.strip()) == 0:
|
| 415 |
-
raise RuntimeError("LLM returned empty response")
|
| 416 |
-
|
| 417 |
-
# LOG raw output for debugging / auditing
|
| 418 |
-
logger.info("LLM raw output:\n%s", text_out)
|
| 419 |
-
|
| 420 |
-
# Use regex-based extraction (robust)
|
| 421 |
-
try:
|
| 422 |
-
parsed = extract_json_via_regex(text_out)
|
| 423 |
-
except Exception as e:
|
| 424 |
-
logger.exception("Regex JSON extraction failed")
|
| 425 |
-
# As a last fallback, attempt naive JSON parsing; if that fails, raise with raw output
|
| 426 |
try:
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
-
#
|
| 458 |
-
|
| 459 |
-
rs = float(parsed.get("risk_score", 0.0))
|
| 460 |
-
parsed["risk_score"] = round(max(0.0, min(100.0, rs)), 2)
|
| 461 |
-
except Exception:
|
| 462 |
-
parsed["risk_score"] = 0.0
|
| 463 |
-
|
| 464 |
-
# confidence clamp 0..1
|
| 465 |
-
parsed["confidence"] = safe_prob(parsed.get("confidence", 0.0))
|
| 466 |
-
|
| 467 |
-
# Ensure summary/recommendation are strings
|
| 468 |
-
parsed["summary"] = str(parsed.get("summary", "") or "").strip()
|
| 469 |
-
parsed["recommendation"] = str(parsed.get("recommendation", "") or "").strip()
|
| 470 |
-
|
| 471 |
-
# Optional: add flags indicating missing values (useful for frontend)
|
| 472 |
-
for k in [
|
| 473 |
-
"jaundice_probability",
|
| 474 |
-
"anemia_probability",
|
| 475 |
-
"hydration_issue_probability",
|
| 476 |
-
"neurological_issue_probability",
|
| 477 |
-
"confidence",
|
| 478 |
-
"risk_score"
|
| 479 |
-
]:
|
| 480 |
-
parsed[f"{k}_was_missing"] = False # extractor already returned defaults; mark as False
|
| 481 |
-
|
| 482 |
-
return parsed
|
| 483 |
|
| 484 |
# -----------------------
|
| 485 |
# API endpoints
|
|
|
|
| 336 |
return parsed_features, text_out
|
| 337 |
|
| 338 |
# -----------------------
|
| 339 |
+
# Gradio / LLM helper (defensive, with retry + clamps)
|
| 340 |
# -----------------------
|
| 341 |
def run_llm_on_vlm(vlm_features_or_raw: Any,
|
| 342 |
max_new_tokens: int = 1024,
|
|
|
|
| 346 |
system_prompt: Optional[str] = None,
|
| 347 |
developer_prompt: Optional[str] = None) -> Dict[str, Any]:
|
| 348 |
"""
|
| 349 |
+
Call the remote LLM Space's /chat endpoint with defensive input handling and a single retry.
|
| 350 |
+
- Coerces types (int for tokens), clamps ranges where remote spaces often expect them.
|
| 351 |
+
- Retries once with safe defaults if the Space rejects the inputs (e.g. temperature too low).
|
| 352 |
+
- Logs and returns regex-extracted JSON as before.
|
|
|
|
|
|
|
|
|
|
| 353 |
"""
|
| 354 |
if not GRADIO_AVAILABLE:
|
| 355 |
raise RuntimeError("gradio_client not installed. Add gradio_client to requirements.txt")
|
| 356 |
|
| 357 |
+
# Try to import AppError for specific handling; fallback to Exception if unavailable
|
| 358 |
+
try:
|
| 359 |
+
from gradio_client import AppError # type: ignore
|
| 360 |
+
except Exception:
|
| 361 |
+
AppError = Exception # fallback
|
| 362 |
|
| 363 |
+
client = get_gradio_client_for_space(LLM_GRADIO_SPACE)
|
| 364 |
model_identity = model_identity or LLM_MODEL_IDENTITY
|
| 365 |
system_prompt = system_prompt or LLM_SYSTEM_PROMPT
|
| 366 |
developer_prompt = developer_prompt or LLM_DEVELOPER_PROMPT
|
|
|
|
| 385 |
f"{vlm_json_str}\n"
|
| 386 |
"===END VLM OUTPUT===\n\n"
|
| 387 |
)
|
| 388 |
+
input_payload_str = instruction
|
| 389 |
+
|
| 390 |
+
# Defensive coercion / clamps
|
| 391 |
+
try_max_new_tokens = int(max_new_tokens) if max_new_tokens is not None else 1024
|
| 392 |
+
if try_max_new_tokens <= 0:
|
| 393 |
+
try_max_new_tokens = 1024
|
| 394 |
+
|
| 395 |
+
try_temperature = float(temperature) if temperature is not None else 0.0
|
| 396 |
+
# Many demos require temperature >= 0.1; clamp to 0.1 minimum to avoid validation failures
|
| 397 |
+
if try_temperature < 0.1:
|
| 398 |
+
try_temperature = 0.1
|
| 399 |
+
|
| 400 |
+
# prepare kwargs for predict
|
| 401 |
+
predict_kwargs = dict(
|
| 402 |
+
input_data=input_payload_str,
|
| 403 |
+
max_new_tokens=float(try_max_new_tokens),
|
| 404 |
+
model_identity=model_identity,
|
| 405 |
+
system_prompt=system_prompt,
|
| 406 |
+
developer_prompt=developer_prompt,
|
| 407 |
+
reasoning_effort=reasoning_effort,
|
| 408 |
+
temperature=float(try_temperature),
|
| 409 |
+
top_p=0.9,
|
| 410 |
+
top_k=50,
|
| 411 |
+
repetition_penalty=1.0,
|
| 412 |
+
api_name="/chat"
|
| 413 |
+
)
|
| 414 |
|
| 415 |
+
# attempt + one retry with safer defaults if AppError occurs
|
| 416 |
+
last_exc = None
|
| 417 |
+
for attempt in (1, 2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
try:
|
| 419 |
+
logger.info("Calling LLM Space %s (attempt %d) with temperature=%s, max_new_tokens=%s",
|
| 420 |
+
LLM_GRADIO_SPACE, attempt, predict_kwargs.get("temperature"), predict_kwargs.get("max_new_tokens"))
|
| 421 |
+
result = client.predict(**predict_kwargs)
|
| 422 |
+
# normalize to string
|
| 423 |
+
if isinstance(result, (dict, list)):
|
| 424 |
+
text_out = json.dumps(result)
|
| 425 |
+
else:
|
| 426 |
+
text_out = str(result)
|
| 427 |
+
if not text_out or len(text_out.strip()) == 0:
|
| 428 |
+
raise RuntimeError("LLM returned empty response")
|
| 429 |
+
logger.info("LLM raw output:\n%s", text_out)
|
| 430 |
|
| 431 |
+
# parse with regex extractor (may raise)
|
| 432 |
+
parsed = extract_json_via_regex(text_out)
|
| 433 |
+
if not isinstance(parsed, dict):
|
| 434 |
+
raise ValueError("Parsed LLM output is not a JSON object/dict")
|
| 435 |
|
| 436 |
+
# pretty log parsed JSON
|
| 437 |
+
try:
|
| 438 |
+
logger.info("LLM parsed JSON:\n%s", json.dumps(parsed, indent=2, ensure_ascii=False))
|
| 439 |
+
except Exception:
|
| 440 |
+
logger.info("LLM parsed JSON (raw dict): %s", str(parsed))
|
| 441 |
+
|
| 442 |
+
# defensive clamps (same as before)
|
| 443 |
+
def safe_prob(val):
|
| 444 |
+
try:
|
| 445 |
+
v = float(val)
|
| 446 |
+
return max(0.0, min(1.0, v))
|
| 447 |
+
except Exception:
|
| 448 |
+
return 0.0
|
| 449 |
+
|
| 450 |
+
for k in [
|
| 451 |
+
"jaundice_probability",
|
| 452 |
+
"anemia_probability",
|
| 453 |
+
"hydration_issue_probability",
|
| 454 |
+
"neurological_issue_probability"
|
| 455 |
+
]:
|
| 456 |
+
parsed[k] = safe_prob(parsed.get(k, 0.0))
|
| 457 |
|
| 458 |
+
try:
|
| 459 |
+
rs = float(parsed.get("risk_score", 0.0))
|
| 460 |
+
parsed["risk_score"] = round(max(0.0, min(100.0, rs)), 2)
|
| 461 |
+
except Exception:
|
| 462 |
+
parsed["risk_score"] = 0.0
|
| 463 |
+
|
| 464 |
+
parsed["confidence"] = safe_prob(parsed.get("confidence", 0.0))
|
| 465 |
+
parsed["summary"] = str(parsed.get("summary", "") or "").strip()
|
| 466 |
+
parsed["recommendation"] = str(parsed.get("recommendation", "") or "").strip()
|
| 467 |
+
|
| 468 |
+
for k in [
|
| 469 |
+
"jaundice_probability",
|
| 470 |
+
"anemia_probability",
|
| 471 |
+
"hydration_issue_probability",
|
| 472 |
+
"neurological_issue_probability",
|
| 473 |
+
"confidence",
|
| 474 |
+
"risk_score"
|
| 475 |
+
]:
|
| 476 |
+
parsed[f"{k}_was_missing"] = False
|
| 477 |
+
|
| 478 |
+
return parsed
|
| 479 |
+
|
| 480 |
+
except AppError as app_e:
|
| 481 |
+
# Specific remote validation error: log and attempt a single retry with ultra-safe defaults
|
| 482 |
+
logger.exception("LLM AppError (remote validation failed) on attempt %d: %s", attempt, str(app_e))
|
| 483 |
+
last_exc = app_e
|
| 484 |
+
if attempt == 1:
|
| 485 |
+
# tighten inputs and retry: force temperature=0.2, max_new_tokens=512
|
| 486 |
+
predict_kwargs["temperature"] = 0.2
|
| 487 |
+
predict_kwargs["max_new_tokens"] = float(512)
|
| 488 |
+
logger.info("Retrying LLM call with temperature=0.2 and max_new_tokens=512")
|
| 489 |
+
continue
|
| 490 |
+
else:
|
| 491 |
+
# no more retries
|
| 492 |
+
raise RuntimeError(f"LLM call failed (AppError): {app_e}")
|
| 493 |
+
except Exception as e:
|
| 494 |
+
logger.exception("LLM call failed on attempt %d: %s", attempt, str(e))
|
| 495 |
+
last_exc = e
|
| 496 |
+
# try one retry only for non-AppError exceptions
|
| 497 |
+
if attempt == 1:
|
| 498 |
+
predict_kwargs["temperature"] = 0.2
|
| 499 |
+
predict_kwargs["max_new_tokens"] = float(512)
|
| 500 |
+
continue
|
| 501 |
+
raise RuntimeError(f"LLM call failed: {e}")
|
| 502 |
|
| 503 |
+
# if we reach here, raise last caught exception
|
| 504 |
+
raise RuntimeError(f"LLM call ultimately failed: {last_exc}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
# -----------------------
|
| 507 |
# API endpoints
|