Shad0wKillar commited on
Commit
4c4efcd
·
verified ·
1 Parent(s): f9d30f0

Added multi-modal support.

Browse files
Files changed (1) hide show
  1. main.py +80 -189
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import io
2
- from fastapi import FastAPI, UploadFile, File
3
  from fastapi.responses import HTMLResponse
4
  import torch
5
  import torchvision
@@ -9,224 +9,112 @@ from PIL import Image
9
 
10
  app = FastAPI()
11
 
12
- def create_model():
13
- model = torchvision.models.efficientnet_b1()
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model.classifier = torch.nn.Sequential(
15
  torch.nn.Dropout(p=0.2, inplace=True),
16
- torch.nn.Linear(in_features=1280, out_features=3, bias=True),
17
  )
18
  return model
19
 
20
- # I loaded the model and weights for the container
21
- model = create_model()
22
- weights_path = hf_hub_download(
23
- repo_id="Shad0wKillar/efficientnet-b1", filename="EfficientNet_B1_20percent.pth"
24
- )
25
- model.load_state_dict(
26
- torch.load(weights_path, map_location=torch.device("cpu"), weights_only=True)
27
- )
28
- model.eval()
29
-
30
- transform = torchvision.transforms.Compose(
31
- [
32
- torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
33
- torchvision.transforms.CenterCrop(240),
34
- torchvision.transforms.ToTensor(),
35
- torchvision.transforms.Normalize(
36
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
37
- ),
38
- ]
39
- )
40
 
41
  class_names = ["pizza", "steak", "sushi"]
42
 
43
-
44
  @app.get("/", response_class=HTMLResponse)
45
  async def read_root():
46
- # I built the new UI matching the Hugging Face dark theme
47
  html_content = """
48
  <!DOCTYPE html>
49
  <html lang="en">
50
  <head>
51
  <meta charset="UTF-8">
52
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
53
- <title>Model Testing API</title>
54
  <style>
55
- body {
56
- font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
57
- background-color: #0b0f19;
58
- color: #e5e7eb;
59
- display: flex;
60
- justify-content: center;
61
- align-items: center;
62
- min-height: 100vh;
63
- margin: 0;
64
- padding: 20px;
65
- }
66
- .container {
67
- background-color: #1e293b;
68
- border: 1px solid #374151;
69
- border-radius: 8px;
70
- padding: 30px;
71
- width: 100%;
72
- max-width: 450px;
73
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
74
- }
75
- h2 {
76
- margin-top: 0;
77
- font-size: 1.5rem;
78
- font-weight: 600;
79
- color: #f3f4f6;
80
- }
81
- .subtitle {
82
- color: #9ca3af;
83
- font-size: 0.875rem;
84
- margin-bottom: 20px;
85
- }
86
- input[type="file"] {
87
- display: block;
88
- width: 100%;
89
- padding: 10px;
90
- font-size: 0.875rem;
91
- color: #9ca3af;
92
- background-color: #0b0f19;
93
- border: 1px solid #374151;
94
- border-radius: 6px;
95
- cursor: pointer;
96
- box-sizing: border-box;
97
- margin-bottom: 15px;
98
- }
99
- #preview {
100
- max-width: 100%;
101
- max-height: 300px;
102
- object-fit: cover;
103
- border-radius: 6px;
104
- display: none;
105
- margin-bottom: 15px;
106
- border: 1px solid #374151;
107
- }
108
- button {
109
- background-color: #ffffff;
110
- color: #000000;
111
- font-weight: 600;
112
- padding: 10px 15px;
113
- border: none;
114
- border-radius: 6px;
115
- cursor: pointer;
116
- width: 100%;
117
- transition: background-color 0.2s;
118
- }
119
- button:hover { background-color: #e5e7eb; }
120
- button:disabled { background-color: #9ca3af; cursor: not-allowed; }
121
-
122
- .result-box {
123
- margin-top: 20px;
124
- padding: 15px;
125
- background-color: #0b0f19;
126
- border: 1px solid #374151;
127
- border-radius: 6px;
128
- display: none;
129
- text-align: center;
130
- }
131
- .prediction {
132
- font-size: 1.5rem;
133
- font-weight: 700;
134
- color: #10b981;
135
- margin-bottom: 8px;
136
- }
137
- .raw-probs {
138
- font-size: 0.75rem;
139
- color: #6b7280;
140
- margin: 0;
141
- }
142
  </style>
143
  </head>
144
  <body>
145
  <div class="container">
146
- <h2>Image Classification</h2>
147
- <div class="subtitle">Upload an image to test the API endpoint.</div>
148
-
149
- <input type="file" id="imageInput" accept="image/jpeg, image/png" onchange="previewImage(event)">
150
- <img id="preview" alt="Image preview">
151
-
 
 
 
 
152
  <button onclick="testAPI()" id="runBtn">Run Prediction</button>
153
-
154
  <div class="result-box" id="resultBox">
155
- <div class="prediction" id="topPrediction"></div>
156
- <div class="raw-probs" id="rawProbs"></div>
157
  </div>
158
  </div>
159
-
160
  <script>
161
  function previewImage(event) {
162
  const reader = new FileReader();
163
- reader.onload = function(){
164
- const preview = document.getElementById('preview');
165
- preview.src = reader.result;
166
- preview.style.display = 'block';
167
- document.getElementById('resultBox').style.display = 'none';
168
  };
169
- if (event.target.files[0]) {
170
- reader.readAsDataURL(event.target.files[0]);
171
- }
172
  }
173
-
174
  async function testAPI() {
175
- const fileInput = document.getElementById('imageInput');
176
- const resultBox = document.getElementById('resultBox');
177
- const topPrediction = document.getElementById('topPrediction');
178
- const rawProbs = document.getElementById('rawProbs');
179
- const runBtn = document.getElementById('runBtn');
180
 
181
- if (fileInput.files.length === 0) {
182
- alert("Please select an image first.");
183
- return;
184
- }
185
 
186
- runBtn.innerText = "Processing...";
187
- runBtn.disabled = true;
188
- resultBox.style.display = 'block';
189
- topPrediction.innerText = "Analyzing...";
190
- topPrediction.style.color = "#e5e7eb";
191
- rawProbs.innerText = "";
192
-
193
  const formData = new FormData();
194
- formData.append("file", fileInput.files[0]);
195
-
196
  try {
197
- const response = await fetch("/predict", {
198
- method: "POST",
199
- body: formData
200
- });
201
-
202
- if (!response.ok) throw new Error("API request failed");
203
-
204
- const data = await response.json();
205
-
206
- let highestClass = "";
207
- let highestProb = -1;
208
- let probsList = [];
209
 
210
- for (const [className, prob] of Object.entries(data)) {
211
- if (prob > highestProb) {
212
- highestProb = prob;
213
- highestClass = className;
214
- }
215
- probsList.push(`${className}: ${(prob * 100).toFixed(1)}%`);
216
- }
217
-
218
- topPrediction.style.color = "#10b981";
219
- topPrediction.innerText = highestClass.charAt(0).toUpperCase() + highestClass.slice(1);
220
- rawProbs.innerText = probsList.join(" • ");
221
-
222
- } catch (error) {
223
- topPrediction.style.color = "#ef4444";
224
- topPrediction.innerText = "Error";
225
- rawProbs.innerText = error.message;
226
- } finally {
227
- runBtn.innerText = "Run Prediction";
228
- runBtn.disabled = false;
229
- }
230
  }
231
  </script>
232
  </body>
@@ -234,16 +122,19 @@ async def read_root():
234
  """
235
  return HTMLResponse(content=html_content)
236
 
237
-
238
  @app.post("/predict")
239
- async def predict(file: UploadFile = File(...)):
 
 
 
 
240
  image_bytes = await file.read()
241
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
242
-
243
  img_tensor = transform(image).unsqueeze(0)
244
-
 
245
  with torch.no_grad():
246
- logits = model(img_tensor)
247
  probs = torch.softmax(logits, dim=1).squeeze()
248
-
249
  return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
 
1
  import io
2
+ from fastapi import FastAPI, UploadFile, File, Query
3
  from fastapi.responses import HTMLResponse
4
  import torch
5
  import torchvision
 
9
 
10
  app = FastAPI()
11
 
12
+ # Map model names to their specific configuration and repo details
13
+ MODEL_CONFIGS = {
14
+ "b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280},
15
+ "b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536},
16
+ "b5": {"repo": "Shad0wKillar/efficientnet-b5", "file": "EfficientNet_B5_20percent.pth", "features": 2048},
17
+ "b7": {"repo": "Shad0wKillar/efficientnet-b7", "file": "EfficientNet_B7_20percent.pth", "features": 2560},
18
+ }
19
+
20
+ def create_model(model_type):
21
+ # I matched the model architecture to the specific version
22
+ if model_type == "b1": model = torchvision.models.efficientnet_b1()
23
+ elif model_type == "b3": model = torchvision.models.efficientnet_b3()
24
+ elif model_type == "b5": model = torchvision.models.efficientnet_b5()
25
+ elif model_type == "b7": model = torchvision.models.efficientnet_b7()
26
+
27
  model.classifier = torch.nn.Sequential(
28
  torch.nn.Dropout(p=0.2, inplace=True),
29
+ torch.nn.Linear(in_features=MODEL_CONFIGS[model_type]["features"], out_features=3, bias=True),
30
  )
31
  return model
32
 
33
+ # I pre-loaded all models into memory for fast access
34
+ loaded_models = {}
35
+ for m_type, config in MODEL_CONFIGS.items():
36
+ m = create_model(m_type)
37
+ path = hf_hub_download(repo_id=config["repo"], filename=config["file"])
38
+ m.load_state_dict(torch.load(path, map_location=torch.device("cpu"), weights_only=True))
39
+ m.eval()
40
+ loaded_models[m_type] = m
41
+
42
+ transform = torchvision.transforms.Compose([
43
+ torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
44
+ torchvision.transforms.CenterCrop(240),
45
+ torchvision.transforms.ToTensor(),
46
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
+ ])
 
 
 
 
 
48
 
49
  class_names = ["pizza", "steak", "sushi"]
50
 
 
51
  @app.get("/", response_class=HTMLResponse)
52
  async def read_root():
53
+ # I added a dropdown to the UI to select the model
54
  html_content = """
55
  <!DOCTYPE html>
56
  <html lang="en">
57
  <head>
58
  <meta charset="UTF-8">
59
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
60
+ <title>EfficientNet Multi-Model API</title>
61
  <style>
62
+ body { font-family: system-ui, sans-serif; background-color: #0b0f19; color: #e5e7eb; display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; padding: 20px; }
63
+ .container { background-color: #1e293b; border: 1px solid #374151; border-radius: 8px; padding: 30px; width: 100%; max-width: 450px; }
64
+ select, input[type="file"], button { width: 100%; padding: 10px; margin-bottom: 15px; border-radius: 6px; border: 1px solid #374151; background: #0b0f19; color: #e5e7eb; box-sizing: border-box; }
65
+ button { background: #ffffff; color: #000; font-weight: 600; cursor: pointer; border: none; }
66
+ #preview { max-width: 100%; border-radius: 6px; display: none; margin-bottom: 15px; }
67
+ .result-box { padding: 15px; background: #0b0f19; border-radius: 6px; display: none; text-align: center; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  </style>
69
  </head>
70
  <body>
71
  <div class="container">
72
+ <h2>Food Classification</h2>
73
+ <label for="modelSelect">Select Model Architecture:</label>
74
+ <select id="modelSelect">
75
+ <option value="b1">EfficientNet-B1</option>
76
+ <option value="b3">EfficientNet-B3</option>
77
+ <option value="b5">EfficientNet-B5</option>
78
+ <option value="b7">EfficientNet-B7</option>
79
+ </select>
80
+ <input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)">
81
+ <img id="preview">
82
  <button onclick="testAPI()" id="runBtn">Run Prediction</button>
 
83
  <div class="result-box" id="resultBox">
84
+ <div id="topPrediction" style="font-size: 1.5rem; color: #10b981; font-weight: 700;"></div>
85
+ <div id="rawProbs" style="font-size: 0.75rem; color: #6b7280; margin-top: 10px;"></div>
86
  </div>
87
  </div>
 
88
  <script>
89
  function previewImage(event) {
90
  const reader = new FileReader();
91
+ reader.onload = () => {
92
+ const p = document.getElementById('preview');
93
+ p.src = reader.result; p.style.display = 'block';
 
 
94
  };
95
+ reader.readAsDataURL(event.target.files[0]);
 
 
96
  }
 
97
  async function testAPI() {
98
+ const file = document.getElementById('imageInput').files[0];
99
+ const model = document.getElementById('modelSelect').value;
100
+ if (!file) return alert("Select an image");
 
 
101
 
102
+ const btn = document.getElementById('runBtn');
103
+ btn.innerText = "Processing..."; btn.disabled = true;
 
 
104
 
 
 
 
 
 
 
 
105
  const formData = new FormData();
106
+ formData.append("file", file);
107
+
108
  try {
109
+ const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData });
110
+ const data = await res.json();
 
 
 
 
 
 
 
 
 
 
111
 
112
+ const best = Object.entries(data).reduce((a, b) => a[1] > b[1] ? a : b);
113
+ document.getElementById('topPrediction').innerText = best[0].toUpperCase();
114
+ document.getElementById('rawProbs').innerText = JSON.stringify(data);
115
+ document.getElementById('resultBox').style.display = 'block';
116
+ } catch (e) { alert("Error: " + e.message); }
117
+ finally { btn.innerText = "Run Prediction"; btn.disabled = false; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  }
119
  </script>
120
  </body>
 
122
  """
123
  return HTMLResponse(content=html_content)
124
 
 
125
  @app.post("/predict")
126
+ async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)):
127
+ # I routed the request to the specific loaded model
128
+ if model_type not in loaded_models:
129
+ return {"error": "Model not found"}
130
+
131
  image_bytes = await file.read()
132
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
133
  img_tensor = transform(image).unsqueeze(0)
134
+
135
+ selected_model = loaded_models[model_type]
136
  with torch.no_grad():
137
+ logits = selected_model(img_tensor)
138
  probs = torch.softmax(logits, dim=1).squeeze()
139
+
140
  return {class_names[i]: float(probs[i]) for i in range(len(class_names))}