malavikapradeep2001 commited on
Commit
d06de52
·
1 Parent(s): 3f3aedd
Files changed (1) hide show
  1. backend/app.py +50 -13
backend/app.py CHANGED
@@ -542,29 +542,56 @@ async def predict(model_name: str = Form(...), file: UploadFile = File(...)):
542
 
543
  num_classes = int(len(proba))
544
 
545
- # Expecting a 3-class CIN classifier. If not 3, return a clear error so the user can supply a 3-class model.
546
- if num_classes != 3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  return JSONResponse(
548
  content={
549
- "error": "CIN classifier must output 3-class probabilities (CIN1, CIN2, CIN3).",
550
  "detected_num_classes": num_classes,
551
  },
552
  status_code=503,
553
  )
554
 
555
- # Map probabilities to explicit CIN labels (CIN1, CIN2, CIN3)
556
- classes = ["CIN1", "CIN2", "CIN3"]
557
- confidences = {classes[i]: float(proba[i]) for i in range(3)}
558
-
559
- # Primary prediction is the class with the highest probability
560
- predicted_idx = int(np.argmax(proba))
561
- predicted_label = classes[predicted_idx]
562
- avg_confidence = float(np.max(proba)) * 100
563
-
564
  # Generate AI interpretation using Mistral client (if available)
565
  ai_interp = generate_cin_summary(predicted_label, confidences, avg_confidence)
566
 
567
- return {
568
  "model_used": "CIN Classifier",
569
  "prediction": predicted_label,
570
  "confidence": confidences,
@@ -573,6 +600,16 @@ async def predict(model_name: str = Form(...), file: UploadFile = File(...)):
573
  "ai_interpretation": ai_interp,
574
  },
575
  }
 
 
 
 
 
 
 
 
 
 
576
  elif model_name == "histopathology":
577
  result = predict_histopathology(image)
578
  return result
 
542
 
543
  num_classes = int(len(proba))
544
 
545
+ # Handle different classifier output sizes:
546
+ # - If 3 classes: map directly to CIN1/CIN2/CIN3
547
+ # - If 2 classes: apply a conservative heuristic to split High-grade into CIN2/CIN3
548
+ if num_classes == 3:
549
+ classes = ["CIN1", "CIN2", "CIN3"]
550
+ confidences = {classes[i]: float(proba[i]) for i in range(3)}
551
+ predicted_idx = int(np.argmax(proba))
552
+ predicted_label = classes[predicted_idx]
553
+ avg_confidence = float(np.max(proba)) * 100
554
+ mapping_used = "direct_3class"
555
+ elif num_classes == 2:
556
+ # Binary model detected (e.g., Low-grade vs High-grade). We'll convert to CIN1/CIN2/CIN3
557
+ # Heuristic:
558
+ # - CIN1 <- low_grade_prob
559
+ # - Split high_grade_prob into CIN2 and CIN3 based on how confident 'high' is.
560
+ # * If high <= 0.6 -> mostly CIN2
561
+ # * If high >= 0.8 -> mostly CIN3
562
+ # * Between 0.6 and 0.8 -> interpolate
563
+ low_prob = float(proba[0])
564
+ high_prob = float(proba[1])
565
+
566
+ if high_prob <= 0.6:
567
+ cin3_factor = 0.0
568
+ elif high_prob >= 0.8:
569
+ cin3_factor = 1.0
570
+ else:
571
+ cin3_factor = (high_prob - 0.6) / 0.2
572
+
573
+ cin1 = low_prob
574
+ cin3 = high_prob * cin3_factor
575
+ cin2 = high_prob - cin3
576
+
577
+ confidences = {"CIN1": cin1, "CIN2": cin2, "CIN3": cin3}
578
+ # pick highest of the mapped three as primary prediction
579
+ predicted_label = max(confidences.items(), key=lambda x: x[1])[0]
580
+ avg_confidence = float(max(confidences.values())) * 100
581
+ mapping_used = "binary_to_3class_heuristic"
582
+ else:
583
  return JSONResponse(
584
  content={
585
+ "error": "CIN classifier must output 2-class (Low/High) or 3-class probabilities (CIN1, CIN2, CIN3).",
586
  "detected_num_classes": num_classes,
587
  },
588
  status_code=503,
589
  )
590
 
 
 
 
 
 
 
 
 
 
591
  # Generate AI interpretation using Mistral client (if available)
592
  ai_interp = generate_cin_summary(predicted_label, confidences, avg_confidence)
593
 
594
+ response = {
595
  "model_used": "CIN Classifier",
596
  "prediction": predicted_label,
597
  "confidence": confidences,
 
600
  "ai_interpretation": ai_interp,
601
  },
602
  }
603
+
604
+ # If we used the binary->3class heuristic, include a diagnostic field so callers know it was mapped
605
+ if 'mapping_used' in locals() and mapping_used == 'binary_to_3class_heuristic':
606
+ response["mapping_used"] = mapping_used
607
+ response["mapping_note"] = (
608
+ "The server mapped a binary Low/High classifier to CIN1/CIN2/CIN3 using a heuristic split. "
609
+ "This is an approximation — for clinical use please supply a native 3-class model."
610
+ )
611
+
612
+ return response
613
  elif model_name == "histopathology":
614
  result = predict_histopathology(image)
615
  return result