File size: 6,442 Bytes
f40b22b 5b11294 3cac439 c34dda4 361c20d c34dda4 5b11294 16da67b 5a6e2db c34dda4 7d9abb8 c34dda4 f40b22b 3cac439 361c20d 3cac439 f40b22b c34dda4 7179b2d 361c20d 3cac439 361c20d 3cac439 5a6e2db 7d9abb8 5a6e2db 16da67b 5a6e2db 7d9abb8 5a6e2db 7d9abb8 5a6e2db 7d9abb8 5a6e2db 7d9abb8 5a6e2db 7d9abb8 5a6e2db 7d9abb8 5a6e2db 7d9abb8 5a6e2db 3cac439 361c20d c34dda4 361c20d 5a6e2db 361c20d 3cac439 16da67b c34dda4 5b11294 16da67b 5b11294 c34dda4 5a6e2db 16da67b 5b11294 16da67b 5b11294 16da67b c34dda4 3cac439 361c20d 3cac439 361c20d 3cac439 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | # app/main.py
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from app.model import load_model, predict_from_bytes
from app.inference import load_classification_model, classify_bytes
from app.inference import load_classification_model, classify_bytes
from app.inference_yolo import classify_yolo_bytes, load_yolo_model
# from app.model import load_model, predict_pca_from_bytes
from ood_detector import OODDetector
from PIL import Image
import io
import os
import uuid
from huggingface_hub import HfApi
import json, os
import hashlib
# ββββββββββββββββββββββββββββββββββββββββββββββ
# FastAPI setup
# ββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(title="NEMO Tools")
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Static Frontend
# ββββββββββββββββββββββββββββββββββββββββββββββ
BASE_DIR = os.path.dirname(__file__)
STATIC_DIR = os.path.join(BASE_DIR, "static")
INDEX_HTML = os.path.join(STATIC_DIR, "index.html")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# --- CONFIGURATION ---
HF_TOKEN = os.environ.get("HF_TOKEN")
DATASET_REPO_ID = "AndrewKof/NEMO-user-uploads"
api = HfApi(token=HF_TOKEN)
OOD_PATH = os.path.join(os.path.dirname(__file__), "OOD_Features")
# Check if artifacts exist before loading
if os.path.exists(OOD_PATH):
ood_detector = OODDetector(
model_path="Arew99/dinov2-costum", # Or your local MODEL_DIR
feature_dir=OOD_PATH
)
print("β
OOD Detector initialized.")
else:
ood_detector = None
print("β οΈ OOD artifacts not found. OOD detection will be skipped.")
def save_image_to_hub(image_bytes):
"""
Uploads image only if it doesn't already exist in the dataset.
Uses SHA256 hash of the content to detect duplicates.
"""
# 1. Calculate the hash of the image content
file_hash = hashlib.sha256(image_bytes).hexdigest()
# 2. Use the hash as the filename (e.g., "user_images/a1b2c3d4....png")
filename = f"user_images/{file_hash}.png"
try:
# 3. Check if this specific file already exists on the Hub
if api.file_exists(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset"):
print(f"Skipping: {filename} already exists in dataset.")
return # <--- STOP HERE
print(f"New image detected. Uploading {filename}...")
# 4. Upload if it's new
file_object = io.BytesIO(image_bytes)
api.upload_file(
path_or_fileobj=file_object,
path_in_repo=filename,
repo_id=DATASET_REPO_ID,
repo_type="dataset"
)
print("Upload successful!")
except Exception as e:
print(f"Error checking/uploading image: {e}")
@app.get("/", response_class=HTMLResponse)
def serve_frontend():
"""Serve the web interface."""
with open(INDEX_HTML, "r", encoding="utf-8") as f:
return f.read()
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Model Initialization
# ββββββββββββββββββββββββββββββββββββββββββββββ
print("π Loading DINOv2 custom model...")
model_device_tuple = load_model()
print("β
Model loaded and ready for inference!")
# warm-up on startup
load_classification_model()
# --- Load classification model & labels once at startup ---
MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
with open(MAP_PATH, "r") as f:
ID2NAME = json.load(f)
cls_model = load_model()
print("β
Classification model loaded and ready for inference!")
# ββββββββββββββββββββββββββββββββββββββββββββββ
# API Endpoints
# ββββββββββββββββββββββββββββββββββββββββββββββ
@app.post("/attention")
async def generate_attention(file: UploadFile = File(...)):
"""Generate and return mean attention map for uploaded image."""
image_bytes = await file.read()
save_image_to_hub(image_bytes)
result = predict_from_bytes(model_device_tuple, image_bytes)
return result
# @app.post("/classify")
# async def classify(
# file: UploadFile = File(...),
# model: str = Form("dino") # <--- Read 'model' from FormData (default 'dino')
# ):
# image_bytes = await file.read()
# save_image_to_hub(image_bytes)
# if model == "yolo":
# print("π§ Running YOLOv11 Inference...")
# return classify_yolo_bytes(image_bytes)
# else:
# print("π¦ Running DINOv2 Inference...")
# return classify_bytes(image_bytes)
@app.post("/classify")
async def classify(
file: UploadFile = File(...),
model: str = Form("dino")
):
image_bytes = await file.read()
save_image_to_hub(image_bytes)
# 1. First, check if it is OOD (only if detector is loaded)
ood_info = None
if ood_detector:
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
ood_info = ood_detector.predict(pil_img)
# 2. Run standard classification
if model == "yolo":
response = classify_yolo_bytes(image_bytes)
else:
response = classify_bytes(image_bytes)
# 3. Attach OOD info to the response
if ood_info:
response["ood_metadata"] = ood_info
return response
@app.get("/api")
def api_root():
return {"message": "NEMO Tools backend running."}
# ββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|