import torch import torch.nn.functional as F from transformers import AutoModelForImageClassification, pipeline from torchvision import transforms from PIL import Image, ImageStat from fastapi import FastAPI, File, UploadFile from fastapi.responses import HTMLResponse import numpy as np import io import gc import librosa import soundfile as sf from datetime import datetime # ========================================== # 1. CONFIGURATION & GUARDRAILS # ========================================== MODELS = { "lungs": { "type": "image", "id": "nickmuchi/vit-finetuned-chest-xray-pneumonia", "desc": "Chest X-Ray Analysis", "safe": ["NORMAL", "normal", "No Pneumonia"], "rules": {"max_sat": 30, "reject_msg": "❌ Invalid: Too colorful. Please upload a B&W X-Ray."} }, "cough": { "type": "audio", "id": "MIT/ast-finetuned-audioset-10-10-0.4593", "desc": "Respiratory Audio Analysis", "target_labels": ["Cough", "Throat clearing", "Respiratory sounds", "Wheeze", "Gasping"], "rules": {"min_duration": 0.5, "reject_msg": "❌ Invalid: Audio too short or silent."} }, "fracture": { "type": "image", "id": "nickmuchi/vit-finetuned-chest-xray-pneumonia", "desc": "Bone Trauma X-Ray", "safe": ["NORMAL", "normal", "No Pneumonia"], "rules": {"max_sat": 30, "reject_msg": "❌ Invalid: Too colorful. Please upload a B&W X-Ray."} }, "brain": { "type": "image", "id": "Hemgg/brain-tumor-classification", "desc": "Brain MRI Scan Analysis", "safe": ["no_tumor"], "rules": {"max_sat": 30, "reject_msg": "❌ Invalid: This looks like a Photo. Please upload a B&W MRI Scan."} }, "eye": { "type": "image", "id": "AventIQ-AI/resnet18-cataract-detection-system", "desc": "Ophthalmology Scan", "safe": ["Normal", "normal", "healthy"], "rules": {"min_sat": 20, "min_white": 0.05, "reject_msg": "❌ Invalid: No eye detected (Missing white sclera)."} }, "skin": { "type": "image", "id": "Anwarkh1/Skin_Cancer-Image_Classification", "desc": "Dermatology Lesion Scan", "safe": ["Benign", "benign", "nv", "bkl"], "rules": {"min_sat": 20, "max_white": 0.15, "reject_msg": "❌ Invalid: Image looks like an Eye or Document."} } } # ========================================== # 2. MEDICAL ENGINE (Logic) # ========================================== class MedicalEngine: def __init__(self): self.device = "cpu" print("✅ System Initialized: Medical Engine Ready") self.img_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def validate_image(self, image, task): rules = MODELS[task].get("rules", {}) img_hsv = image.convert('HSV') img_np = np.array(img_hsv) s_channel = img_np[:, :, 1] v_channel = img_np[:, :, 2] avg_sat = np.mean(s_channel) white_pixels = np.logical_and(s_channel < 40, v_channel > 180) white_ratio = np.sum(white_pixels) / white_pixels.size print(f"🔍 Analysis [{task}]: Sat={int(avg_sat)}, WhiteRatio={white_ratio:.3f}") if "max_sat" in rules and avg_sat > rules["max_sat"]: return False, rules["reject_msg"] if "min_sat" in rules and avg_sat < rules["min_sat"]: return False, "❌ Invalid: Image is B&W. Color photo required." if "min_white" in rules and white_ratio < rules["min_white"]: return False, rules["reject_msg"] if "max_white" in rules and white_ratio > rules["max_white"]: return False, rules["reject_msg"] return True, "" def validate_audio(self, audio_array, sr): duration = len(audio_array) / sr if duration < 0.5: return False, "❌ Audio too short (< 0.5s)." if np.max(np.abs(audio_array)) < 0.01: return False, "❌ Audio is silent/empty." return True, "" def predict(self, file_bytes, task): model_cfg = MODELS[task] if model_cfg["type"] == "audio": try: with open("temp_audio_input", "wb") as f: f.write(file_bytes) try: audio, sr = librosa.load("temp_audio_input", sr=16000) except: return {"error": "Audio Format Error. Use .wav or .mp3", "risk": "INVALID"} is_valid, msg = self.validate_audio(audio, sr) if not is_valid: return {"error": msg, "risk": "INVALID"} classifier = pipeline("audio-classification", model=model_cfg["id"]) outputs = classifier("temp_audio_input") top = outputs[0] is_cough = any(t in res['label'] for res in outputs[:3] for t in model_cfg["target_labels"]) risk = "HIGH" if is_cough and top['score'] > 0.4 else "LOW" label = f"Detected: {top['label']}" if is_cough else "Normal Background Noise" return {"task": task, "desc": model_cfg["desc"], "prediction": {"label": label, "score": top['score']}, "risk": risk} except Exception as e: return {"error": f"Audio Error: {str(e)}"} else: try: image = Image.open(io.BytesIO(file_bytes)).convert("RGB") is_valid, msg = self.validate_image(image, task) if not is_valid: return {"task": task, "risk": "INVALID", "error": msg, "prediction": {"label": "Rejected", "score": 0.0}} print(f"⏳ Loading Model: {task}...") model = AutoModelForImageClassification.from_pretrained(model_cfg["id"]) model.to(self.device) model.eval() inputs = self.img_transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = model(inputs) probs = F.softmax(outputs.logits, dim=-1) results = [{"label": model.config.id2label[i], "score": float(score)} for i, score in enumerate(probs[0])] results.sort(key=lambda x: x['score'], reverse=True) top = results[0] safe_words = model_cfg["safe"] is_safe = any(s.lower() in top["label"].lower() for s in safe_words) if top["score"] < 0.5: risk = "UNCERTAIN" elif is_safe: risk = "LOW" else: risk = "HIGH" if top["score"] > 0.70 else "MODERATE" if task == "fracture": top["label"] = "Fracture / Anomaly" if risk in ["HIGH", "MODERATE"] else "Healthy Bone" del model gc.collect() return {"task": task, "desc": model_cfg["desc"], "prediction": top, "risk": risk} except Exception as e: return {"error": f"Image Error: {str(e)}"} # ========================================== # 3. API & FRONTEND # ========================================== app = FastAPI() engine = MedicalEngine() HISTORY = [] @app.post("/predict/{task}") async def predict_route(task: str, patient: str, age: str, file: UploadFile = File(...)): if task not in MODELS: return {"error": "Invalid Task"} content = await file.read() result = engine.predict(content, task) if "error" not in result and result.get("risk") != "INVALID": HISTORY.insert(0, {"time": datetime.now().strftime("%H:%M"), "patient": patient, "task": task.capitalize(), "diagnosis": result["prediction"]["label"], "risk": result["risk"]}) return result @app.get("/history") def get_history(): return HISTORY @app.post("/reset_history") def reset_history(): global HISTORY HISTORY = [] return {"status": "cleared"} @app.get("/", response_class=HTMLResponse) def home(): return """ MediScan Rural | Govt of Meghalaya

Select a Category

Tap to upload

Supported: JPG, PNG, WAV, MP3

Recent Patients

TimePatientCategoryDiagnosisRisk
"""