File size: 10,977 Bytes
2761f40 4c4efcd f9d30f0 2761f40 1e6f209 4c4efcd 1e6f209 4c4efcd 2761f40 4c4efcd 2761f40 1e6f209 4c4efcd 2761f40 c6403bc 1e6f209 c6403bc f9d30f0 c6403bc f9d30f0 63ef65b c6403bc 63ef65b f98cb2d 63ef65b 1e6f209 63ef65b f98cb2d 63ef65b f98cb2d 63ef65b f98cb2d 63ef65b f98cb2d 63ef65b 1e6f209 63ef65b 1e6f209 63ef65b 1e6f209 63ef65b f98cb2d 63ef65b 1e6f209 c6403bc 63ef65b 1e6f209 63ef65b 1e6f209 63ef65b f98cb2d 63ef65b 1e6f209 63ef65b f9d30f0 c6403bc f98cb2d c6403bc f9d30f0 f98cb2d 63ef65b 4c4efcd f98cb2d f9d30f0 f98cb2d f9d30f0 f98cb2d c6403bc 4c4efcd 63ef65b c6403bc 63ef65b 4c4efcd 63ef65b c6403bc 4c4efcd c6403bc 4c4efcd f9d30f0 f98cb2d 63ef65b f98cb2d 63ef65b 1e6f209 63ef65b c6403bc 2761f40 4c4efcd 1e6f209 4c4efcd 2761f40 4c4efcd 2761f40 4c4efcd 2761f40 4c4efcd f9d30f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | import io
from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import HTMLResponse
import torch
import torchvision
from torchvision.transforms import InterpolationMode
from huggingface_hub import hf_hub_download
from PIL import Image
app = FastAPI()
# Model configurations mapped to the weights you provided
MODEL_CONFIGS = {
"b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280},
"b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536},
"b5": {"repo": "Shad0wKillar/efficientnet-b5", "file": "EfficientNet_B5_20percent.pth", "features": 2048},
"b7": {"repo": "Shad0wKillar/efficientnet-b7", "file": "EfficientNet_B7_20percent.pth", "features": 2560},
}
def create_model(model_type):
# I matched architectures to the weights in EfficientNet_TransferLearned.zip
if model_type == "b1": model = torchvision.models.efficientnet_b1()
elif model_type == "b3": model = torchvision.models.efficientnet_b3()
elif model_type == "b5": model = torchvision.models.efficientnet_b5()
elif model_type == "b7": model = torchvision.models.efficientnet_b7()
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(in_features=MODEL_CONFIGS[model_type]["features"], out_features=3, bias=True),
)
return model
# I pre-loaded the dictionary for faster response times
loaded_models = {}
for m_type, config in MODEL_CONFIGS.items():
m = create_model(m_type)
path = hf_hub_download(repo_id=config["repo"], filename=config["file"])
m.load_state_dict(torch.load(path, map_location=torch.device("cpu"), weights_only=True))
m.eval()
loaded_models[m_type] = m
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
torchvision.transforms.CenterCrop(240),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_names = ["pizza", "steak", "sushi"]
@app.get("/", response_class=HTMLResponse)
async def read_root():
# I adjusted the CSS flexbox for perfect horizontal and vertical alignment
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>EfficientNet AI - MultiModel</title>
<style>
:root { --bg: #0b0f19; --card: #1e293b; --accent: #3b82f6; --success: #10b981; --amber: #fbbf24; }
body, html { margin: 0; padding: 0; height: 100%; font-family: system-ui, sans-serif; background-color: var(--bg); color: #e5e7eb; overflow: hidden; }
.split-container { display: flex; height: 100vh; width: 100vw; }
.left-panel { flex: 1; padding: 40px; display: flex; flex-direction: column; justify-content: center; border-right: 1px solid #374151; background: #0f172a; }
/* I added flex-direction column and width 100% to ensure true centering */
.right-panel {
flex: 1.2;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
background-color: var(--bg);
position: relative;
text-align: center;
}
.content-width { max-width: 400px; width: 100%; margin: 0 auto; }
select, button { width: 100%; padding: 14px; margin-bottom: 20px; border-radius: 10px; border: 1px solid #374151; background: var(--bg); color: white; font-size: 15px; outline: none; }
.upload-label {
display: flex; flex-direction: column; align-items: center; justify-content: center;
width: 100%; height: 150px; border: 2px dashed #4b5563; border-radius: 15px;
cursor: pointer; transition: 0.3s; margin-bottom: 20px; background: #1e293b44;
}
.upload-label:hover { border-color: var(--accent); background: #1e293b88; }
#imageInput { display: none; }
button { background: var(--accent); font-weight: 700; border: none; transition: 0.2s; }
button:hover { background: #2563eb; transform: translateY(-1px); }
button:disabled { background: #4b5563; opacity: 0.6; }
#preview { width: 100%; border-radius: 12px; display: none; margin-bottom: 20px; border: 1px solid #374151; object-fit: cover; height: 200px; }
.result-display { width: 100%; opacity: 0; transform: translateY(20px); transition: 0.5s ease-out; }
.result-display.show { opacity: 1; transform: translateY(0); }
/* I forced the placeholder to occupy full width for centering */
#statusMsg { width: 100%; text-align: center; }
.placeholder-text { color: #4b5563; font-size: 1.2rem; font-style: italic; width: 100%; display: block; }
.prediction-title { font-size: 4rem; font-weight: 900; color: var(--success); text-transform: uppercase; letter-spacing: -2px; margin: 0; }
.prob-row { display: flex; justify-content: center; gap: 15px; margin-top: 20px; flex-wrap: wrap; padding: 0 20px; }
.prob-pill { background: #1e293b; padding: 8px 15px; border-radius: 20px; border: 1px solid #374151; color: var(--amber); font-family: monospace; font-weight: bold; }
@keyframes pulse { 0% { opacity: 0.5; } 50% { opacity: 1; } 100% { opacity: 0.5; } }
.loading { animation: pulse 1s infinite; color: var(--accent); font-size: 1.5rem; font-weight: bold; width: 100%; text-align: center; }
</style>
</head>
<body>
<div class="split-container">
<div class="left-panel">
<div class="content-width">
<h2 style="margin: 0 0 10px 0; font-size: 2rem;">Classifier</h2>
<p style="color: #9ca3af; margin-bottom: 30px;">Select a model and upload an image to begin.</p>
<select id="modelSelect">
<option value="b1">EfficientNet-B1</option>
<option value="b3">EfficientNet-B3</option>
<option value="b5">EfficientNet-B5</option>
<option value="b7">EfficientNet-B7</option>
</select>
<input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)">
<label for="imageInput" class="upload-label">
<span style="font-size: 32px; margin-bottom: 10px;">📤</span>
<span id="uploadText">Drop or click to upload</span>
</label>
<img id="preview">
<button onclick="testAPI()" id="runBtn">Run Analysis</button>
</div>
</div>
<div class="right-panel" id="resultContainer">
<!-- I ensured this container is the central focus of the right side -->
<div id="statusMsg">
<span class="placeholder-text">Ready for Prediction...</span>
</div>
<div class="result-display" id="resultDisplay">
<div class="prediction-title" id="topPrediction"></div>
<div class="prob-row" id="probList"></div>
</div>
</div>
</div>
<script>
function previewImage(event) {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = () => {
const p = document.getElementById('preview');
p.src = reader.result; p.style.display = 'block';
document.getElementById('uploadText').innerText = file.name;
};
reader.readAsDataURL(file);
}
async function testAPI() {
const file = document.getElementById('imageInput').files[0];
const model = document.getElementById('modelSelect').value;
if (!file) return alert("Please select an image.");
const statusMsg = document.getElementById('statusMsg');
const resultDisplay = document.getElementById('resultDisplay');
const btn = document.getElementById('runBtn');
resultDisplay.classList.remove('show');
statusMsg.innerHTML = '<div class="loading">ANALYZING...</div>';
statusMsg.style.display = 'block';
btn.disabled = true;
const formData = new FormData();
formData.append("file", file);
try {
const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData });
const data = await res.json();
const entries = Object.entries(data);
const best = entries.reduce((a, b) => a[1] > b[1] ? a : b);
document.getElementById('topPrediction').innerText = best[0];
const list = document.getElementById('probList');
list.innerHTML = entries.map(([name, prob]) => `
<div class="prob-pill">${name.toUpperCase()}: ${prob.toFixed(2)}</div>
`).join("");
statusMsg.style.display = 'none';
resultDisplay.classList.add('show');
} catch (e) {
statusMsg.innerHTML = '<span class="placeholder-text" style="color: #ef4444;">Error during analysis.</span>';
} finally {
btn.disabled = false;
}
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.post("/predict")
async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)):
# I kept the prediction logic optimized for LightBox's RAM
if model_type not in loaded_models:
return {"error": "Model not found"}
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(image).unsqueeze(0)
selected_model = loaded_models[model_type]
with torch.no_grad():
logits = selected_model(img_tensor)
probs = torch.softmax(logits, dim=1).squeeze()
return {class_names[i]: float(probs[i]) for i in range(len(class_names))} |