epicheck-api / app.py
jefhgofk's picture
Update app.py
e5c2f47 verified
import os
import cv2
import gc
import uuid
import base64
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File, Header, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
# -------------------------------------------------
# FORCE CPU + LOW MEMORY
# -------------------------------------------------
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
torch.set_grad_enabled(False)
torch.set_num_threads(2)
torch.set_float32_matmul_precision("high")
device = "cpu"
# -------------------------------------------------
# CONFIG
# -------------------------------------------------
DET_MODEL_PATH = "detection.pt"
CLS_MODEL_PATH = "classification.pt"
API_KEY = os.getenv("API_KEY")
# -------------------------------------------------
# FALLBACK CLASS NAMES
# -------------------------------------------------
FALLBACK_CLASS_NAMES = {
0: "melanoma", 1: "warts", 2: "basal_cell_carcinoma", 3: "tinea",
4: "tinea_versicolor", 5: "corns", 6: "chickenpox", 7: "skin_tag",
8: "cutaneous_candidiasis", 9: "pityriasis_rosea",
10: "seborrheic_dermatitis", 11: "seborrheic_keratoses",
12: "black_heel", 13: "psoriasis", 14: "molluscum_contagiosum",
15: "ichthyosis", 16: "acne", 17: "eczema",
18: "herpes_simplex", 19: "herpes_zoster",
20: "keratosis_pilaris", 21: "lichen"
}
def resolve_class_name(model, class_id: int) -> str:
try:
if hasattr(model, "names") and class_id in model.names:
return model.names[class_id]
except Exception:
pass
return FALLBACK_CLASS_NAMES.get(class_id, "unknown")
# -------------------------------------------------
# FASTAPI
# -------------------------------------------------
app = FastAPI(title="EpiCheck YOLOv8 CPU API")
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://epi-check.great-site.net",
"https://epi-check.great-site.net",
],
allow_credentials=True,
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
# -------------------------------------------------
# AUTH
# -------------------------------------------------
def verify_api_key(x_api_key: str):
if not API_KEY or x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API Key")
# -------------------------------------------------
# LOAD MODELS (ONCE)
# -------------------------------------------------
print("🚀 Loading detection model...")
det_model = YOLO(DET_MODEL_PATH)
det_model.to(device)
cls_model = None
def get_cls_model():
global cls_model
if cls_model is None:
print("⚠️ Loading classification model...")
cls_model = YOLO(CLS_MODEL_PATH)
cls_model.to(device)
return cls_model
# -------------------------------------------------
# PREDICT
# -------------------------------------------------
@app.post("/predict")
async def predict(
file: UploadFile = File(...),
x_api_key: str = Header(...)
):
verify_api_key(x_api_key)
data = await file.read()
img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid image")
try:
# ---------------- DETECTION ----------------
result = det_model(
img,
imgsz=512,
conf=0.3, # 🔥 LOWERED for better sensitivity
verbose=False
)[0]
# Debug logs (remove in production if needed)
print("Detection boxes:", result.boxes)
if result.boxes is not None and len(result.boxes) > 0:
detections = []
for b in result.boxes:
cid = int(b.cls)
detections.append({
"class_id": cid,
"class_name": resolve_class_name(det_model, cid),
"confidence": float(b.conf),
"bbox": b.xyxy[0].tolist()
})
annotated = result.plot()
_, buffer = cv2.imencode(".jpg", annotated)
annotated_b64 = base64.b64encode(buffer).decode()
del result
gc.collect()
return JSONResponse({
"model_used": "detection",
"detections": detections,
"annotated_image_base64": annotated_b64
})
# ---------------- FALLBACK ----------------
print("⚠️ No detections → using classification")
model = get_cls_model()
result = model(img, imgsz=224, verbose=False)[0]
probs = result.probs
cid = int(probs.top1)
del result
gc.collect()
return JSONResponse({
"model_used": "classification",
"class_id": cid,
"class_name": resolve_class_name(model, cid),
"confidence": float(probs.top1conf)
})
except Exception as e:
print("❌ Unexpected error:", str(e))
raise HTTPException(status_code=500, detail="Prediction failed")
# -------------------------------------------------
# HEALTH CHECK
# -------------------------------------------------
@app.get("/")
def root():
return {"status": "API is running"}
@app.get("/ping")
def ping():
return {"status": "ok"}