File size: 6,442 Bytes
f40b22b
5b11294
3cac439
c34dda4
361c20d
c34dda4
5b11294
 
 
16da67b
 
5a6e2db
 
 
 
c34dda4
7d9abb8
c34dda4
f40b22b
3cac439
361c20d
 
 
3cac439
f40b22b
c34dda4
 
 
 
 
 
 
7179b2d
361c20d
 
 
 
 
 
3cac439
361c20d
3cac439
5a6e2db
 
 
 
7d9abb8
5a6e2db
 
 
16da67b
 
 
 
 
 
 
 
 
 
 
 
5a6e2db
 
7d9abb8
 
5a6e2db
7d9abb8
 
 
 
 
5a6e2db
 
7d9abb8
 
 
 
 
 
5a6e2db
7d9abb8
 
5a6e2db
7d9abb8
5a6e2db
 
 
 
 
7d9abb8
5a6e2db
7d9abb8
5a6e2db
3cac439
 
361c20d
 
 
 
 
 
 
 
 
 
 
c34dda4
 
 
 
 
 
 
 
 
 
 
361c20d
 
 
 
 
 
 
5a6e2db
361c20d
 
3cac439
16da67b
 
 
 
 
 
 
 
 
 
 
 
 
 
c34dda4
5b11294
 
16da67b
5b11294
c34dda4
5a6e2db
16da67b
 
 
 
 
 
 
 
5b11294
16da67b
5b11294
16da67b
 
 
 
 
 
 
c34dda4
3cac439
 
361c20d
3cac439
361c20d
3cac439
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# app/main.py
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from app.model import load_model, predict_from_bytes
from app.inference import load_classification_model, classify_bytes
from app.inference import load_classification_model, classify_bytes
from app.inference_yolo import classify_yolo_bytes, load_yolo_model
# from app.model import load_model, predict_pca_from_bytes
from ood_detector import OODDetector
from PIL import Image
import io  
import os
import uuid
from huggingface_hub import HfApi
import json, os
import hashlib



# ──────────────────────────────────────────────
# FastAPI setup
# ──────────────────────────────────────────────
app = FastAPI(title="NEMO Tools")

# app.add_middleware(
#     CORSMiddleware,
#     allow_origins=["*"],
#     allow_credentials=True,
#     allow_methods=["*"],
#     allow_headers=["*"],
# )

# ──────────────────────────────────────────────
# Static Frontend
# ──────────────────────────────────────────────
BASE_DIR = os.path.dirname(__file__)
STATIC_DIR = os.path.join(BASE_DIR, "static")
INDEX_HTML = os.path.join(STATIC_DIR, "index.html")

app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")


# --- CONFIGURATION ---
HF_TOKEN = os.environ.get("HF_TOKEN") 

DATASET_REPO_ID = "AndrewKof/NEMO-user-uploads"

api = HfApi(token=HF_TOKEN)

OOD_PATH = os.path.join(os.path.dirname(__file__), "OOD_Features")
# Check if artifacts exist before loading
if os.path.exists(OOD_PATH):
    ood_detector = OODDetector(
        model_path="Arew99/dinov2-costum", # Or your local MODEL_DIR
        feature_dir=OOD_PATH
    )
    print("βœ… OOD Detector initialized.")
else:
    ood_detector = None
    print("⚠️ OOD artifacts not found. OOD detection will be skipped.")

def save_image_to_hub(image_bytes):
    """
    Uploads image only if it doesn't already exist in the dataset.
    Uses SHA256 hash of the content to detect duplicates.
    """
    # 1. Calculate the hash of the image content
    file_hash = hashlib.sha256(image_bytes).hexdigest()
    
    # 2. Use the hash as the filename (e.g., "user_images/a1b2c3d4....png")
    filename = f"user_images/{file_hash}.png"
    
    try:
        # 3. Check if this specific file already exists on the Hub
        if api.file_exists(repo_id=DATASET_REPO_ID, filename=filename, repo_type="dataset"):
            print(f"Skipping: {filename} already exists in dataset.")
            return  # <--- STOP HERE

        print(f"New image detected. Uploading {filename}...")
        
        # 4. Upload if it's new
        file_object = io.BytesIO(image_bytes)
        api.upload_file(
            path_or_fileobj=file_object,
            path_in_repo=filename,
            repo_id=DATASET_REPO_ID,
            repo_type="dataset"
        )
        print("Upload successful!")
        
    except Exception as e:
        print(f"Error checking/uploading image: {e}")

@app.get("/", response_class=HTMLResponse)
def serve_frontend():
    """Serve the web interface."""
    with open(INDEX_HTML, "r", encoding="utf-8") as f:
        return f.read()

# ──────────────────────────────────────────────
# Model Initialization
# ──────────────────────────────────────────────
print("πŸš€ Loading DINOv2 custom model...")
model_device_tuple = load_model()
print("βœ… Model loaded and ready for inference!")

# warm-up on startup
load_classification_model()

# --- Load classification model & labels once at startup ---
MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
with open(MAP_PATH, "r") as f:
    ID2NAME = json.load(f)

cls_model = load_model()
print("βœ… Classification model loaded and ready for inference!")

# ──────────────────────────────────────────────
# API Endpoints
# ──────────────────────────────────────────────
@app.post("/attention")
async def generate_attention(file: UploadFile = File(...)):
    """Generate and return mean attention map for uploaded image."""
    image_bytes = await file.read()
    save_image_to_hub(image_bytes)
    result = predict_from_bytes(model_device_tuple, image_bytes)
    return result

# @app.post("/classify")
# async def classify(
#     file: UploadFile = File(...), 
#     model: str = Form("dino")  # <--- Read 'model' from FormData (default 'dino')
# ):
#     image_bytes = await file.read()
#     save_image_to_hub(image_bytes)
#     if model == "yolo":
#         print("🧠 Running YOLOv11 Inference...")
#         return classify_yolo_bytes(image_bytes)
#     else:
#         print("πŸ¦• Running DINOv2 Inference...")
#         return classify_bytes(image_bytes)

@app.post("/classify")
async def classify(
    file: UploadFile = File(...), 
    model: str = Form("dino")
):
    image_bytes = await file.read()
    save_image_to_hub(image_bytes)
    
    # 1. First, check if it is OOD (only if detector is loaded)
    ood_info = None
    if ood_detector:
        pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        ood_info = ood_detector.predict(pil_img)
    
    # 2. Run standard classification
    if model == "yolo":
        response = classify_yolo_bytes(image_bytes)
    else:
        response = classify_bytes(image_bytes)
    
    # 3. Attach OOD info to the response
    if ood_info:
        response["ood_metadata"] = ood_info
        
    return response

@app.get("/api")
def api_root():
    return {"message": "NEMO Tools backend running."}

# ──────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)