dpv007 commited on
Commit
d5eb738
·
verified ·
1 Parent(s): e464210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -100
app.py CHANGED
@@ -336,7 +336,7 @@ def run_vlm_and_get_features(face_path: str, eye_path: str, prompt: Optional[str
336
  return parsed_features, text_out
337
 
338
  # -----------------------
339
- # Gradio / LLM helper (uses regex extractor on LLM output)
340
  # -----------------------
341
  def run_llm_on_vlm(vlm_features_or_raw: Any,
342
  max_new_tokens: int = 1024,
@@ -346,19 +346,21 @@ def run_llm_on_vlm(vlm_features_or_raw: Any,
346
  system_prompt: Optional[str] = None,
347
  developer_prompt: Optional[str] = None) -> Dict[str, Any]:
348
  """
349
- Call the remote LLM Space's /chat endpoint.
350
- Accepts either:
351
- - a dict (parsed VLM features) -> will be JSON-dumped (backwards compatible)
352
- - a raw string (the exact VLM text output) -> will be forwarded AS-IS (no extra JSON quoting)
353
-
354
- After the LLM returns, we use a regex-based extractor to pull numeric values and strings,
355
- reconstruct a clean JSON dict with numeric defaults (no NaN).
356
  """
357
  if not GRADIO_AVAILABLE:
358
  raise RuntimeError("gradio_client not installed. Add gradio_client to requirements.txt")
359
 
360
- client = get_gradio_client_for_space(LLM_GRADIO_SPACE)
 
 
 
 
361
 
 
362
  model_identity = model_identity or LLM_MODEL_IDENTITY
363
  system_prompt = system_prompt or LLM_SYSTEM_PROMPT
364
  developer_prompt = developer_prompt or LLM_DEVELOPER_PROMPT
@@ -383,103 +385,123 @@ def run_llm_on_vlm(vlm_features_or_raw: Any,
383
  f"{vlm_json_str}\n"
384
  "===END VLM OUTPUT===\n\n"
385
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- input_payload_str = instruction # we feed only the instruction (which contains the VLM output)
388
-
389
- try:
390
- logger.info("Calling LLM Space %s with strict JSON-only prompt", LLM_GRADIO_SPACE)
391
- result = client.predict(
392
- input_data=input_payload_str,
393
- max_new_tokens=float(max_new_tokens),
394
- model_identity=model_identity,
395
- system_prompt=system_prompt,
396
- developer_prompt=developer_prompt,
397
- reasoning_effort=reasoning_effort,
398
- temperature=float(temperature),
399
- top_p=0.9,
400
- top_k=50,
401
- repetition_penalty=1.0,
402
- api_name="/chat"
403
- )
404
- except Exception as e:
405
- logger.exception("LLM call failed")
406
- raise RuntimeError(f"LLM call failed: {e}")
407
-
408
- # Normalize result to string
409
- if isinstance(result, (dict, list)):
410
- text_out = json.dumps(result)
411
- else:
412
- text_out = str(result)
413
-
414
- if not text_out or len(text_out.strip()) == 0:
415
- raise RuntimeError("LLM returned empty response")
416
-
417
- # LOG raw output for debugging / auditing
418
- logger.info("LLM raw output:\n%s", text_out)
419
-
420
- # Use regex-based extraction (robust)
421
- try:
422
- parsed = extract_json_via_regex(text_out)
423
- except Exception as e:
424
- logger.exception("Regex JSON extraction failed")
425
- # As a last fallback, attempt naive JSON parsing; if that fails, raise with raw output
426
  try:
427
- parsed = json.loads(text_out)
428
- except Exception:
429
- # include raw output in the exception text so logs contain it
430
- raise ValueError(f"Failed to extract JSON from LLM output: {e}\nRaw Output:\n{text_out}")
 
 
 
 
 
 
 
431
 
432
- if not isinstance(parsed, dict):
433
- raise ValueError("Parsed LLM output is not a JSON object/dict")
 
 
434
 
435
- # LOG parsed JSON (pretty-printed)
436
- try:
437
- logger.info("LLM parsed JSON:\n%s", json.dumps(parsed, indent=2, ensure_ascii=False))
438
- except Exception:
439
- logger.info("LLM parsed JSON (raw dict): %s", str(parsed))
440
-
441
- # Final safety clamps (already ensured by extractor, but keep defensive checks)
442
- def safe_prob(val):
443
- try:
444
- v = float(val)
445
- return max(0.0, min(1.0, v))
446
- except Exception:
447
- return 0.0
 
 
 
 
 
 
 
 
448
 
449
- for k in [
450
- "jaundice_probability",
451
- "anemia_probability",
452
- "hydration_issue_probability",
453
- "neurological_issue_probability"
454
- ]:
455
- parsed[k] = safe_prob(parsed.get(k, 0.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
- # risk_score clamp 0..100
458
- try:
459
- rs = float(parsed.get("risk_score", 0.0))
460
- parsed["risk_score"] = round(max(0.0, min(100.0, rs)), 2)
461
- except Exception:
462
- parsed["risk_score"] = 0.0
463
-
464
- # confidence clamp 0..1
465
- parsed["confidence"] = safe_prob(parsed.get("confidence", 0.0))
466
-
467
- # Ensure summary/recommendation are strings
468
- parsed["summary"] = str(parsed.get("summary", "") or "").strip()
469
- parsed["recommendation"] = str(parsed.get("recommendation", "") or "").strip()
470
-
471
- # Optional: add flags indicating missing values (useful for frontend)
472
- for k in [
473
- "jaundice_probability",
474
- "anemia_probability",
475
- "hydration_issue_probability",
476
- "neurological_issue_probability",
477
- "confidence",
478
- "risk_score"
479
- ]:
480
- parsed[f"{k}_was_missing"] = False # extractor already returned defaults; mark as False
481
-
482
- return parsed
483
 
484
  # -----------------------
485
  # API endpoints
 
336
  return parsed_features, text_out
337
 
338
  # -----------------------
339
+ # Gradio / LLM helper (defensive, with retry + clamps)
340
  # -----------------------
341
  def run_llm_on_vlm(vlm_features_or_raw: Any,
342
  max_new_tokens: int = 1024,
 
346
  system_prompt: Optional[str] = None,
347
  developer_prompt: Optional[str] = None) -> Dict[str, Any]:
348
  """
349
+ Call the remote LLM Space's /chat endpoint with defensive input handling and a single retry.
350
+ - Coerces types (int for tokens), clamps ranges where remote spaces often expect them.
351
+ - Retries once with safe defaults if the Space rejects the inputs (e.g. temperature too low).
352
+ - Logs and returns regex-extracted JSON as before.
 
 
 
353
  """
354
  if not GRADIO_AVAILABLE:
355
  raise RuntimeError("gradio_client not installed. Add gradio_client to requirements.txt")
356
 
357
+ # Try to import AppError for specific handling; fallback to Exception if unavailable
358
+ try:
359
+ from gradio_client import AppError # type: ignore
360
+ except Exception:
361
+ AppError = Exception # fallback
362
 
363
+ client = get_gradio_client_for_space(LLM_GRADIO_SPACE)
364
  model_identity = model_identity or LLM_MODEL_IDENTITY
365
  system_prompt = system_prompt or LLM_SYSTEM_PROMPT
366
  developer_prompt = developer_prompt or LLM_DEVELOPER_PROMPT
 
385
  f"{vlm_json_str}\n"
386
  "===END VLM OUTPUT===\n\n"
387
  )
388
+ input_payload_str = instruction
389
+
390
+ # Defensive coercion / clamps
391
+ try_max_new_tokens = int(max_new_tokens) if max_new_tokens is not None else 1024
392
+ if try_max_new_tokens <= 0:
393
+ try_max_new_tokens = 1024
394
+
395
+ try_temperature = float(temperature) if temperature is not None else 0.0
396
+ # Many demos require temperature >= 0.1; clamp to 0.1 minimum to avoid validation failures
397
+ if try_temperature < 0.1:
398
+ try_temperature = 0.1
399
+
400
+ # prepare kwargs for predict
401
+ predict_kwargs = dict(
402
+ input_data=input_payload_str,
403
+ max_new_tokens=float(try_max_new_tokens),
404
+ model_identity=model_identity,
405
+ system_prompt=system_prompt,
406
+ developer_prompt=developer_prompt,
407
+ reasoning_effort=reasoning_effort,
408
+ temperature=float(try_temperature),
409
+ top_p=0.9,
410
+ top_k=50,
411
+ repetition_penalty=1.0,
412
+ api_name="/chat"
413
+ )
414
 
415
+ # attempt + one retry with safer defaults if AppError occurs
416
+ last_exc = None
417
+ for attempt in (1, 2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  try:
419
+ logger.info("Calling LLM Space %s (attempt %d) with temperature=%s, max_new_tokens=%s",
420
+ LLM_GRADIO_SPACE, attempt, predict_kwargs.get("temperature"), predict_kwargs.get("max_new_tokens"))
421
+ result = client.predict(**predict_kwargs)
422
+ # normalize to string
423
+ if isinstance(result, (dict, list)):
424
+ text_out = json.dumps(result)
425
+ else:
426
+ text_out = str(result)
427
+ if not text_out or len(text_out.strip()) == 0:
428
+ raise RuntimeError("LLM returned empty response")
429
+ logger.info("LLM raw output:\n%s", text_out)
430
 
431
+ # parse with regex extractor (may raise)
432
+ parsed = extract_json_via_regex(text_out)
433
+ if not isinstance(parsed, dict):
434
+ raise ValueError("Parsed LLM output is not a JSON object/dict")
435
 
436
+ # pretty log parsed JSON
437
+ try:
438
+ logger.info("LLM parsed JSON:\n%s", json.dumps(parsed, indent=2, ensure_ascii=False))
439
+ except Exception:
440
+ logger.info("LLM parsed JSON (raw dict): %s", str(parsed))
441
+
442
+ # defensive clamps (same as before)
443
+ def safe_prob(val):
444
+ try:
445
+ v = float(val)
446
+ return max(0.0, min(1.0, v))
447
+ except Exception:
448
+ return 0.0
449
+
450
+ for k in [
451
+ "jaundice_probability",
452
+ "anemia_probability",
453
+ "hydration_issue_probability",
454
+ "neurological_issue_probability"
455
+ ]:
456
+ parsed[k] = safe_prob(parsed.get(k, 0.0))
457
 
458
+ try:
459
+ rs = float(parsed.get("risk_score", 0.0))
460
+ parsed["risk_score"] = round(max(0.0, min(100.0, rs)), 2)
461
+ except Exception:
462
+ parsed["risk_score"] = 0.0
463
+
464
+ parsed["confidence"] = safe_prob(parsed.get("confidence", 0.0))
465
+ parsed["summary"] = str(parsed.get("summary", "") or "").strip()
466
+ parsed["recommendation"] = str(parsed.get("recommendation", "") or "").strip()
467
+
468
+ for k in [
469
+ "jaundice_probability",
470
+ "anemia_probability",
471
+ "hydration_issue_probability",
472
+ "neurological_issue_probability",
473
+ "confidence",
474
+ "risk_score"
475
+ ]:
476
+ parsed[f"{k}_was_missing"] = False
477
+
478
+ return parsed
479
+
480
+ except AppError as app_e:
481
+ # Specific remote validation error: log and attempt a single retry with ultra-safe defaults
482
+ logger.exception("LLM AppError (remote validation failed) on attempt %d: %s", attempt, str(app_e))
483
+ last_exc = app_e
484
+ if attempt == 1:
485
+ # tighten inputs and retry: force temperature=0.2, max_new_tokens=512
486
+ predict_kwargs["temperature"] = 0.2
487
+ predict_kwargs["max_new_tokens"] = float(512)
488
+ logger.info("Retrying LLM call with temperature=0.2 and max_new_tokens=512")
489
+ continue
490
+ else:
491
+ # no more retries
492
+ raise RuntimeError(f"LLM call failed (AppError): {app_e}")
493
+ except Exception as e:
494
+ logger.exception("LLM call failed on attempt %d: %s", attempt, str(e))
495
+ last_exc = e
496
+ # try one retry only for non-AppError exceptions
497
+ if attempt == 1:
498
+ predict_kwargs["temperature"] = 0.2
499
+ predict_kwargs["max_new_tokens"] = float(512)
500
+ continue
501
+ raise RuntimeError(f"LLM call failed: {e}")
502
 
503
+ # if we reach here, raise last caught exception
504
+ raise RuntimeError(f"LLM call ultimately failed: {last_exc}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
  # -----------------------
507
  # API endpoints