AdarshRajDS commited on
Commit
4bb02cf
·
1 Parent(s): 78d34ab

Recreate clean HF Space using HF Dataset for reference images

Browse files
Files changed (4) hide show
  1. Dockerfile +2 -11
  2. app.py +58 -22
  3. dino.py +31 -12
  4. requirements.txt +2 -0
Dockerfile CHANGED
@@ -1,31 +1,22 @@
1
  FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
-
5
- # Make local modules importable
6
  ENV PYTHONPATH=/app
7
 
8
- # Install system dependencies
9
  RUN apt-get update && apt-get install -y \
10
  build-essential \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
- # Copy requirements first for better caching
14
  COPY requirements.txt .
15
 
16
- # Install Python dependencies
17
  RUN pip install --no-cache-dir --upgrade pip && \
18
- pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
 
19
  pip install --no-cache-dir -r requirements.txt
20
 
21
- # Copy ALL application files
22
  COPY *.py ./
23
-
24
- # Copy model weights
25
  COPY resnet50_multitask_mold.pth ./
26
 
27
- # Expose HF Spaces port
28
  EXPOSE 7860
29
 
30
- # Run the app
31
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.11-slim
2
 
3
  WORKDIR /app
 
 
4
  ENV PYTHONPATH=/app
5
 
 
6
  RUN apt-get update && apt-get install -y \
7
  build-essential \
8
  && rm -rf /var/lib/apt/lists/*
9
 
 
10
  COPY requirements.txt .
11
 
 
12
  RUN pip install --no-cache-dir --upgrade pip && \
13
+ pip install --no-cache-dir torch torchvision \
14
+ --index-url https://download.pytorch.org/whl/cpu && \
15
  pip install --no-cache-dir -r requirements.txt
16
 
 
17
  COPY *.py ./
 
 
18
  COPY resnet50_multitask_mold.pth ./
19
 
 
20
  EXPOSE 7860
21
 
 
22
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -2,14 +2,17 @@ from fastapi import FastAPI, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
  import torch, io
5
- from pathlib import Path
6
  from torchvision import transforms
7
 
8
  from model import MultiTaskResNet50
9
- from decision import final_decision #
10
- from advanced_decision import *
 
 
 
 
11
  from gradcam import GradCAM
12
- from dino import *
13
 
14
  app = FastAPI(title="Mold Detection API v2")
15
 
@@ -23,24 +26,49 @@ app.add_middleware(
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  mold_idx = 4
25
 
26
- # Load model
 
 
27
  model = MultiTaskResNet50().to(device)
28
- model.load_state_dict(torch.load("resnet50_multitask_mold.pth", map_location=device))
 
 
29
  model.eval()
30
 
 
31
  # Transforms
 
32
  transform = transforms.Compose([
33
- transforms.Resize((224,224)),
34
  transforms.ToTensor(),
35
- transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
 
 
 
36
  ])
37
 
 
38
  # Grad-CAM
 
39
  gradcam = GradCAM(model, model.backbone.layer4[-1].conv3)
40
 
41
- # DINO
42
- dino = load_dino(device)
43
- mold_embs = build_embeddings(dino, transform, "mold_reference_images", device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @app.post("/predict/v1")
46
  async def predict_v1(file: UploadFile):
@@ -48,22 +76,29 @@ async def predict_v1(file: UploadFile):
48
  img_t = transform(img).to(device)
49
  return final_decision(model, img_t)
50
 
 
51
  @app.post("/predict/v2")
52
  async def predict_v2(file: UploadFile):
 
 
53
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
54
  img_t = transform(img).to(device)
55
 
56
  with torch.no_grad():
57
  out = model(img_t.unsqueeze(0))
58
- cp = torch.softmax(out["class"],1)[0]
59
- bp = torch.softmax(out["bio"],1)[0]
60
 
61
  mold_p = cp[mold_idx].item()
62
- bio_p = bp[1].item()
63
 
64
  mean_p, std_p = mc_uncertainty(model, img_t, mold_idx)
65
- patch_ratio = patch_consistency(model, img, transform, mold_idx, device)
66
- dino_sim = similarity(dino, mold_embs, img, transform, device)
 
 
 
 
67
 
68
  decision = final_decision_v2(
69
  mold_p, bio_p, std_p, patch_ratio, dino_sim
@@ -72,16 +107,17 @@ async def predict_v2(file: UploadFile):
72
  return {
73
  "decision": decision,
74
  "model_outputs": {
75
- "mold_probability": round(mold_p,3),
76
- "biological_probability": round(bio_p,3)
77
  },
78
  "confidence_checks": {
79
- "uncertainty": round(std_p,3),
80
- "patch_ratio": round(patch_ratio,3),
81
- "dino_similarity": round(dino_sim,3)
82
- }
83
  }
84
 
 
85
  @app.post("/explain/gradcam")
86
  async def explain_gradcam(file: UploadFile):
87
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from PIL import Image
4
  import torch, io
 
5
  from torchvision import transforms
6
 
7
  from model import MultiTaskResNet50
8
+ from decision import final_decision
9
+ from advanced_decision import (
10
+ mc_uncertainty,
11
+ patch_consistency,
12
+ final_decision_v2
13
+ )
14
  from gradcam import GradCAM
15
+ from dino import load_dino, build_embeddings, similarity
16
 
17
  app = FastAPI(title="Mold Detection API v2")
18
 
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  mold_idx = 4
28
 
29
+ # ------------------
30
+ # Load main model
31
+ # ------------------
32
  model = MultiTaskResNet50().to(device)
33
+ model.load_state_dict(
34
+ torch.load("resnet50_multitask_mold.pth", map_location=device)
35
+ )
36
  model.eval()
37
 
38
+ # ------------------
39
  # Transforms
40
+ # ------------------
41
  transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
  transforms.ToTensor(),
44
+ transforms.Normalize(
45
+ [0.485, 0.456, 0.406],
46
+ [0.229, 0.224, 0.225]
47
+ )
48
  ])
49
 
50
+ # ------------------
51
  # Grad-CAM
52
+ # ------------------
53
  gradcam = GradCAM(model, model.backbone.layer4[-1].conv3)
54
 
55
+ # ------------------
56
+ # DINO (lazy loaded)
57
+ # ------------------
58
+ dino = None
59
+ mold_embs = None
60
+
61
+
62
+ def ensure_dino():
63
+ global dino, mold_embs
64
+ if dino is None:
65
+ dino = load_dino(device)
66
+ mold_embs = build_embeddings(dino, transform, device)
67
+
68
+
69
+ # ------------------
70
+ # API endpoints
71
+ # ------------------
72
 
73
  @app.post("/predict/v1")
74
  async def predict_v1(file: UploadFile):
 
76
  img_t = transform(img).to(device)
77
  return final_decision(model, img_t)
78
 
79
+
80
  @app.post("/predict/v2")
81
  async def predict_v2(file: UploadFile):
82
+ ensure_dino()
83
+
84
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
85
  img_t = transform(img).to(device)
86
 
87
  with torch.no_grad():
88
  out = model(img_t.unsqueeze(0))
89
+ cp = torch.softmax(out["class"], 1)[0]
90
+ bp = torch.softmax(out["bio"], 1)[0]
91
 
92
  mold_p = cp[mold_idx].item()
93
+ bio_p = bp[1].item()
94
 
95
  mean_p, std_p = mc_uncertainty(model, img_t, mold_idx)
96
+ patch_ratio = patch_consistency(
97
+ model, img, transform, mold_idx, device
98
+ )
99
+ dino_sim = similarity(
100
+ dino, mold_embs, img, transform, device
101
+ )
102
 
103
  decision = final_decision_v2(
104
  mold_p, bio_p, std_p, patch_ratio, dino_sim
 
107
  return {
108
  "decision": decision,
109
  "model_outputs": {
110
+ "mold_probability": round(mold_p, 3),
111
+ "biological_probability": round(bio_p, 3),
112
  },
113
  "confidence_checks": {
114
+ "uncertainty": round(std_p, 3),
115
+ "patch_ratio": round(patch_ratio, 3),
116
+ "dino_similarity": round(dino_sim, 3),
117
+ },
118
  }
119
 
120
+
121
  @app.post("/explain/gradcam")
122
  async def explain_gradcam(file: UploadFile):
123
  img = Image.open(io.BytesIO(await file.read())).convert("RGB")
dino.py CHANGED
@@ -1,28 +1,47 @@
1
- import os
2
- import numpy as np
3
  import torch
4
- import torch.hub
5
  from PIL import Image
 
6
  from sklearn.metrics.pairwise import cosine_similarity
7
 
 
8
  def load_dino(device):
9
- model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
 
 
 
10
  model.eval().to(device)
11
  return model
12
 
13
- def build_embeddings(dino, transform, image_dir, device):
 
 
 
 
 
 
14
  embs = []
15
- for f in os.listdir(image_dir):
16
- if f.lower().endswith((".jpg",".png",".jpeg")):
17
- img = Image.open(os.path.join(image_dir,f)).convert("RGB")
18
- t = transform(img).unsqueeze(0).to(device)
19
- with torch.no_grad():
20
- e = dino(t)
21
- embs.append(e.squeeze().cpu().numpy())
 
 
 
 
 
 
 
 
22
  return np.vstack(embs)
23
 
 
24
  def similarity(dino, mold_embs, image, transform, device):
25
  t = transform(image).unsqueeze(0).to(device)
26
  with torch.no_grad():
27
  e = dino(t).cpu().numpy()
 
28
  return float(cosine_similarity(e, mold_embs).max())
 
 
 
1
  import torch
2
+ import numpy as np
3
  from PIL import Image
4
+ from datasets import load_dataset
5
  from sklearn.metrics.pairwise import cosine_similarity
6
 
7
+
8
  def load_dino(device):
9
+ model = torch.hub.load(
10
+ "facebookresearch/dinov2",
11
+ "dinov2_vits14"
12
+ )
13
  model.eval().to(device)
14
  return model
15
 
16
+
17
+ def build_embeddings(dino, transform, device):
18
+ dataset = load_dataset(
19
+ "AdarshDS/mold-reference-images",
20
+ split="train"
21
+ )
22
+
23
  embs = []
24
+
25
+ for sample in dataset:
26
+ img: Image.Image = sample["image"].convert("RGB")
27
+ t = transform(img).unsqueeze(0).to(device)
28
+
29
+ with torch.no_grad():
30
+ e = dino(t)
31
+
32
+ embs.append(e.squeeze().cpu().numpy())
33
+
34
+ if not embs:
35
+ raise RuntimeError(
36
+ "No reference images found in HF dataset"
37
+ )
38
+
39
  return np.vstack(embs)
40
 
41
+
42
  def similarity(dino, mold_embs, image, transform, device):
43
  t = transform(image).unsqueeze(0).to(device)
44
  with torch.no_grad():
45
  e = dino(t).cpu().numpy()
46
+
47
  return float(cosine_similarity(e, mold_embs).max())
requirements.txt CHANGED
@@ -6,6 +6,8 @@ pillow
6
  numpy<2
7
  python-multipart
8
  scikit-learn
 
 
9
 
10
 
11
 
 
6
  numpy<2
7
  python-multipart
8
  scikit-learn
9
+ scikit-learn
10
+ datasets
11
 
12
 
13