Shad0wKillar commited on
Commit
f9d30f0
·
verified ·
1 Parent(s): c6403bc

Added better UI

Browse files
Files changed (1) hide show
  1. main.py +156 -30
main.py CHANGED
@@ -1,28 +1,23 @@
1
  import io
2
  from fastapi import FastAPI, UploadFile, File
3
- from fastapi.responses import HTMLResponse # Add this import
4
  import torch
5
  import torchvision
6
  from torchvision.transforms import InterpolationMode
7
- from fastapi import FastAPI, UploadFile, File
8
  from huggingface_hub import hf_hub_download
9
  from PIL import Image
10
 
11
  app = FastAPI()
12
 
13
-
14
- # I recreated the exact architecture from your training script
15
  def create_model():
16
  model = torchvision.models.efficientnet_b1()
17
- # I replaced the classifier exactly as you did in training
18
  model.classifier = torch.nn.Sequential(
19
  torch.nn.Dropout(p=0.2, inplace=True),
20
  torch.nn.Linear(in_features=1280, out_features=3, bias=True),
21
  )
22
  return model
23
 
24
-
25
- # I load the model and weights when the Docker container starts
26
  model = create_model()
27
  weights_path = hf_hub_download(
28
  repo_id="Shad0wKillar/efficientnet-b1", filename="EfficientNet_B1_20percent.pth"
@@ -32,7 +27,6 @@ model.load_state_dict(
32
  )
33
  model.eval()
34
 
35
- # I mapped the exact auto_transform sequence you provided
36
  transform = torchvision.transforms.Compose(
37
  [
38
  torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
@@ -49,41 +43,152 @@ class_names = ["pizza", "steak", "sushi"]
49
 
50
  @app.get("/", response_class=HTMLResponse)
51
  async def read_root():
 
52
  html_content = """
53
  <!DOCTYPE html>
54
- <html>
55
  <head>
 
 
56
  <title>Model Testing API</title>
57
  <style>
58
- body { font-family: sans-serif; max-width: 600px; margin: 40px auto; padding: 20px; background: #121212; color: #ffffff; }
59
- .box { border: 1px solid #333; padding: 20px; border-radius: 8px; background: #1e1e1e; }
60
- button { background: #3b82f6; color: white; border: none; padding: 10px 15px; border-radius: 4px; cursor: pointer; margin-top: 10px; }
61
- button:hover { background: #2563eb; }
62
- pre { background: #000; padding: 15px; border-radius: 4px; overflow-x: auto; color: #10b981; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  </style>
64
  </head>
65
  <body>
66
- <div class="box">
67
- <h2>Backend Diagnostic UI</h2>
68
- <p>Upload an image to test the API directly.</p>
69
- <input type="file" id="imageInput" accept="image/jpeg, image/png">
70
- <button onclick="testAPI()">Run Prediction</button>
 
71
 
72
- <h3 style="margin-top: 20px;">Response:</h3>
73
- <pre id="output">Awaiting image...</pre>
 
 
 
 
74
  </div>
75
 
76
  <script>
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  async function testAPI() {
78
  const fileInput = document.getElementById('imageInput');
79
- const output = document.getElementById('output');
 
 
 
80
 
81
  if (fileInput.files.length === 0) {
82
- output.innerText = "Please select an image first.";
83
  return;
84
  }
85
 
86
- output.innerText = "Processing...";
 
 
 
 
 
87
 
88
  const formData = new FormData();
89
  formData.append("file", fileInput.files[0]);
@@ -93,10 +198,34 @@ async def read_root():
93
  method: "POST",
94
  body: formData
95
  });
 
 
 
96
  const data = await response.json();
97
- output.innerText = JSON.stringify(data, null, 2);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  } catch (error) {
99
- output.innerText = "Error hitting API: " + error.message;
 
 
 
 
 
100
  }
101
  }
102
  </script>
@@ -108,16 +237,13 @@ async def read_root():
108
 
109
  @app.post("/predict")
110
  async def predict(file: UploadFile = File(...)):
111
- # I read the incoming bytes into a PIL image
112
  image_bytes = await file.read()
113
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
114
 
115
- # I process the image and run inference
116
  img_tensor = transform(image).unsqueeze(0)
117
 
118
  with torch.no_grad():
119
  logits = model(img_tensor)
120
  probs = torch.softmax(logits, dim=1).squeeze()
121
 
122
- # I format the output as a JSON dictionary
123
- return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
 
1
  import io
2
  from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.responses import HTMLResponse
4
  import torch
5
  import torchvision
6
  from torchvision.transforms import InterpolationMode
 
7
  from huggingface_hub import hf_hub_download
8
  from PIL import Image
9
 
10
  app = FastAPI()
11
 
 
 
12
  def create_model():
13
  model = torchvision.models.efficientnet_b1()
 
14
  model.classifier = torch.nn.Sequential(
15
  torch.nn.Dropout(p=0.2, inplace=True),
16
  torch.nn.Linear(in_features=1280, out_features=3, bias=True),
17
  )
18
  return model
19
 
20
+ # I loaded the model and weights for the container
 
21
  model = create_model()
22
  weights_path = hf_hub_download(
23
  repo_id="Shad0wKillar/efficientnet-b1", filename="EfficientNet_B1_20percent.pth"
 
27
  )
28
  model.eval()
29
 
 
30
  transform = torchvision.transforms.Compose(
31
  [
32
  torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
 
43
 
44
  @app.get("/", response_class=HTMLResponse)
45
  async def read_root():
46
+ # I built the new UI matching the Hugging Face dark theme
47
  html_content = """
48
  <!DOCTYPE html>
49
+ <html lang="en">
50
  <head>
51
+ <meta charset="UTF-8">
52
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
53
  <title>Model Testing API</title>
54
  <style>
55
+ body {
56
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
57
+ background-color: #0b0f19;
58
+ color: #e5e7eb;
59
+ display: flex;
60
+ justify-content: center;
61
+ align-items: center;
62
+ min-height: 100vh;
63
+ margin: 0;
64
+ padding: 20px;
65
+ }
66
+ .container {
67
+ background-color: #1e293b;
68
+ border: 1px solid #374151;
69
+ border-radius: 8px;
70
+ padding: 30px;
71
+ width: 100%;
72
+ max-width: 450px;
73
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
74
+ }
75
+ h2 {
76
+ margin-top: 0;
77
+ font-size: 1.5rem;
78
+ font-weight: 600;
79
+ color: #f3f4f6;
80
+ }
81
+ .subtitle {
82
+ color: #9ca3af;
83
+ font-size: 0.875rem;
84
+ margin-bottom: 20px;
85
+ }
86
+ input[type="file"] {
87
+ display: block;
88
+ width: 100%;
89
+ padding: 10px;
90
+ font-size: 0.875rem;
91
+ color: #9ca3af;
92
+ background-color: #0b0f19;
93
+ border: 1px solid #374151;
94
+ border-radius: 6px;
95
+ cursor: pointer;
96
+ box-sizing: border-box;
97
+ margin-bottom: 15px;
98
+ }
99
+ #preview {
100
+ max-width: 100%;
101
+ max-height: 300px;
102
+ object-fit: cover;
103
+ border-radius: 6px;
104
+ display: none;
105
+ margin-bottom: 15px;
106
+ border: 1px solid #374151;
107
+ }
108
+ button {
109
+ background-color: #ffffff;
110
+ color: #000000;
111
+ font-weight: 600;
112
+ padding: 10px 15px;
113
+ border: none;
114
+ border-radius: 6px;
115
+ cursor: pointer;
116
+ width: 100%;
117
+ transition: background-color 0.2s;
118
+ }
119
+ button:hover { background-color: #e5e7eb; }
120
+ button:disabled { background-color: #9ca3af; cursor: not-allowed; }
121
+
122
+ .result-box {
123
+ margin-top: 20px;
124
+ padding: 15px;
125
+ background-color: #0b0f19;
126
+ border: 1px solid #374151;
127
+ border-radius: 6px;
128
+ display: none;
129
+ text-align: center;
130
+ }
131
+ .prediction {
132
+ font-size: 1.5rem;
133
+ font-weight: 700;
134
+ color: #10b981;
135
+ margin-bottom: 8px;
136
+ }
137
+ .raw-probs {
138
+ font-size: 0.75rem;
139
+ color: #6b7280;
140
+ margin: 0;
141
+ }
142
  </style>
143
  </head>
144
  <body>
145
+ <div class="container">
146
+ <h2>Image Classification</h2>
147
+ <div class="subtitle">Upload an image to test the API endpoint.</div>
148
+
149
+ <input type="file" id="imageInput" accept="image/jpeg, image/png" onchange="previewImage(event)">
150
+ <img id="preview" alt="Image preview">
151
 
152
+ <button onclick="testAPI()" id="runBtn">Run Prediction</button>
153
+
154
+ <div class="result-box" id="resultBox">
155
+ <div class="prediction" id="topPrediction"></div>
156
+ <div class="raw-probs" id="rawProbs"></div>
157
+ </div>
158
  </div>
159
 
160
  <script>
161
+ function previewImage(event) {
162
+ const reader = new FileReader();
163
+ reader.onload = function(){
164
+ const preview = document.getElementById('preview');
165
+ preview.src = reader.result;
166
+ preview.style.display = 'block';
167
+ document.getElementById('resultBox').style.display = 'none';
168
+ };
169
+ if (event.target.files[0]) {
170
+ reader.readAsDataURL(event.target.files[0]);
171
+ }
172
+ }
173
+
174
  async function testAPI() {
175
  const fileInput = document.getElementById('imageInput');
176
+ const resultBox = document.getElementById('resultBox');
177
+ const topPrediction = document.getElementById('topPrediction');
178
+ const rawProbs = document.getElementById('rawProbs');
179
+ const runBtn = document.getElementById('runBtn');
180
 
181
  if (fileInput.files.length === 0) {
182
+ alert("Please select an image first.");
183
  return;
184
  }
185
 
186
+ runBtn.innerText = "Processing...";
187
+ runBtn.disabled = true;
188
+ resultBox.style.display = 'block';
189
+ topPrediction.innerText = "Analyzing...";
190
+ topPrediction.style.color = "#e5e7eb";
191
+ rawProbs.innerText = "";
192
 
193
  const formData = new FormData();
194
  formData.append("file", fileInput.files[0]);
 
198
  method: "POST",
199
  body: formData
200
  });
201
+
202
+ if (!response.ok) throw new Error("API request failed");
203
+
204
  const data = await response.json();
205
+
206
+ let highestClass = "";
207
+ let highestProb = -1;
208
+ let probsList = [];
209
+
210
+ for (const [className, prob] of Object.entries(data)) {
211
+ if (prob > highestProb) {
212
+ highestProb = prob;
213
+ highestClass = className;
214
+ }
215
+ probsList.push(`${className}: ${(prob * 100).toFixed(1)}%`);
216
+ }
217
+
218
+ topPrediction.style.color = "#10b981";
219
+ topPrediction.innerText = highestClass.charAt(0).toUpperCase() + highestClass.slice(1);
220
+ rawProbs.innerText = probsList.join(" • ");
221
+
222
  } catch (error) {
223
+ topPrediction.style.color = "#ef4444";
224
+ topPrediction.innerText = "Error";
225
+ rawProbs.innerText = error.message;
226
+ } finally {
227
+ runBtn.innerText = "Run Prediction";
228
+ runBtn.disabled = false;
229
  }
230
  }
231
  </script>
 
237
 
238
  @app.post("/predict")
239
  async def predict(file: UploadFile = File(...)):
 
240
  image_bytes = await file.read()
241
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
242
 
 
243
  img_tensor = transform(image).unsqueeze(0)
244
 
245
  with torch.no_grad():
246
  logits = model(img_tensor)
247
  probs = torch.softmax(logits, dim=1).squeeze()
248
 
249
+ return {class_names[i]: float(probs[i]) for i in range(len(class_names))}