NEMOtools / app /main.py
AndrewKof's picture
Add OOD Detector logic and artifacts
16da67b
# app/main.py
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from app.model import load_model, predict_from_bytes
from app.inference import load_classification_model, classify_bytes
from app.inference import load_classification_model, classify_bytes
from app.inference_yolo import classify_yolo_bytes, load_yolo_model
# from app.model import load_model, predict_pca_from_bytes
from ood_detector import OODDetector
from PIL import Image
import io
import os
import uuid
from huggingface_hub import HfApi
import json, os
import hashlib
# ──────────────────────────────────────────────
# FastAPI setup
# ──────────────────────────────────────────────
app = FastAPI(title="NEMO Tools")
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# ──────────────────────────────────────────────
# Static Frontend
# ──────────────────────────────────────────────
BASE_DIR = os.path.dirname(__file__)
STATIC_DIR = os.path.join(BASE_DIR, "static")
INDEX_HTML = os.path.join(STATIC_DIR, "index.html")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# --- CONFIGURATION ---
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_REPO_ID = "AndrewKof/NEMO-user-uploads"
api = HfApi(token=HF_TOKEN)
OOD_PATH = os.path.join(os.path.dirname(__file__), "OOD_Features")
# Check if artifacts exist before loading
if os.path.exists(OOD_PATH):
ood_detector = OODDetector(
model_path="Arew99/dinov2-costum", # Or your local MODEL_DIR
feature_dir=OOD_PATH
)
print("βœ… OOD Detector initialized.")
else:
ood_detector = None
print("⚠️ OOD artifacts not found. OOD detection will be skipped.")
def save_image_to_hub(image_bytes):
"""
Uploads image only if it doesn't already exist in the dataset.
Uses SHA256 hash of the content to detect duplicates.
"""
# 1. Calculate the hash of the image content
file_hash = hashlib.sha256(image_bytes).hexdigest()
# 2. Use the hash as the filename (e.g., "user_images/a1b2c3d4....png")
filename = f"user_images/{file_hash}.png"
try:
# 3. Check if this specific file already exists on the Hub
if api.file_exists(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset"):
print(f"Skipping: {filename} already exists in dataset.")
return # <--- STOP HERE
print(f"New image detected. Uploading {filename}...")
# 4. Upload if it's new
file_object = io.BytesIO(image_bytes)
api.upload_file(
path_or_fileobj=file_object,
path_in_repo=filename,
repo_id=DATASET_REPO_ID,
repo_type="dataset"
)
print("Upload successful!")
except Exception as e:
print(f"Error checking/uploading image: {e}")
@app.get("/", response_class=HTMLResponse)
def serve_frontend():
"""Serve the web interface."""
with open(INDEX_HTML, "r", encoding="utf-8") as f:
return f.read()
# ──────────────────────────────────────────────
# Model Initialization
# ──────────────────────────────────────────────
print("πŸš€ Loading DINOv2 custom model...")
model_device_tuple = load_model()
print("βœ… Model loaded and ready for inference!")
# warm-up on startup
load_classification_model()
# --- Load classification model & labels once at startup ---
MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
with open(MAP_PATH, "r") as f:
ID2NAME = json.load(f)
cls_model = load_model()
print("βœ… Classification model loaded and ready for inference!")
# ──────────────────────────────────────────────
# API Endpoints
# ──────────────────────────────────────────────
@app.post("/attention")
async def generate_attention(file: UploadFile = File(...)):
"""Generate and return mean attention map for uploaded image."""
image_bytes = await file.read()
save_image_to_hub(image_bytes)
result = predict_from_bytes(model_device_tuple, image_bytes)
return result
# @app.post("/classify")
# async def classify(
# file: UploadFile = File(...),
# model: str = Form("dino") # <--- Read 'model' from FormData (default 'dino')
# ):
# image_bytes = await file.read()
# save_image_to_hub(image_bytes)
# if model == "yolo":
# print("🧠 Running YOLOv11 Inference...")
# return classify_yolo_bytes(image_bytes)
# else:
# print("πŸ¦• Running DINOv2 Inference...")
# return classify_bytes(image_bytes)
@app.post("/classify")
async def classify(
file: UploadFile = File(...),
model: str = Form("dino")
):
image_bytes = await file.read()
save_image_to_hub(image_bytes)
# 1. First, check if it is OOD (only if detector is loaded)
ood_info = None
if ood_detector:
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
ood_info = ood_detector.predict(pil_img)
# 2. Run standard classification
if model == "yolo":
response = classify_yolo_bytes(image_bytes)
else:
response = classify_bytes(image_bytes)
# 3. Attach OOD info to the response
if ood_info:
response["ood_metadata"] = ood_info
return response
@app.get("/api")
def api_root():
return {"message": "NEMO Tools backend running."}
# ──────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)