import os import cv2 import gc import uuid import base64 import torch import numpy as np from fastapi import FastAPI, UploadFile, File, Header, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from ultralytics import YOLO # ------------------------------------------------- # FORCE CPU + LOW MEMORY # ------------------------------------------------- os.environ["CUDA_VISIBLE_DEVICES"] = "-1" torch.set_grad_enabled(False) torch.set_num_threads(2) torch.set_float32_matmul_precision("high") device = "cpu" # ------------------------------------------------- # CONFIG # ------------------------------------------------- DET_MODEL_PATH = "detection.pt" CLS_MODEL_PATH = "classification.pt" API_KEY = os.getenv("API_KEY") # ------------------------------------------------- # FALLBACK CLASS NAMES # ------------------------------------------------- FALLBACK_CLASS_NAMES = { 0: "melanoma", 1: "warts", 2: "basal_cell_carcinoma", 3: "tinea", 4: "tinea_versicolor", 5: "corns", 6: "chickenpox", 7: "skin_tag", 8: "cutaneous_candidiasis", 9: "pityriasis_rosea", 10: "seborrheic_dermatitis", 11: "seborrheic_keratoses", 12: "black_heel", 13: "psoriasis", 14: "molluscum_contagiosum", 15: "ichthyosis", 16: "acne", 17: "eczema", 18: "herpes_simplex", 19: "herpes_zoster", 20: "keratosis_pilaris", 21: "lichen" } def resolve_class_name(model, class_id: int) -> str: try: if hasattr(model, "names") and class_id in model.names: return model.names[class_id] except Exception: pass return FALLBACK_CLASS_NAMES.get(class_id, "unknown") # ------------------------------------------------- # FASTAPI # ------------------------------------------------- app = FastAPI(title="EpiCheck YOLOv8 CPU API") app.add_middleware( CORSMiddleware, allow_origins=[ "http://epi-check.great-site.net", "https://epi-check.great-site.net", ], allow_credentials=True, allow_methods=["POST", "GET"], allow_headers=["*"], ) # ------------------------------------------------- # AUTH # ------------------------------------------------- def verify_api_key(x_api_key: str): if not API_KEY or x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API Key") # ------------------------------------------------- # LOAD MODELS (ONCE) # ------------------------------------------------- print("🚀 Loading detection model...") det_model = YOLO(DET_MODEL_PATH) det_model.to(device) cls_model = None def get_cls_model(): global cls_model if cls_model is None: print("⚠️ Loading classification model...") cls_model = YOLO(CLS_MODEL_PATH) cls_model.to(device) return cls_model # ------------------------------------------------- # PREDICT # ------------------------------------------------- @app.post("/predict") async def predict( file: UploadFile = File(...), x_api_key: str = Header(...) ): verify_api_key(x_api_key) data = await file.read() img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) if img is None: raise HTTPException(status_code=400, detail="Invalid image") try: # ---------------- DETECTION ---------------- result = det_model( img, imgsz=512, conf=0.3, # 🔥 LOWERED for better sensitivity verbose=False )[0] # Debug logs (remove in production if needed) print("Detection boxes:", result.boxes) if result.boxes is not None and len(result.boxes) > 0: detections = [] for b in result.boxes: cid = int(b.cls) detections.append({ "class_id": cid, "class_name": resolve_class_name(det_model, cid), "confidence": float(b.conf), "bbox": b.xyxy[0].tolist() }) annotated = result.plot() _, buffer = cv2.imencode(".jpg", annotated) annotated_b64 = base64.b64encode(buffer).decode() del result gc.collect() return JSONResponse({ "model_used": "detection", "detections": detections, "annotated_image_base64": annotated_b64 }) # ---------------- FALLBACK ---------------- print("⚠️ No detections → using classification") model = get_cls_model() result = model(img, imgsz=224, verbose=False)[0] probs = result.probs cid = int(probs.top1) del result gc.collect() return JSONResponse({ "model_used": "classification", "class_id": cid, "class_name": resolve_class_name(model, cid), "confidence": float(probs.top1conf) }) except Exception as e: print("❌ Unexpected error:", str(e)) raise HTTPException(status_code=500, detail="Prediction failed") # ------------------------------------------------- # HEALTH CHECK # ------------------------------------------------- @app.get("/") def root(): return {"status": "API is running"} @app.get("/ping") def ping(): return {"status": "ok"}