Spaces:
Sleeping
Sleeping
| import os, io, base64, time | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| MODEL_PATH = "best_model.pth" | |
| MODEL_TYPE = "pytorch" | |
| IMG_SIZE = 640 | |
| CONFIDENCE = 0.35 | |
| DEVICE = "cpu" | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tumhari actual classes | |
| TOOTH_CLASSES = ["Molars", "Premolars", "Canines", "Incisors", "Filling"] | |
| # Tumhare actual colors (RGB) | |
| CLASS_COLORS = { | |
| "Molars": (255, 0, 250), | |
| "Premolars": (255, 0, 8), | |
| "Canines": (0, 127, 255), | |
| "Incisors": (42, 255, 0), | |
| "Filling": (196, 201, 255), | |
| } | |
| DEFAULT_COLOR = (99, 102, 241) | |
| # Spelling variations fix — model ki galat spelling ko sahi naam pe map karo | |
| NAME_MAP = { | |
| "molar": "Molars", "Molar": "Molars", "Molars": "Molars", | |
| "premolar": "Premolars", "Premolar": "Premolars", "Premolars": "Premolars", | |
| "canine": "Canines", "Canine": "Canines", "Canines": "Canines", | |
| "cannine": "Canines", "canin": "Canines", | |
| "incisor": "Incisors", "Incisor": "Incisors", "Incisors": "Incisors", | |
| "incissors": "Incisors", "incissor": "Incisors", | |
| "filling": "Filling", "Filling": "Filling", "Fillings": "Filling", | |
| } | |
| def fix_name(raw: str) -> str: | |
| return NAME_MAP.get(raw, raw) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| app = FastAPI(title="DentalOS Tooth Segmentation API", version="1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| model = None | |
| async def load_model(): | |
| global model | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"⚠️ Model file not found: {MODEL_PATH} — running in DEMO mode") | |
| return | |
| print(f"🔄 Loading model from {MODEL_PATH} ...") | |
| t0 = time.time() | |
| import torch | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) | |
| # Checkpoint se model nikalo | |
| if isinstance(checkpoint, dict): | |
| if 'model' in checkpoint: | |
| model = checkpoint['model'] | |
| elif 'state_dict' in checkpoint: | |
| model = checkpoint['state_dict'] | |
| else: | |
| # Sirf weights hain — keys print karo debug ke liye | |
| print(f"Checkpoint keys: {list(checkpoint.keys())}") | |
| model = checkpoint | |
| else: | |
| model = checkpoint | |
| if hasattr(model, 'eval'): | |
| model.eval() | |
| print(f"✅ Model loaded in {time.time()-t0:.2f}s") | |
| def root(): | |
| return { | |
| "service": "DentalOS Tooth Segmentation API", | |
| "model_loaded": model is not None, | |
| "model_type": MODEL_TYPE, | |
| "status": "ready" if model else "demo_mode", | |
| } | |
| def health(): | |
| return {"status": "ok", "model": "loaded" if model else "demo"} | |
| async def segment(file: UploadFile = File(...)): | |
| if file.content_type not in ["image/jpeg", "image/png", "image/jpg", "image/webp"]: | |
| raise HTTPException(400, "Only JPG/PNG supported.") | |
| contents = await file.read() | |
| img_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| t0 = time.time() | |
| if model is None: | |
| detections, annotated_b64 = _demo_response(img_pil) | |
| else: | |
| detections, annotated_b64 = _run_pytorch(img_pil) | |
| inference_ms = round((time.time() - t0) * 1000) | |
| summary = { | |
| "Molars": len([d for d in detections if d["label"] == "Molars"]), | |
| "Premolars": len([d for d in detections if d["label"] == "Premolars"]), | |
| "Canines": len([d for d in detections if d["label"] == "Canines"]), | |
| "Incisors": len([d for d in detections if d["label"] == "Incisors"]), | |
| "Fillings": len([d for d in detections if d["label"] == "Filling"]), | |
| } | |
| return JSONResponse({ | |
| "annotated_image": annotated_b64, | |
| "detections": detections, | |
| "summary": summary, | |
| "inference_ms": inference_ms, | |
| "model_type": MODEL_TYPE, | |
| "is_demo": model is None, | |
| }) | |
| def _run_pytorch(img_pil: Image.Image): | |
| import torch | |
| img_resized = img_pil.resize((IMG_SIZE, IMG_SIZE)) | |
| img_arr = np.array(img_resized).astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_arr.transpose(2, 0, 1)).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| detections = [] | |
| annotated_img = img_pil.copy() | |
| draw = ImageDraw.Draw(annotated_img) | |
| w, h = img_pil.size | |
| sx, sy = w / IMG_SIZE, h / IMG_SIZE | |
| # outputs shape: [1, num_preds, 9] for [x1,y1,x2,y2,conf, c0,c1,c2,c3,c4] | |
| try: | |
| preds = outputs[0] if isinstance(outputs, (list, tuple)) else outputs | |
| if hasattr(preds, "cpu"): | |
| preds = preds.cpu().numpy() | |
| for pred in preds: | |
| pred = pred.flatten() | |
| if len(pred) < 5: | |
| continue | |
| conf = float(pred[4]) | |
| if conf < CONFIDENCE: | |
| continue | |
| x1, y1, x2, y2 = int(pred[0]*sx), int(pred[1]*sy), int(pred[2]*sx), int(pred[3]*sy) | |
| if len(pred) > 5: | |
| cls_id = int(np.argmax(pred[5:])) | |
| else: | |
| cls_id = 0 | |
| raw_label = TOOTH_CLASSES[cls_id] if cls_id < len(TOOTH_CLASSES) else "Unknown" | |
| label = fix_name(raw_label) | |
| color = CLASS_COLORS.get(label, DEFAULT_COLOR) | |
| detections.append({ | |
| "label": label, | |
| "tooth_id": label, | |
| "confidence": round(conf, 3), | |
| "bbox": [x1, y1, x2, y2], | |
| }) | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| text = f"{label} {conf:.0%}" | |
| draw.rectangle([x1, y1-18, x1+len(text)*7+4, y1], fill=color) | |
| draw.text((x1+2, y1-16), text, fill="white") | |
| except Exception as e: | |
| print(f"Inference parse error: {e}") | |
| return detections, _img_to_b64(annotated_img) | |
| def _demo_response(img_pil: Image.Image): | |
| draw = ImageDraw.Draw(img_pil) | |
| w, h = img_pil.size | |
| mock = [ | |
| {"label":"Molars", "tooth_id":"Molars", "confidence":0.91, "bbox":[int(w*0.1), int(h*0.15), int(w*0.25), int(h*0.45)]}, | |
| {"label":"Premolars", "tooth_id":"Premolars", "confidence":0.87, "bbox":[int(w*0.3), int(h*0.15), int(w*0.45), int(h*0.45)]}, | |
| {"label":"Incisors", "tooth_id":"Incisors", "confidence":0.93, "bbox":[int(w*0.42),int(h*0.1), int(w*0.58), int(h*0.42)]}, | |
| {"label":"Canines", "tooth_id":"Canines", "confidence":0.85, "bbox":[int(w*0.6), int(h*0.15), int(w*0.72), int(h*0.45)]}, | |
| {"label":"Filling", "tooth_id":"Filling", "confidence":0.78, "bbox":[int(w*0.3), int(h*0.55), int(w*0.45), int(h*0.80)]}, | |
| ] | |
| for d in mock: | |
| x1, y1, x2, y2 = d["bbox"] | |
| color = CLASS_COLORS.get(d["label"], DEFAULT_COLOR) | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| text = f"{d['label']} {d['confidence']:.0%}" | |
| draw.rectangle([x1, y1-18, x1+len(text)*7+4, y1], fill=color) | |
| draw.text((x1+2, y1-16), text, fill="white") | |
| return mock, _img_to_b64(img_pil) | |
| def _img_to_b64(img: Image.Image) -> str: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |