import os, io, base64, time from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from PIL import Image, ImageDraw import numpy as np # ───────────────────────────────────────────────────────────────────────────── MODEL_PATH = "best_model.pth" MODEL_TYPE = "pytorch" IMG_SIZE = 640 CONFIDENCE = 0.35 DEVICE = "cpu" # ───────────────────────────────────────────────────────────────────────────── # Tumhari actual classes TOOTH_CLASSES = ["Molars", "Premolars", "Canines", "Incisors", "Filling"] # Tumhare actual colors (RGB) CLASS_COLORS = { "Molars": (255, 0, 250), "Premolars": (255, 0, 8), "Canines": (0, 127, 255), "Incisors": (42, 255, 0), "Filling": (196, 201, 255), } DEFAULT_COLOR = (99, 102, 241) # Spelling variations fix — model ki galat spelling ko sahi naam pe map karo NAME_MAP = { "molar": "Molars", "Molar": "Molars", "Molars": "Molars", "premolar": "Premolars", "Premolar": "Premolars", "Premolars": "Premolars", "canine": "Canines", "Canine": "Canines", "Canines": "Canines", "cannine": "Canines", "canin": "Canines", "incisor": "Incisors", "Incisor": "Incisors", "Incisors": "Incisors", "incissors": "Incisors", "incissor": "Incisors", "filling": "Filling", "Filling": "Filling", "Fillings": "Filling", } def fix_name(raw: str) -> str: return NAME_MAP.get(raw, raw) # ───────────────────────────────────────────────────────────────────────────── app = FastAPI(title="DentalOS Tooth Segmentation API", version="1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = None @app.on_event("startup") async def load_model(): global model if not os.path.exists(MODEL_PATH): print(f"⚠️ Model file not found: {MODEL_PATH} — running in DEMO mode") return print(f"🔄 Loading model from {MODEL_PATH} ...") t0 = time.time() import torch checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) # Checkpoint se model nikalo if isinstance(checkpoint, dict): if 'model' in checkpoint: model = checkpoint['model'] elif 'state_dict' in checkpoint: model = checkpoint['state_dict'] else: # Sirf weights hain — keys print karo debug ke liye print(f"Checkpoint keys: {list(checkpoint.keys())}") model = checkpoint else: model = checkpoint if hasattr(model, 'eval'): model.eval() print(f"✅ Model loaded in {time.time()-t0:.2f}s") @app.get("/") def root(): return { "service": "DentalOS Tooth Segmentation API", "model_loaded": model is not None, "model_type": MODEL_TYPE, "status": "ready" if model else "demo_mode", } @app.get("/health") def health(): return {"status": "ok", "model": "loaded" if model else "demo"} @app.post("/segment") async def segment(file: UploadFile = File(...)): if file.content_type not in ["image/jpeg", "image/png", "image/jpg", "image/webp"]: raise HTTPException(400, "Only JPG/PNG supported.") contents = await file.read() img_pil = Image.open(io.BytesIO(contents)).convert("RGB") t0 = time.time() if model is None: detections, annotated_b64 = _demo_response(img_pil) else: detections, annotated_b64 = _run_pytorch(img_pil) inference_ms = round((time.time() - t0) * 1000) summary = { "Molars": len([d for d in detections if d["label"] == "Molars"]), "Premolars": len([d for d in detections if d["label"] == "Premolars"]), "Canines": len([d for d in detections if d["label"] == "Canines"]), "Incisors": len([d for d in detections if d["label"] == "Incisors"]), "Fillings": len([d for d in detections if d["label"] == "Filling"]), } return JSONResponse({ "annotated_image": annotated_b64, "detections": detections, "summary": summary, "inference_ms": inference_ms, "model_type": MODEL_TYPE, "is_demo": model is None, }) def _run_pytorch(img_pil: Image.Image): import torch img_resized = img_pil.resize((IMG_SIZE, IMG_SIZE)) img_arr = np.array(img_resized).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_arr.transpose(2, 0, 1)).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(img_tensor) detections = [] annotated_img = img_pil.copy() draw = ImageDraw.Draw(annotated_img) w, h = img_pil.size sx, sy = w / IMG_SIZE, h / IMG_SIZE # outputs shape: [1, num_preds, 9] for [x1,y1,x2,y2,conf, c0,c1,c2,c3,c4] try: preds = outputs[0] if isinstance(outputs, (list, tuple)) else outputs if hasattr(preds, "cpu"): preds = preds.cpu().numpy() for pred in preds: pred = pred.flatten() if len(pred) < 5: continue conf = float(pred[4]) if conf < CONFIDENCE: continue x1, y1, x2, y2 = int(pred[0]*sx), int(pred[1]*sy), int(pred[2]*sx), int(pred[3]*sy) if len(pred) > 5: cls_id = int(np.argmax(pred[5:])) else: cls_id = 0 raw_label = TOOTH_CLASSES[cls_id] if cls_id < len(TOOTH_CLASSES) else "Unknown" label = fix_name(raw_label) color = CLASS_COLORS.get(label, DEFAULT_COLOR) detections.append({ "label": label, "tooth_id": label, "confidence": round(conf, 3), "bbox": [x1, y1, x2, y2], }) draw.rectangle([x1, y1, x2, y2], outline=color, width=3) text = f"{label} {conf:.0%}" draw.rectangle([x1, y1-18, x1+len(text)*7+4, y1], fill=color) draw.text((x1+2, y1-16), text, fill="white") except Exception as e: print(f"Inference parse error: {e}") return detections, _img_to_b64(annotated_img) def _demo_response(img_pil: Image.Image): draw = ImageDraw.Draw(img_pil) w, h = img_pil.size mock = [ {"label":"Molars", "tooth_id":"Molars", "confidence":0.91, "bbox":[int(w*0.1), int(h*0.15), int(w*0.25), int(h*0.45)]}, {"label":"Premolars", "tooth_id":"Premolars", "confidence":0.87, "bbox":[int(w*0.3), int(h*0.15), int(w*0.45), int(h*0.45)]}, {"label":"Incisors", "tooth_id":"Incisors", "confidence":0.93, "bbox":[int(w*0.42),int(h*0.1), int(w*0.58), int(h*0.42)]}, {"label":"Canines", "tooth_id":"Canines", "confidence":0.85, "bbox":[int(w*0.6), int(h*0.15), int(w*0.72), int(h*0.45)]}, {"label":"Filling", "tooth_id":"Filling", "confidence":0.78, "bbox":[int(w*0.3), int(h*0.55), int(w*0.45), int(h*0.80)]}, ] for d in mock: x1, y1, x2, y2 = d["bbox"] color = CLASS_COLORS.get(d["label"], DEFAULT_COLOR) draw.rectangle([x1, y1, x2, y2], outline=color, width=3) text = f"{d['label']} {d['confidence']:.0%}" draw.rectangle([x1, y1-18, x1+len(text)*7+4, y1], fill=color) draw.text((x1+2, y1-16), text, fill="white") return mock, _img_to_b64(img_pil) def _img_to_b64(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="PNG") return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)