Spaces:
Running
Running
File size: 6,401 Bytes
1947612 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | from __future__ import annotations
import logging
from rest_framework import status
from rest_framework.parsers import MultiPartParser, JSONParser
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from .groq_client import generate_clinical_response
from .inference import model_ready, run_inference
from .serializers import AnalyseRequestSerializer
from .pdf_generator import build_pdf
logger = logging.getLogger(__name__)
MODEL_VERSION = "dermavision-dinov2-v1"
class HealthView(APIView):
"""GET /health — liveness + model readiness probe."""
def get(self, request: Request) -> Response:
return Response({
"status": "ok",
"model_loaded": model_ready(),
"model_version": MODEL_VERSION,
})
class AnalyseView(APIView):
"""
POST /analyse
Multipart fields:
image (required) — skin lesion image file
include_heatmap (bool, default False)
include_narrative (bool, default True)
patient_name (str, optional)
patient_age (str, optional)
patient_sex (str, optional)
symptoms (str, optional)
Returns the full structured clinical response the frontend needs directly —
primaryFinding, confidence, urgency, urgencyText, treatmentNotes,
recommendedAction, referralNote, conditionCode, predictions, model_version.
"""
parser_classes = [MultiPartParser, JSONParser]
def post(self, request: Request) -> Response:
serializer = AnalyseRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
image_file = serializer.validated_data["image"]
include_heatmap = serializer.validated_data["include_heatmap"]
include_narrative = serializer.validated_data["include_narrative"]
# Patient context — all optional, passed through to Groq prompt
patient_name = request.data.get("patient_name", "").strip()
patient_age = request.data.get("patient_age", "").strip()
patient_sex = request.data.get("patient_sex", "").strip()
symptoms = request.data.get("symptoms", "").strip()
image_bytes = image_file.read()
# --- Run ONNX inference ---
try:
inference_result = run_inference(image_bytes, include_heatmap=include_heatmap)
except FileNotFoundError as exc:
logger.error("Model not found: %s", exc)
return Response(
{"error": "Model not loaded. Please contact the administrator."},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
except Exception as exc:
logger.exception("Inference error: %s", exc)
return Response(
{"error": "Inference failed. Check server logs."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
predictions = inference_result["predictions"] # top-3 [{label, confidence}]
heatmap_b64 = inference_result["heatmap_b64"]
# --- Generate structured clinical response via Groq ---
clinical = {}
if include_narrative:
try:
clinical = generate_clinical_response(
predictions=predictions,
patient_name=patient_name,
patient_age=patient_age,
patient_sex=patient_sex,
symptoms=symptoms,
)
except Exception as exc:
logger.warning("Clinical response generation failed: %s", exc)
# --- Build final response ---
# clinical already contains all frontend fields.
# We add predictions + heatmap + model_version on top.
response_data = {
# Full structured fields from Groq (or fallback)
"primaryFinding": clinical.get("primaryFinding", predictions[0]["label"]),
"confidence": clinical.get("confidence", round(predictions[0]["confidence"] * 100)),
"urgency": clinical.get("urgency", "Moderate"),
"urgencyText": clinical.get("urgencyText", "Refer to clinic within 3 days."),
"treatmentNotes": clinical.get("treatmentNotes", []),
"recommendedAction": clinical.get("recommendedAction", "Refer to appropriate specialist."),
"referralNote": clinical.get("referralNote", ""),
"conditionCode": clinical.get("conditionCode", "ringworm"),
"therapyRegimen": clinical.get("therapyRegimen", {}),
"patientHandout": clinical.get("patientHandout", {}),
# Raw model output — kept so frontend can show differential runner-ups
"allPredictions": predictions,
"heatmap_b64": heatmap_b64,
"model_version": MODEL_VERSION,
}
logger.info(
"Analyse complete — patient: %s | finding: %s | urgency: %s",
patient_name or "anonymous",
response_data["primaryFinding"],
response_data["urgency"],
)
return Response(response_data, status=status.HTTP_200_OK)
class GeneratePdfView(APIView):
"""
POST /pdf
Expects a JSON payload containing:
- case_id
- patient (dict with name, age, sex, symptoms, healthWorkerName)
- clinical (dict with primaryFinding, confidence, urgency, referralNote, treatmentNotes)
- images (dict with original_b64, heatmap_b64)
"""
parser_classes = [JSONParser]
def post(self, request: Request) -> Response:
try:
data = request.data
case_id = data.get("case_id", "NEW-CASE")
patient = data.get("patient", {})
clinical = data.get("clinical", {})
images = data.get("images", {})
pdf_b64 = build_pdf(case_id, patient, clinical, images)
return Response({"pdf_b64": pdf_b64}, status=status.HTTP_200_OK)
except Exception as exc:
logger.exception("Failed to generate PDF: %s", exc)
return Response({"error": str(exc)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) |