Rukhsar9684's picture
Update app.py
cd9fcad verified
import os, io, base64, time
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image, ImageDraw
import numpy as np
# ─────────────────────────────────────────────────────────────────────────────
MODEL_PATH = "best_model.pth"
MODEL_TYPE = "pytorch"
IMG_SIZE = 640
CONFIDENCE = 0.35
DEVICE = "cpu"
# ─────────────────────────────────────────────────────────────────────────────
# Tumhari actual classes
TOOTH_CLASSES = ["Molars", "Premolars", "Canines", "Incisors", "Filling"]
# Tumhare actual colors (RGB)
CLASS_COLORS = {
"Molars": (255, 0, 250),
"Premolars": (255, 0, 8),
"Canines": (0, 127, 255),
"Incisors": (42, 255, 0),
"Filling": (196, 201, 255),
}
DEFAULT_COLOR = (99, 102, 241)
# Spelling variations fix — model ki galat spelling ko sahi naam pe map karo
NAME_MAP = {
"molar": "Molars", "Molar": "Molars", "Molars": "Molars",
"premolar": "Premolars", "Premolar": "Premolars", "Premolars": "Premolars",
"canine": "Canines", "Canine": "Canines", "Canines": "Canines",
"cannine": "Canines", "canin": "Canines",
"incisor": "Incisors", "Incisor": "Incisors", "Incisors": "Incisors",
"incissors": "Incisors", "incissor": "Incisors",
"filling": "Filling", "Filling": "Filling", "Fillings": "Filling",
}
def fix_name(raw: str) -> str:
return NAME_MAP.get(raw, raw)
# ─────────────────────────────────────────────────────────────────────────────
app = FastAPI(title="DentalOS Tooth Segmentation API", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = None
@app.on_event("startup")
async def load_model():
global model
if not os.path.exists(MODEL_PATH):
print(f"⚠️ Model file not found: {MODEL_PATH} — running in DEMO mode")
return
print(f"🔄 Loading model from {MODEL_PATH} ...")
t0 = time.time()
import torch
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
# Checkpoint se model nikalo
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
model = checkpoint['model']
elif 'state_dict' in checkpoint:
model = checkpoint['state_dict']
else:
# Sirf weights hain — keys print karo debug ke liye
print(f"Checkpoint keys: {list(checkpoint.keys())}")
model = checkpoint
else:
model = checkpoint
if hasattr(model, 'eval'):
model.eval()
print(f"✅ Model loaded in {time.time()-t0:.2f}s")
@app.get("/")
def root():
return {
"service": "DentalOS Tooth Segmentation API",
"model_loaded": model is not None,
"model_type": MODEL_TYPE,
"status": "ready" if model else "demo_mode",
}
@app.get("/health")
def health():
return {"status": "ok", "model": "loaded" if model else "demo"}
@app.post("/segment")
async def segment(file: UploadFile = File(...)):
if file.content_type not in ["image/jpeg", "image/png", "image/jpg", "image/webp"]:
raise HTTPException(400, "Only JPG/PNG supported.")
contents = await file.read()
img_pil = Image.open(io.BytesIO(contents)).convert("RGB")
t0 = time.time()
if model is None:
detections, annotated_b64 = _demo_response(img_pil)
else:
detections, annotated_b64 = _run_pytorch(img_pil)
inference_ms = round((time.time() - t0) * 1000)
summary = {
"Molars": len([d for d in detections if d["label"] == "Molars"]),
"Premolars": len([d for d in detections if d["label"] == "Premolars"]),
"Canines": len([d for d in detections if d["label"] == "Canines"]),
"Incisors": len([d for d in detections if d["label"] == "Incisors"]),
"Fillings": len([d for d in detections if d["label"] == "Filling"]),
}
return JSONResponse({
"annotated_image": annotated_b64,
"detections": detections,
"summary": summary,
"inference_ms": inference_ms,
"model_type": MODEL_TYPE,
"is_demo": model is None,
})
def _run_pytorch(img_pil: Image.Image):
import torch
img_resized = img_pil.resize((IMG_SIZE, IMG_SIZE))
img_arr = np.array(img_resized).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_arr.transpose(2, 0, 1)).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(img_tensor)
detections = []
annotated_img = img_pil.copy()
draw = ImageDraw.Draw(annotated_img)
w, h = img_pil.size
sx, sy = w / IMG_SIZE, h / IMG_SIZE
# outputs shape: [1, num_preds, 9] for [x1,y1,x2,y2,conf, c0,c1,c2,c3,c4]
try:
preds = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
if hasattr(preds, "cpu"):
preds = preds.cpu().numpy()
for pred in preds:
pred = pred.flatten()
if len(pred) < 5:
continue
conf = float(pred[4])
if conf < CONFIDENCE:
continue
x1, y1, x2, y2 = int(pred[0]*sx), int(pred[1]*sy), int(pred[2]*sx), int(pred[3]*sy)
if len(pred) > 5:
cls_id = int(np.argmax(pred[5:]))
else:
cls_id = 0
raw_label = TOOTH_CLASSES[cls_id] if cls_id < len(TOOTH_CLASSES) else "Unknown"
label = fix_name(raw_label)
color = CLASS_COLORS.get(label, DEFAULT_COLOR)
detections.append({
"label": label,
"tooth_id": label,
"confidence": round(conf, 3),
"bbox": [x1, y1, x2, y2],
})
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
text = f"{label} {conf:.0%}"
draw.rectangle([x1, y1-18, x1+len(text)*7+4, y1], fill=color)
draw.text((x1+2, y1-16), text, fill="white")
except Exception as e:
print(f"Inference parse error: {e}")
return detections, _img_to_b64(annotated_img)
def _demo_response(img_pil: Image.Image):
draw = ImageDraw.Draw(img_pil)
w, h = img_pil.size
mock = [
{"label":"Molars", "tooth_id":"Molars", "confidence":0.91, "bbox":[int(w*0.1), int(h*0.15), int(w*0.25), int(h*0.45)]},
{"label":"Premolars", "tooth_id":"Premolars", "confidence":0.87, "bbox":[int(w*0.3), int(h*0.15), int(w*0.45), int(h*0.45)]},
{"label":"Incisors", "tooth_id":"Incisors", "confidence":0.93, "bbox":[int(w*0.42),int(h*0.1), int(w*0.58), int(h*0.42)]},
{"label":"Canines", "tooth_id":"Canines", "confidence":0.85, "bbox":[int(w*0.6), int(h*0.15), int(w*0.72), int(h*0.45)]},
{"label":"Filling", "tooth_id":"Filling", "confidence":0.78, "bbox":[int(w*0.3), int(h*0.55), int(w*0.45), int(h*0.80)]},
]
for d in mock:
x1, y1, x2, y2 = d["bbox"]
color = CLASS_COLORS.get(d["label"], DEFAULT_COLOR)
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
text = f"{d['label']} {d['confidence']:.0%}"
draw.rectangle([x1, y1-18, x1+len(text)*7+4, y1], fill=color)
draw.text((x1+2, y1-16), text, fill="white")
return mock, _img_to_b64(img_pil)
def _img_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)