AndrewKof commited on
Commit
f40b22b
Β·
1 Parent(s): c6b075f

Fix Docker CMD and UI run button

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -4
  2. app/main.py +46 -23
  3. app/static/index.html +25 -4
Dockerfile CHANGED
@@ -9,7 +9,7 @@ COPY requirements.txt .
9
  RUN pip install --no-cache-dir -r requirements.txt
10
 
11
  # Copy your FastAPI app and dinov2 module
12
- COPY app ./app
13
 
14
  # Set environment variable for module discovery
15
  ENV PYTHONPATH=/code
@@ -17,7 +17,5 @@ ENV PYTHONPATH=/code
17
  # Expose the default Hugging Face Spaces port
18
  EXPOSE 7860
19
 
20
- # Run the FastAPI app
21
  CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
22
-
23
-
 
9
  RUN pip install --no-cache-dir -r requirements.txt
10
 
11
  # Copy your FastAPI app and dinov2 module
12
+ COPY app /code/app
13
 
14
  # Set environment variable for module discovery
15
  ENV PYTHONPATH=/code
 
17
  # Expose the default Hugging Face Spaces port
18
  EXPOSE 7860
19
 
20
+ # βœ… Run FastAPI app
21
  CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
app/main.py CHANGED
@@ -1,32 +1,55 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
- import uvicorn
5
- from .model import load_model, predict_from_bytes
6
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- app = FastAPI(title="DINOv2 Attention Map Demo")
 
 
9
 
10
- MODEL = None
 
 
11
 
12
- @app.on_event("startup")
13
- async def startup_event():
14
- global MODEL
15
- MODEL = load_model()
 
 
 
 
 
 
16
 
17
  @app.post("/predict")
18
  async def predict(file: UploadFile = File(...)):
19
- contents = await file.read()
20
- result = predict_from_bytes(MODEL, contents)
21
- return JSONResponse(result)
22
-
23
- @app.get("/health")
24
- async def health():
25
- return {"status": "ok"}
26
 
27
- # πŸ‘‡ Mount static *after* routes
28
- static_dir = os.path.join(os.path.dirname(__file__), "static")
29
- app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
30
 
31
- if __name__ == "__main__":
32
- uvicorn.run("app.main:app", host="0.0.0.0", port=7860)
 
 
 
 
1
+ # app/main.py
 
 
 
 
2
  import os
3
+ import json
4
+ import torch
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from transformers import AutoProcessor, Dinov2ForImageClassification
8
+ from torch.nn.functional import softmax
9
+ from PIL import Image
10
+
11
+ app = FastAPI()
12
+
13
+ # Allow frontend to call backend
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
 
22
+ # --- Load model and mapping on startup ---
23
+ print("πŸš€ Loading model and label mapping...")
24
+ MODEL_ID = "Arew99/dinov2-costum"
25
 
26
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
27
+ model = Dinov2ForImageClassification.from_pretrained(MODEL_ID)
28
+ model.eval()
29
 
30
+ # Load id2name.json
31
+ MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
32
+ with open(MAP_PATH, "r") as f:
33
+ id2name = json.load(f)
34
+
35
+ print(f"βœ“ Loaded {len(id2name)} labels from id2name.json")
36
+
37
+ @app.get("/")
38
+ def root():
39
+ return {"message": "Welcome to NEMOtools API"}
40
 
41
  @app.post("/predict")
42
  async def predict(file: UploadFile = File(...)):
43
+ """Perform top-5 inference on an uploaded image."""
44
+ image = Image.open(file.file).convert("RGB")
45
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
 
46
 
47
+ with torch.no_grad():
48
+ logits = model(**inputs).logits.squeeze(0)
49
+ probs, idxs = softmax(logits, dim=0).topk(5)
50
 
51
+ results = [
52
+ {"label": id2name[str(i)], "confidence": float(p)}
53
+ for p, i in zip(probs, idxs)
54
+ ]
55
+ return {"predictions": results}
app/static/index.html CHANGED
@@ -174,7 +174,7 @@
174
 
175
  <!-- Classification tool -->
176
  <div id="tool-classification" class="hidden">
177
- <div class="flex flex-col md:flex-row md:items-center gap-4 mb-6 justify-center">
178
  <input id="cls-file" type="file" accept="image/*"
179
  class="block w-full md:w-auto text-sm text-gray-600
180
  file:mr-4 file:py-2 file:px-4
@@ -182,10 +182,22 @@
182
  file:text-sm file:font-semibold
183
  file:bg-indigo-50 file:text-indigo-600
184
  hover:file:bg-indigo-100" />
 
 
 
 
 
 
 
185
  </div>
186
-
 
187
  <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div>
188
  </div>
 
 
 
 
189
  </div>
190
  </section>
191
  </main>
@@ -324,10 +336,19 @@
324
  fd.append("file", file);
325
 
326
  try {
327
- const res = await fetch("/classify", { method: "POST", body: fd });
328
  if (!res.ok) throw new Error(`Server error: ${res.status}`);
329
  const json = await res.json();
330
- result.textContent = "βœ… Predicted class: " + json.label;
 
 
 
 
 
 
 
 
 
331
  } catch (err) {
332
  console.error(err);
333
  result.textContent = "❌ Error: " + err.message;
 
174
 
175
  <!-- Classification tool -->
176
  <div id="tool-classification" class="hidden">
177
+ <div class="flex flex-col items-center gap-4 mb-6 justify-center">
178
  <input id="cls-file" type="file" accept="image/*"
179
  class="block w-full md:w-auto text-sm text-gray-600
180
  file:mr-4 file:py-2 file:px-4
 
182
  file:text-sm file:font-semibold
183
  file:bg-indigo-50 file:text-indigo-600
184
  hover:file:bg-indigo-100" />
185
+
186
+ <!-- πŸ‘‡ THIS BUTTON TRIGGERS INFERENCE -->
187
+ <button id="runClsBtn"
188
+ onclick="runClassification()"
189
+ class="px-8 py-3 bg-green-600 text-white text-lg font-semibold rounded-full shadow-md hover:bg-green-700 transition">
190
+ ▢️ Run Classification
191
+ </button>
192
  </div>
193
+
194
+ <!-- Where results will appear -->
195
  <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div>
196
  </div>
197
+
198
+
199
+ <!-- <div id="cls-result" class="text-center text-gray-700 mt-4 text-lg font-medium"></div> -->
200
+ </div>
201
  </div>
202
  </section>
203
  </main>
 
336
  fd.append("file", file);
337
 
338
  try {
339
+ const res = await fetch("/predict", { method: "POST", body: fd }); // βœ… must match FastAPI route
340
  if (!res.ok) throw new Error(`Server error: ${res.status}`);
341
  const json = await res.json();
342
+
343
+ // βœ… display top-5 predictions if available
344
+ if (json.predictions) {
345
+ result.innerHTML = "<h3 class='font-semibold text-indigo-600 mb-2'>Top-5 Predictions:</h3>" +
346
+ json.predictions.map(p =>
347
+ `<div>${p.label} β€” ${(p.confidence * 100).toFixed(2)}%</div>`
348
+ ).join("");
349
+ } else {
350
+ result.textContent = "βœ… Predicted class: " + json.label;
351
+ }
352
  } catch (err) {
353
  console.error(err);
354
  result.textContent = "❌ Error: " + err.message;