AndrewKof commited on
Commit
3cac439
Β·
1 Parent(s): ef01405

Fix DINOv2 processor and update requirements

Browse files
Files changed (2) hide show
  1. app/main.py +74 -23
  2. requirements.txt +1 -1
app/main.py CHANGED
@@ -1,16 +1,31 @@
1
  # app/main.py
2
  import os
3
  import json
 
 
4
  import torch
5
  from fastapi import FastAPI, File, UploadFile
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from transformers import AutoImageProcessor, Dinov2ForImageClassification
 
 
 
 
 
8
  from torch.nn.functional import softmax
9
  from PIL import Image
10
 
11
- app = FastAPI()
 
 
 
 
 
 
 
 
12
 
13
- # Allow frontend to call backend
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -19,44 +34,80 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
22
- # --- Load model and mapping on startup ---
 
 
 
 
 
 
 
 
 
 
 
23
  print("πŸš€ Loading model and label mapping...")
24
 
25
  MODEL_ID = "Arew99/dinov2-costum"
26
 
27
- print("πŸš€ Loading model and label mapping...")
28
  model = Dinov2ForImageClassification.from_pretrained(
29
  MODEL_ID,
30
  num_labels=101,
31
- ignore_mismatched_sizes=True
32
  )
33
- processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
34
-
35
  model.eval()
36
 
37
- # Load id2name.json
38
- MAP_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
39
- with open(MAP_PATH, "r") as f:
40
- id2name = json.load(f)
41
 
 
 
 
42
  print(f"βœ“ Loaded {len(id2name)} labels from id2name.json")
43
 
44
- @app.get("/")
45
- def root():
46
- return {"message": "Welcome to NEMOtools API"}
47
 
 
 
 
48
  @app.post("/predict")
49
  async def predict(file: UploadFile = File(...)):
50
- """Perform top-5 inference on an uploaded image."""
51
- image = Image.open(file.file).convert("RGB")
52
- inputs = processor(images=image, return_tensors="pt")
 
 
53
 
54
  with torch.no_grad():
55
- logits = model(**inputs).logits.squeeze(0)
56
  probs, idxs = softmax(logits, dim=0).topk(5)
57
 
58
- results = [
59
- {"label": id2name[str(i)], "confidence": float(p)}
60
- for p, i in zip(probs, idxs)
61
- ]
 
62
  return {"predictions": results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app/main.py
2
  import os
3
  import json
4
+ from pathlib import Path
5
+
6
  import torch
7
  from fastapi import FastAPI, File, UploadFile
8
  from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import HTMLResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+ from transformers import (
12
+ Dinov2ForImageClassification,
13
+ Dinov2ImageProcessor, # <-- needs the newer transformers
14
+ )
15
  from torch.nn.functional import softmax
16
  from PIL import Image
17
 
18
+ # -------------------------------------------------
19
+ # paths
20
+ # -------------------------------------------------
21
+ BASE_DIR = Path(__file__).parent
22
+ STATIC_DIR = BASE_DIR / "static"
23
+ INDEX_HTML = STATIC_DIR / "index.html"
24
+ MAP_PATH = BASE_DIR / "id2name.json"
25
+
26
+ app = FastAPI(title="NEMO Tools")
27
 
28
+ # CORS so the JS can call us
29
  app.add_middleware(
30
  CORSMiddleware,
31
  allow_origins=["*"],
 
34
  allow_headers=["*"],
35
  )
36
 
37
+ # serve /static/*
38
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
39
+
40
+
41
+ @app.get("/", response_class=HTMLResponse)
42
+ def serve_frontend():
43
+ return INDEX_HTML.read_text(encoding="utf-8")
44
+
45
+
46
+ # -------------------------------------------------
47
+ # load model + processor + labels ONCE
48
+ # -------------------------------------------------
49
  print("πŸš€ Loading model and label mapping...")
50
 
51
  MODEL_ID = "Arew99/dinov2-costum"
52
 
53
+ # model: your fine-tuned one
54
  model = Dinov2ForImageClassification.from_pretrained(
55
  MODEL_ID,
56
  num_labels=101,
57
+ ignore_mismatched_sizes=True,
58
  )
 
 
59
  model.eval()
60
 
61
+ # processor: from the ORIGINAL dino repo (not your custom one)
62
+ processor = Dinov2ImageProcessor.from_pretrained("facebook/dinov2-large")
 
 
63
 
64
+ # labels
65
+ with MAP_PATH.open("r") as f:
66
+ id2name = json.load(f)
67
  print(f"βœ“ Loaded {len(id2name)} labels from id2name.json")
68
 
 
 
 
69
 
70
+ # -------------------------------------------------
71
+ # endpoints
72
+ # -------------------------------------------------
73
  @app.post("/predict")
74
  async def predict(file: UploadFile = File(...)):
75
+ # this is your β€œtop-5 for an image” endpoint
76
+ img = Image.open(file.file).convert("RGB")
77
+
78
+ # Dinov2ImageProcessor wants a list β†’ [img]
79
+ inputs = processor(images=[img], return_tensors="pt")
80
 
81
  with torch.no_grad():
82
+ logits = model(**inputs).logits[0] # shape [101]
83
  probs, idxs = softmax(logits, dim=0).topk(5)
84
 
85
+ results = []
86
+ for p, i in zip(probs.tolist(), idxs.tolist()):
87
+ label = id2name.get(str(i), f"Class {i}")
88
+ results.append({"label": label, "confidence": p})
89
+
90
  return {"predictions": results}
91
+
92
+
93
+ @app.post("/classify")
94
+ async def classify(file: UploadFile = File(...)):
95
+ img = Image.open(file.file).convert("RGB")
96
+ inputs = processor(images=[img], return_tensors="pt")
97
+
98
+ with torch.no_grad():
99
+ logits = model(**inputs).logits[0]
100
+ pred = int(logits.argmax().item())
101
+
102
+ return {"label": id2name.get(str(pred), f"Class {pred}")}
103
+
104
+
105
+ @app.get("/api")
106
+ def api_root():
107
+ return {"message": "NEMO Tools backend is running."}
108
+
109
+
110
+ if __name__ == "__main__":
111
+ import uvicorn
112
+
113
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -8,7 +8,7 @@ pillow
8
  numpy
9
 
10
  # Hugging Face bits
11
- transformers
12
  huggingface-hub
13
  peft
14
 
 
8
  numpy
9
 
10
  # Hugging Face bits
11
+ transformers>=4.42.0
12
  huggingface-hub
13
  peft
14