Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gc | |
| import uuid | |
| import base64 | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Header, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from ultralytics import YOLO | |
| # ------------------------------------------------- | |
| # FORCE CPU + LOW MEMORY | |
| # ------------------------------------------------- | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| torch.set_grad_enabled(False) | |
| torch.set_num_threads(2) | |
| torch.set_float32_matmul_precision("high") | |
| device = "cpu" | |
| # ------------------------------------------------- | |
| # CONFIG | |
| # ------------------------------------------------- | |
| DET_MODEL_PATH = "detection.pt" | |
| CLS_MODEL_PATH = "classification.pt" | |
| API_KEY = os.getenv("API_KEY") | |
| # ------------------------------------------------- | |
| # FALLBACK CLASS NAMES | |
| # ------------------------------------------------- | |
| FALLBACK_CLASS_NAMES = { | |
| 0: "melanoma", 1: "warts", 2: "basal_cell_carcinoma", 3: "tinea", | |
| 4: "tinea_versicolor", 5: "corns", 6: "chickenpox", 7: "skin_tag", | |
| 8: "cutaneous_candidiasis", 9: "pityriasis_rosea", | |
| 10: "seborrheic_dermatitis", 11: "seborrheic_keratoses", | |
| 12: "black_heel", 13: "psoriasis", 14: "molluscum_contagiosum", | |
| 15: "ichthyosis", 16: "acne", 17: "eczema", | |
| 18: "herpes_simplex", 19: "herpes_zoster", | |
| 20: "keratosis_pilaris", 21: "lichen" | |
| } | |
| def resolve_class_name(model, class_id: int) -> str: | |
| try: | |
| if hasattr(model, "names") and class_id in model.names: | |
| return model.names[class_id] | |
| except Exception: | |
| pass | |
| return FALLBACK_CLASS_NAMES.get(class_id, "unknown") | |
| # ------------------------------------------------- | |
| # FASTAPI | |
| # ------------------------------------------------- | |
| app = FastAPI(title="EpiCheck YOLOv8 CPU API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://epi-check.great-site.net", | |
| "https://epi-check.great-site.net", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------------------------------- | |
| # AUTH | |
| # ------------------------------------------------- | |
| def verify_api_key(x_api_key: str): | |
| if not API_KEY or x_api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API Key") | |
| # ------------------------------------------------- | |
| # LOAD MODELS (ONCE) | |
| # ------------------------------------------------- | |
| print("🚀 Loading detection model...") | |
| det_model = YOLO(DET_MODEL_PATH) | |
| det_model.to(device) | |
| cls_model = None | |
| def get_cls_model(): | |
| global cls_model | |
| if cls_model is None: | |
| print("⚠️ Loading classification model...") | |
| cls_model = YOLO(CLS_MODEL_PATH) | |
| cls_model.to(device) | |
| return cls_model | |
| # ------------------------------------------------- | |
| # PREDICT | |
| # ------------------------------------------------- | |
| async def predict( | |
| file: UploadFile = File(...), | |
| x_api_key: str = Header(...) | |
| ): | |
| verify_api_key(x_api_key) | |
| data = await file.read() | |
| img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise HTTPException(status_code=400, detail="Invalid image") | |
| try: | |
| # ---------------- DETECTION ---------------- | |
| result = det_model( | |
| img, | |
| imgsz=512, | |
| conf=0.3, # 🔥 LOWERED for better sensitivity | |
| verbose=False | |
| )[0] | |
| # Debug logs (remove in production if needed) | |
| print("Detection boxes:", result.boxes) | |
| if result.boxes is not None and len(result.boxes) > 0: | |
| detections = [] | |
| for b in result.boxes: | |
| cid = int(b.cls) | |
| detections.append({ | |
| "class_id": cid, | |
| "class_name": resolve_class_name(det_model, cid), | |
| "confidence": float(b.conf), | |
| "bbox": b.xyxy[0].tolist() | |
| }) | |
| annotated = result.plot() | |
| _, buffer = cv2.imencode(".jpg", annotated) | |
| annotated_b64 = base64.b64encode(buffer).decode() | |
| del result | |
| gc.collect() | |
| return JSONResponse({ | |
| "model_used": "detection", | |
| "detections": detections, | |
| "annotated_image_base64": annotated_b64 | |
| }) | |
| # ---------------- FALLBACK ---------------- | |
| print("⚠️ No detections → using classification") | |
| model = get_cls_model() | |
| result = model(img, imgsz=224, verbose=False)[0] | |
| probs = result.probs | |
| cid = int(probs.top1) | |
| del result | |
| gc.collect() | |
| return JSONResponse({ | |
| "model_used": "classification", | |
| "class_id": cid, | |
| "class_name": resolve_class_name(model, cid), | |
| "confidence": float(probs.top1conf) | |
| }) | |
| except Exception as e: | |
| print("❌ Unexpected error:", str(e)) | |
| raise HTTPException(status_code=500, detail="Prediction failed") | |
| # ------------------------------------------------- | |
| # HEALTH CHECK | |
| # ------------------------------------------------- | |
| def root(): | |
| return {"status": "API is running"} | |
| def ping(): | |
| return {"status": "ok"} |