# app/main.py from fastapi import FastAPI, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, JSONResponse from app.model import load_model, predict_from_bytes from app.inference import load_classification_model, classify_bytes from app.inference import load_classification_model, classify_bytes from app.inference_yolo import classify_yolo_bytes, load_yolo_model # from app.model import load_model, predict_pca_from_bytes from ood_detector import OODDetector from PIL import Image import io import os import uuid from huggingface_hub import HfApi import json, os import hashlib # ────────────────────────────────────────────── # FastAPI setup # ────────────────────────────────────────────── app = FastAPI(title="NEMO Tools") # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # ────────────────────────────────────────────── # Static Frontend # ────────────────────────────────────────────── BASE_DIR = os.path.dirname(__file__) STATIC_DIR = os.path.join(BASE_DIR, "static") INDEX_HTML = os.path.join(STATIC_DIR, "index.html") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # --- CONFIGURATION --- HF_TOKEN = os.environ.get("HF_TOKEN") DATASET_REPO_ID = "AndrewKof/NEMO-user-uploads" api = HfApi(token=HF_TOKEN) OOD_PATH = os.path.join(os.path.dirname(__file__), "OOD_Features") # Check if artifacts exist before loading if os.path.exists(OOD_PATH): ood_detector = OODDetector( model_path="Arew99/dinov2-costum", # Or your local MODEL_DIR feature_dir=OOD_PATH ) print("✅ OOD Detector initialized.") else: ood_detector = None print("⚠️ OOD artifacts not found. OOD detection will be skipped.") def save_image_to_hub(image_bytes): """ Uploads image only if it doesn't already exist in the dataset. Uses SHA256 hash of the content to detect duplicates. """ # 1. Calculate the hash of the image content file_hash = hashlib.sha256(image_bytes).hexdigest() # 2. Use the hash as the filename (e.g., "user_images/a1b2c3d4....png") filename = f"user_images/{file_hash}.png" try: # 3. Check if this specific file already exists on the Hub if api.file_exists(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset"): print(f"Skipping: {filename} already exists in dataset.") return # <--- STOP HERE print(f"New image detected. Uploading {filename}...") # 4. Upload if it's new file_object = io.BytesIO(image_bytes) api.upload_file( path_or_fileobj=file_object, path_in_repo=filename, repo_id=DATASET_REPO_ID, repo_type="dataset" ) print("Upload successful!") except Exception as e: print(f"Error checking/uploading image: {e}") @app.get("/", response_class=HTMLResponse) def serve_frontend(): """Serve the web interface.""" with open(INDEX_HTML, "r", encoding="utf-8") as f: return f.read() # ────────────────────────────────────────────── # Model Initialization # ────────────────────────────────────────────── print("🚀 Loading DINOv2 custom model...") model_device_tuple = load_model() print("✅ Model loaded and ready for inference!") # warm-up on startup load_classification_model() # --- Load classification model & labels once at startup --- MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json") with open(MAP_PATH, "r") as f: ID2NAME = json.load(f) cls_model = load_model() print("✅ Classification model loaded and ready for inference!") # ────────────────────────────────────────────── # API Endpoints # ────────────────────────────────────────────── @app.post("/attention") async def generate_attention(file: UploadFile = File(...)): """Generate and return mean attention map for uploaded image.""" image_bytes = await file.read() save_image_to_hub(image_bytes) result = predict_from_bytes(model_device_tuple, image_bytes) return result # @app.post("/classify") # async def classify( # file: UploadFile = File(...), # model: str = Form("dino") # <--- Read 'model' from FormData (default 'dino') # ): # image_bytes = await file.read() # save_image_to_hub(image_bytes) # if model == "yolo": # print("🧠 Running YOLOv11 Inference...") # return classify_yolo_bytes(image_bytes) # else: # print("🦕 Running DINOv2 Inference...") # return classify_bytes(image_bytes) @app.post("/classify") async def classify( file: UploadFile = File(...), model: str = Form("dino") ): image_bytes = await file.read() save_image_to_hub(image_bytes) # 1. First, check if it is OOD (only if detector is loaded) ood_info = None if ood_detector: pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") ood_info = ood_detector.predict(pil_img) # 2. Run standard classification if model == "yolo": response = classify_yolo_bytes(image_bytes) else: response = classify_bytes(image_bytes) # 3. Attach OOD info to the response if ood_info: response["ood_metadata"] = ood_info return response @app.get("/api") def api_root(): return {"message": "NEMO Tools backend running."} # ────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)