|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
|
from app.model import load_model, predict_from_bytes |
|
|
from app.inference import load_classification_model, classify_bytes |
|
|
from app.inference import load_classification_model, classify_bytes |
|
|
from app.inference_yolo import classify_yolo_bytes, load_yolo_model |
|
|
|
|
|
from ood_detector import OODDetector |
|
|
from PIL import Image |
|
|
import io |
|
|
import os |
|
|
import uuid |
|
|
from huggingface_hub import HfApi |
|
|
import json, os |
|
|
import hashlib |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="NEMO Tools") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_DIR = os.path.dirname(__file__) |
|
|
STATIC_DIR = os.path.join(BASE_DIR, "static") |
|
|
INDEX_HTML = os.path.join(STATIC_DIR, "index.html") |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") |
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
DATASET_REPO_ID = "AndrewKof/NEMO-user-uploads" |
|
|
|
|
|
api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
OOD_PATH = os.path.join(os.path.dirname(__file__), "OOD_Features") |
|
|
|
|
|
if os.path.exists(OOD_PATH): |
|
|
ood_detector = OODDetector( |
|
|
model_path="Arew99/dinov2-costum", |
|
|
feature_dir=OOD_PATH |
|
|
) |
|
|
print("β
OOD Detector initialized.") |
|
|
else: |
|
|
ood_detector = None |
|
|
print("β οΈ OOD artifacts not found. OOD detection will be skipped.") |
|
|
|
|
|
def save_image_to_hub(image_bytes): |
|
|
""" |
|
|
Uploads image only if it doesn't already exist in the dataset. |
|
|
Uses SHA256 hash of the content to detect duplicates. |
|
|
""" |
|
|
|
|
|
file_hash = hashlib.sha256(image_bytes).hexdigest() |
|
|
|
|
|
|
|
|
filename = f"user_images/{file_hash}.png" |
|
|
|
|
|
try: |
|
|
|
|
|
if api.file_exists(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset"): |
|
|
print(f"Skipping: {filename} already exists in dataset.") |
|
|
return |
|
|
|
|
|
print(f"New image detected. Uploading {filename}...") |
|
|
|
|
|
|
|
|
file_object = io.BytesIO(image_bytes) |
|
|
api.upload_file( |
|
|
path_or_fileobj=file_object, |
|
|
path_in_repo=filename, |
|
|
repo_id=DATASET_REPO_ID, |
|
|
repo_type="dataset" |
|
|
) |
|
|
print("Upload successful!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error checking/uploading image: {e}") |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
def serve_frontend(): |
|
|
"""Serve the web interface.""" |
|
|
with open(INDEX_HTML, "r", encoding="utf-8") as f: |
|
|
return f.read() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π Loading DINOv2 custom model...") |
|
|
model_device_tuple = load_model() |
|
|
print("β
Model loaded and ready for inference!") |
|
|
|
|
|
|
|
|
load_classification_model() |
|
|
|
|
|
|
|
|
MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json") |
|
|
with open(MAP_PATH, "r") as f: |
|
|
ID2NAME = json.load(f) |
|
|
|
|
|
cls_model = load_model() |
|
|
print("β
Classification model loaded and ready for inference!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/attention") |
|
|
async def generate_attention(file: UploadFile = File(...)): |
|
|
"""Generate and return mean attention map for uploaded image.""" |
|
|
image_bytes = await file.read() |
|
|
save_image_to_hub(image_bytes) |
|
|
result = predict_from_bytes(model_device_tuple, image_bytes) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/classify") |
|
|
async def classify( |
|
|
file: UploadFile = File(...), |
|
|
model: str = Form("dino") |
|
|
): |
|
|
image_bytes = await file.read() |
|
|
save_image_to_hub(image_bytes) |
|
|
|
|
|
|
|
|
ood_info = None |
|
|
if ood_detector: |
|
|
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
ood_info = ood_detector.predict(pil_img) |
|
|
|
|
|
|
|
|
if model == "yolo": |
|
|
response = classify_yolo_bytes(image_bytes) |
|
|
else: |
|
|
response = classify_bytes(image_bytes) |
|
|
|
|
|
|
|
|
if ood_info: |
|
|
response["ood_metadata"] = ood_info |
|
|
|
|
|
return response |
|
|
|
|
|
@app.get("/api") |
|
|
def api_root(): |
|
|
return {"message": "NEMO Tools backend running."} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|