Fix Docker CMD and UI run button
Browse files- Dockerfile +2 -4
- app/main.py +46 -23
- app/static/index.html +25 -4
Dockerfile
CHANGED
|
@@ -9,7 +9,7 @@ COPY requirements.txt .
|
|
| 9 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
|
| 11 |
# Copy your FastAPI app and dinov2 module
|
| 12 |
-
COPY app
|
| 13 |
|
| 14 |
# Set environment variable for module discovery
|
| 15 |
ENV PYTHONPATH=/code
|
|
@@ -17,7 +17,5 @@ ENV PYTHONPATH=/code
|
|
| 17 |
# Expose the default Hugging Face Spaces port
|
| 18 |
EXPOSE 7860
|
| 19 |
|
| 20 |
-
# Run
|
| 21 |
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 9 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
|
| 11 |
# Copy your FastAPI app and dinov2 module
|
| 12 |
+
COPY app /code/app
|
| 13 |
|
| 14 |
# Set environment variable for module discovery
|
| 15 |
ENV PYTHONPATH=/code
|
|
|
|
| 17 |
# Expose the default Hugging Face Spaces port
|
| 18 |
EXPOSE 7860
|
| 19 |
|
| 20 |
+
# β
Run FastAPI app
|
| 21 |
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
|
|
app/main.py
CHANGED
|
@@ -1,32 +1,55 @@
|
|
| 1 |
-
|
| 2 |
-
from fastapi.responses import JSONResponse
|
| 3 |
-
from fastapi.staticfiles import StaticFiles
|
| 4 |
-
import uvicorn
|
| 5 |
-
from .model import load_model, predict_from_bytes
|
| 6 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
@app.post("/predict")
|
| 18 |
async def predict(file: UploadFile = File(...)):
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@app.get("/health")
|
| 24 |
-
async def health():
|
| 25 |
-
return {"status": "ok"}
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/main.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
from fastapi import FastAPI, File, UploadFile
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from transformers import AutoProcessor, Dinov2ForImageClassification
|
| 8 |
+
from torch.nn.functional import softmax
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
app = FastAPI()
|
| 12 |
+
|
| 13 |
+
# Allow frontend to call backend
|
| 14 |
+
app.add_middleware(
|
| 15 |
+
CORSMiddleware,
|
| 16 |
+
allow_origins=["*"],
|
| 17 |
+
allow_credentials=True,
|
| 18 |
+
allow_methods=["*"],
|
| 19 |
+
allow_headers=["*"],
|
| 20 |
+
)
|
| 21 |
|
| 22 |
+
# --- Load model and mapping on startup ---
|
| 23 |
+
print("π Loading model and label mapping...")
|
| 24 |
+
MODEL_ID = "Arew99/dinov2-costum"
|
| 25 |
|
| 26 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 27 |
+
model = Dinov2ForImageClassification.from_pretrained(MODEL_ID)
|
| 28 |
+
model.eval()
|
| 29 |
|
| 30 |
+
# Load id2name.json
|
| 31 |
+
MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
|
| 32 |
+
with open(MAP_PATH, "r") as f:
|
| 33 |
+
id2name = json.load(f)
|
| 34 |
+
|
| 35 |
+
print(f"β Loaded {len(id2name)} labels from id2name.json")
|
| 36 |
+
|
| 37 |
+
@app.get("/")
|
| 38 |
+
def root():
|
| 39 |
+
return {"message": "Welcome to NEMOtools API"}
|
| 40 |
|
| 41 |
@app.post("/predict")
|
| 42 |
async def predict(file: UploadFile = File(...)):
|
| 43 |
+
"""Perform top-5 inference on an uploaded image."""
|
| 44 |
+
image = Image.open(file.file).convert("RGB")
|
| 45 |
+
inputs = processor(images=image, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
logits = model(**inputs).logits.squeeze(0)
|
| 49 |
+
probs, idxs = softmax(logits, dim=0).topk(5)
|
| 50 |
|
| 51 |
+
results = [
|
| 52 |
+
{"label": id2name[str(i)], "confidence": float(p)}
|
| 53 |
+
for p, i in zip(probs, idxs)
|
| 54 |
+
]
|
| 55 |
+
return {"predictions": results}
|
app/static/index.html
CHANGED
|
@@ -174,7 +174,7 @@
|
|
| 174 |
|
| 175 |
<!-- Classification tool -->
|
| 176 |
<div id="tool-classification" class="hidden">
|
| 177 |
-
<div class="flex flex-col
|
| 178 |
<input id="cls-file" type="file" accept="image/*"
|
| 179 |
class="block w-full md:w-auto text-sm text-gray-600
|
| 180 |
file:mr-4 file:py-2 file:px-4
|
|
@@ -182,10 +182,22 @@
|
|
| 182 |
file:text-sm file:font-semibold
|
| 183 |
file:bg-indigo-50 file:text-indigo-600
|
| 184 |
hover:file:bg-indigo-100" />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
</div>
|
| 186 |
-
|
|
|
|
| 187 |
<div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div>
|
| 188 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
</div>
|
| 190 |
</section>
|
| 191 |
</main>
|
|
@@ -324,10 +336,19 @@
|
|
| 324 |
fd.append("file", file);
|
| 325 |
|
| 326 |
try {
|
| 327 |
-
const res = await fetch("/
|
| 328 |
if (!res.ok) throw new Error(`Server error: ${res.status}`);
|
| 329 |
const json = await res.json();
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
} catch (err) {
|
| 332 |
console.error(err);
|
| 333 |
result.textContent = "β Error: " + err.message;
|
|
|
|
| 174 |
|
| 175 |
<!-- Classification tool -->
|
| 176 |
<div id="tool-classification" class="hidden">
|
| 177 |
+
<div class="flex flex-col items-center gap-4 mb-6 justify-center">
|
| 178 |
<input id="cls-file" type="file" accept="image/*"
|
| 179 |
class="block w-full md:w-auto text-sm text-gray-600
|
| 180 |
file:mr-4 file:py-2 file:px-4
|
|
|
|
| 182 |
file:text-sm file:font-semibold
|
| 183 |
file:bg-indigo-50 file:text-indigo-600
|
| 184 |
hover:file:bg-indigo-100" />
|
| 185 |
+
|
| 186 |
+
<!-- π THIS BUTTON TRIGGERS INFERENCE -->
|
| 187 |
+
<button id="runClsBtn"
|
| 188 |
+
onclick="runClassification()"
|
| 189 |
+
class="px-8 py-3 bg-green-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-green-700 transition">
|
| 190 |
+
βΆοΈ Run Classification
|
| 191 |
+
</button>
|
| 192 |
</div>
|
| 193 |
+
|
| 194 |
+
<!-- Where results will appear -->
|
| 195 |
<div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div>
|
| 196 |
</div>
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
<!-- <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div> -->
|
| 200 |
+
</div>
|
| 201 |
</div>
|
| 202 |
</section>
|
| 203 |
</main>
|
|
|
|
| 336 |
fd.append("file", file);
|
| 337 |
|
| 338 |
try {
|
| 339 |
+
const res = await fetch("/predict", { method: "POST", body: fd }); // β
must match FastAPI route
|
| 340 |
if (!res.ok) throw new Error(`Server error: ${res.status}`);
|
| 341 |
const json = await res.json();
|
| 342 |
+
|
| 343 |
+
// β
display top-5 predictions if available
|
| 344 |
+
if (json.predictions) {
|
| 345 |
+
result.innerHTML = "<h3 class='font-semibold text-indigo-600 mb-2'>Top-5 Predictions:</h3>" +
|
| 346 |
+
json.predictions.map(p =>
|
| 347 |
+
`<div>${p.label} β ${(p.confidence * 100).toFixed(2)}%</div>`
|
| 348 |
+
).join("");
|
| 349 |
+
} else {
|
| 350 |
+
result.textContent = "β
Predicted class: " + json.label;
|
| 351 |
+
}
|
| 352 |
} catch (err) {
|
| 353 |
console.error(err);
|
| 354 |
result.textContent = "β Error: " + err.message;
|