Shad0wKillar commited on
Commit
f98cb2d
·
verified ·
1 Parent(s): 4c4efcd

Made the UI better.

Browse files
Files changed (1) hide show
  1. main.py +62 -21
main.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
 
10
  app = FastAPI()
11
 
12
- # Map model names to their specific configuration and repo details
13
  MODEL_CONFIGS = {
14
  "b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280},
15
  "b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536},
@@ -18,7 +18,7 @@ MODEL_CONFIGS = {
18
  }
19
 
20
  def create_model(model_type):
21
- # I matched the model architecture to the specific version
22
  if model_type == "b1": model = torchvision.models.efficientnet_b1()
23
  elif model_type == "b3": model = torchvision.models.efficientnet_b3()
24
  elif model_type == "b5": model = torchvision.models.efficientnet_b5()
@@ -30,7 +30,7 @@ def create_model(model_type):
30
  )
31
  return model
32
 
33
- # I pre-loaded all models into memory for fast access
34
  loaded_models = {}
35
  for m_type, config in MODEL_CONFIGS.items():
36
  m = create_model(m_type)
@@ -50,7 +50,7 @@ class_names = ["pizza", "steak", "sushi"]
50
 
51
  @app.get("/", response_class=HTMLResponse)
52
  async def read_root():
53
- # I added a dropdown to the UI to select the model
54
  html_content = """
55
  <!DOCTYPE html>
56
  <html lang="en">
@@ -60,47 +60,79 @@ async def read_root():
60
  <title>EfficientNet Multi-Model API</title>
61
  <style>
62
  body { font-family: system-ui, sans-serif; background-color: #0b0f19; color: #e5e7eb; display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; padding: 20px; }
63
- .container { background-color: #1e293b; border: 1px solid #374151; border-radius: 8px; padding: 30px; width: 100%; max-width: 450px; }
64
- select, input[type="file"], button { width: 100%; padding: 10px; margin-bottom: 15px; border-radius: 6px; border: 1px solid #374151; background: #0b0f19; color: #e5e7eb; box-sizing: border-box; }
65
- button { background: #ffffff; color: #000; font-weight: 600; cursor: pointer; border: none; }
66
- #preview { max-width: 100%; border-radius: 6px; display: none; margin-bottom: 15px; }
67
- .result-box { padding: 15px; background: #0b0f19; border-radius: 6px; display: none; text-align: center; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  </style>
69
  </head>
70
  <body>
71
  <div class="container">
72
- <h2>Food Classification</h2>
73
- <label for="modelSelect">Select Model Architecture:</label>
 
74
  <select id="modelSelect">
75
  <option value="b1">EfficientNet-B1</option>
76
  <option value="b3">EfficientNet-B3</option>
77
  <option value="b5">EfficientNet-B5</option>
78
  <option value="b7">EfficientNet-B7</option>
79
  </select>
 
80
  <input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)">
 
 
 
 
 
81
  <img id="preview">
82
  <button onclick="testAPI()" id="runBtn">Run Prediction</button>
 
83
  <div class="result-box" id="resultBox">
84
- <div id="topPrediction" style="font-size: 1.5rem; color: #10b981; font-weight: 700;"></div>
85
- <div id="rawProbs" style="font-size: 0.75rem; color: #6b7280; margin-top: 10px;"></div>
86
  </div>
87
  </div>
 
88
  <script>
89
  function previewImage(event) {
90
  const reader = new FileReader();
 
 
 
91
  reader.onload = () => {
92
  const p = document.getElementById('preview');
93
- p.src = reader.result; p.style.display = 'block';
 
94
  };
95
- reader.readAsDataURL(event.target.files[0]);
96
  }
 
97
  async function testAPI() {
98
  const file = document.getElementById('imageInput').files[0];
99
  const model = document.getElementById('modelSelect').value;
100
- if (!file) return alert("Select an image");
101
 
102
  const btn = document.getElementById('runBtn');
103
- btn.innerText = "Processing..."; btn.disabled = true;
104
 
105
  const formData = new FormData();
106
  formData.append("file", file);
@@ -109,9 +141,18 @@ async def read_root():
109
  const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData });
110
  const data = await res.json();
111
 
112
- const best = Object.entries(data).reduce((a, b) => a[1] > b[1] ? a : b);
113
- document.getElementById('topPrediction').innerText = best[0].toUpperCase();
114
- document.getElementById('rawProbs').innerText = JSON.stringify(data);
 
 
 
 
 
 
 
 
 
115
  document.getElementById('resultBox').style.display = 'block';
116
  } catch (e) { alert("Error: " + e.message); }
117
  finally { btn.innerText = "Run Prediction"; btn.disabled = false; }
@@ -124,7 +165,7 @@ async def read_root():
124
 
125
  @app.post("/predict")
126
  async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)):
127
- # I routed the request to the specific loaded model
128
  if model_type not in loaded_models:
129
  return {"error": "Model not found"}
130
 
 
9
 
10
  app = FastAPI()
11
 
12
+ # Model configurations for the pre-loaded dictionary
13
  MODEL_CONFIGS = {
14
  "b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280},
15
  "b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536},
 
18
  }
19
 
20
  def create_model(model_type):
21
+ # I matched architectures to their specific feature counts
22
  if model_type == "b1": model = torchvision.models.efficientnet_b1()
23
  elif model_type == "b3": model = torchvision.models.efficientnet_b3()
24
  elif model_type == "b5": model = torchvision.models.efficientnet_b5()
 
30
  )
31
  return model
32
 
33
+ # I pre-loaded the models to avoid cold-start delays
34
  loaded_models = {}
35
  for m_type, config in MODEL_CONFIGS.items():
36
  m = create_model(m_type)
 
50
 
51
  @app.get("/", response_class=HTMLResponse)
52
  async def read_root():
53
+ # I styled a new custom upload area and formatted the probability output
54
  html_content = """
55
  <!DOCTYPE html>
56
  <html lang="en">
 
60
  <title>EfficientNet Multi-Model API</title>
61
  <style>
62
  body { font-family: system-ui, sans-serif; background-color: #0b0f19; color: #e5e7eb; display: flex; justify-content: center; align-items: center; min-height: 100vh; margin: 0; padding: 20px; }
63
+ .container { background-color: #1e293b; border: 1px solid #374151; border-radius: 12px; padding: 30px; width: 100%; max-width: 450px; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3); }
64
+
65
+ /* Styled Select Box */
66
+ select { width: 100%; padding: 12px; margin-bottom: 20px; border-radius: 8px; border: 1px solid #374151; background: #0b0f19; color: #e5e7eb; font-size: 14px; outline: none; }
67
+
68
+ /* Attractive Upload Area */
69
+ .upload-label {
70
+ display: flex; flex-direction: column; align-items: center; justify-content: center;
71
+ width: 100%; height: 120px; border: 2px dashed #4b5563; border-radius: 12px;
72
+ cursor: pointer; transition: all 0.2s ease; margin-bottom: 20px; color: #9ca3af;
73
+ }
74
+ .upload-label:hover { border-color: #3b82f6; background-color: #1a2333; color: #f3f4f6; }
75
+ #imageInput { display: none; }
76
+
77
+ /* Run Button */
78
+ button { width: 100%; padding: 12px; background: #3b82f6; color: white; font-weight: 700; cursor: pointer; border: none; border-radius: 8px; transition: background 0.2s; }
79
+ button:hover { background: #2563eb; }
80
+ button:disabled { background: #4b5563; cursor: not-allowed; }
81
+
82
+ #preview { max-width: 100%; border-radius: 8px; display: none; margin-bottom: 20px; border: 1px solid #374151; }
83
+
84
+ .result-box { padding: 20px; background: #0b0f19; border-radius: 8px; display: none; text-align: center; border: 1px solid #374151; }
85
+ .prob-text { color: #fbbf24; font-family: monospace; font-size: 0.9rem; margin-top: 10px; line-height: 1.5; }
86
  </style>
87
  </head>
88
  <body>
89
  <div class="container">
90
+ <h2 style="margin-top:0">Food Classifier</h2>
91
+
92
+ <label style="font-size: 12px; color: #9ca3af; display: block; margin-bottom: 5px;">Model Architecture</label>
93
  <select id="modelSelect">
94
  <option value="b1">EfficientNet-B1</option>
95
  <option value="b3">EfficientNet-B3</option>
96
  <option value="b5">EfficientNet-B5</option>
97
  <option value="b7">EfficientNet-B7</option>
98
  </select>
99
+
100
  <input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)">
101
+ <label for="imageInput" class="upload-label" id="dropZone">
102
+ <span style="font-size: 24px; margin-bottom: 8px;">📷</span>
103
+ <span id="uploadText">Click to upload image</span>
104
+ </label>
105
+
106
  <img id="preview">
107
  <button onclick="testAPI()" id="runBtn">Run Prediction</button>
108
+
109
  <div class="result-box" id="resultBox">
110
+ <div id="topPrediction" style="font-size: 1.8rem; color: #10b981; font-weight: 800; text-transform: uppercase;"></div>
111
+ <div id="rawProbs" class="prob-text"></div>
112
  </div>
113
  </div>
114
+
115
  <script>
116
  function previewImage(event) {
117
  const reader = new FileReader();
118
+ const file = event.target.files[0];
119
+ if (!file) return;
120
+
121
  reader.onload = () => {
122
  const p = document.getElementById('preview');
123
+ p.src = reader.result; p.style.display = 'block';
124
+ document.getElementById('uploadText').innerText = file.name;
125
  };
126
+ reader.readAsDataURL(file);
127
  }
128
+
129
  async function testAPI() {
130
  const file = document.getElementById('imageInput').files[0];
131
  const model = document.getElementById('modelSelect').value;
132
+ if (!file) return alert("Please select an image first");
133
 
134
  const btn = document.getElementById('runBtn');
135
+ btn.innerText = "Analyzing..."; btn.disabled = true;
136
 
137
  const formData = new FormData();
138
  formData.append("file", file);
 
141
  const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData });
142
  const data = await res.json();
143
 
144
+ // I handled the decimal formatting and class extraction here
145
+ const entries = Object.entries(data);
146
+ const best = entries.reduce((a, b) => a[1] > b[1] ? a : b);
147
+
148
+ document.getElementById('topPrediction').innerText = best[0];
149
+
150
+ // Cleaned up the probabilities display
151
+ const formattedProbs = entries
152
+ .map(([name, prob]) => `${name.toUpperCase()}: ${prob.toFixed(2)}`)
153
+ .join(" | ");
154
+
155
+ document.getElementById('rawProbs').innerText = formattedProbs;
156
  document.getElementById('resultBox').style.display = 'block';
157
  } catch (e) { alert("Error: " + e.message); }
158
  finally { btn.innerText = "Run Prediction"; btn.disabled = false; }
 
165
 
166
  @app.post("/predict")
167
  async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)):
168
+ # I kept the routing logic the same for speed
169
  if model_type not in loaded_models:
170
  return {"error": "Model not found"}
171