Shad0wKillar's picture
Centered in the right window.
1e6f209 verified
import io
from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import HTMLResponse
import torch
import torchvision
from torchvision.transforms import InterpolationMode
from huggingface_hub import hf_hub_download
from PIL import Image
app = FastAPI()
# Model configurations mapped to the weights you provided
MODEL_CONFIGS = {
"b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280},
"b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536},
"b5": {"repo": "Shad0wKillar/efficientnet-b5", "file": "EfficientNet_B5_20percent.pth", "features": 2048},
"b7": {"repo": "Shad0wKillar/efficientnet-b7", "file": "EfficientNet_B7_20percent.pth", "features": 2560},
}
def create_model(model_type):
# I matched architectures to the weights in EfficientNet_TransferLearned.zip
if model_type == "b1": model = torchvision.models.efficientnet_b1()
elif model_type == "b3": model = torchvision.models.efficientnet_b3()
elif model_type == "b5": model = torchvision.models.efficientnet_b5()
elif model_type == "b7": model = torchvision.models.efficientnet_b7()
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(in_features=MODEL_CONFIGS[model_type]["features"], out_features=3, bias=True),
)
return model
# I pre-loaded the dictionary for faster response times
loaded_models = {}
for m_type, config in MODEL_CONFIGS.items():
m = create_model(m_type)
path = hf_hub_download(repo_id=config["repo"], filename=config["file"])
m.load_state_dict(torch.load(path, map_location=torch.device("cpu"), weights_only=True))
m.eval()
loaded_models[m_type] = m
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
torchvision.transforms.CenterCrop(240),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_names = ["pizza", "steak", "sushi"]
@app.get("/", response_class=HTMLResponse)
async def read_root():
# I adjusted the CSS flexbox for perfect horizontal and vertical alignment
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>EfficientNet AI - MultiModel</title>
<style>
:root { --bg: #0b0f19; --card: #1e293b; --accent: #3b82f6; --success: #10b981; --amber: #fbbf24; }
body, html { margin: 0; padding: 0; height: 100%; font-family: system-ui, sans-serif; background-color: var(--bg); color: #e5e7eb; overflow: hidden; }
.split-container { display: flex; height: 100vh; width: 100vw; }
.left-panel { flex: 1; padding: 40px; display: flex; flex-direction: column; justify-content: center; border-right: 1px solid #374151; background: #0f172a; }
/* I added flex-direction column and width 100% to ensure true centering */
.right-panel {
flex: 1.2;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
background-color: var(--bg);
position: relative;
text-align: center;
}
.content-width { max-width: 400px; width: 100%; margin: 0 auto; }
select, button { width: 100%; padding: 14px; margin-bottom: 20px; border-radius: 10px; border: 1px solid #374151; background: var(--bg); color: white; font-size: 15px; outline: none; }
.upload-label {
display: flex; flex-direction: column; align-items: center; justify-content: center;
width: 100%; height: 150px; border: 2px dashed #4b5563; border-radius: 15px;
cursor: pointer; transition: 0.3s; margin-bottom: 20px; background: #1e293b44;
}
.upload-label:hover { border-color: var(--accent); background: #1e293b88; }
#imageInput { display: none; }
button { background: var(--accent); font-weight: 700; border: none; transition: 0.2s; }
button:hover { background: #2563eb; transform: translateY(-1px); }
button:disabled { background: #4b5563; opacity: 0.6; }
#preview { width: 100%; border-radius: 12px; display: none; margin-bottom: 20px; border: 1px solid #374151; object-fit: cover; height: 200px; }
.result-display { width: 100%; opacity: 0; transform: translateY(20px); transition: 0.5s ease-out; }
.result-display.show { opacity: 1; transform: translateY(0); }
/* I forced the placeholder to occupy full width for centering */
#statusMsg { width: 100%; text-align: center; }
.placeholder-text { color: #4b5563; font-size: 1.2rem; font-style: italic; width: 100%; display: block; }
.prediction-title { font-size: 4rem; font-weight: 900; color: var(--success); text-transform: uppercase; letter-spacing: -2px; margin: 0; }
.prob-row { display: flex; justify-content: center; gap: 15px; margin-top: 20px; flex-wrap: wrap; padding: 0 20px; }
.prob-pill { background: #1e293b; padding: 8px 15px; border-radius: 20px; border: 1px solid #374151; color: var(--amber); font-family: monospace; font-weight: bold; }
@keyframes pulse { 0% { opacity: 0.5; } 50% { opacity: 1; } 100% { opacity: 0.5; } }
.loading { animation: pulse 1s infinite; color: var(--accent); font-size: 1.5rem; font-weight: bold; width: 100%; text-align: center; }
</style>
</head>
<body>
<div class="split-container">
<div class="left-panel">
<div class="content-width">
<h2 style="margin: 0 0 10px 0; font-size: 2rem;">Classifier</h2>
<p style="color: #9ca3af; margin-bottom: 30px;">Select a model and upload an image to begin.</p>
<select id="modelSelect">
<option value="b1">EfficientNet-B1</option>
<option value="b3">EfficientNet-B3</option>
<option value="b5">EfficientNet-B5</option>
<option value="b7">EfficientNet-B7</option>
</select>
<input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)">
<label for="imageInput" class="upload-label">
<span style="font-size: 32px; margin-bottom: 10px;">📤</span>
<span id="uploadText">Drop or click to upload</span>
</label>
<img id="preview">
<button onclick="testAPI()" id="runBtn">Run Analysis</button>
</div>
</div>
<div class="right-panel" id="resultContainer">
<!-- I ensured this container is the central focus of the right side -->
<div id="statusMsg">
<span class="placeholder-text">Ready for Prediction...</span>
</div>
<div class="result-display" id="resultDisplay">
<div class="prediction-title" id="topPrediction"></div>
<div class="prob-row" id="probList"></div>
</div>
</div>
</div>
<script>
function previewImage(event) {
const file = event.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = () => {
const p = document.getElementById('preview');
p.src = reader.result; p.style.display = 'block';
document.getElementById('uploadText').innerText = file.name;
};
reader.readAsDataURL(file);
}
async function testAPI() {
const file = document.getElementById('imageInput').files[0];
const model = document.getElementById('modelSelect').value;
if (!file) return alert("Please select an image.");
const statusMsg = document.getElementById('statusMsg');
const resultDisplay = document.getElementById('resultDisplay');
const btn = document.getElementById('runBtn');
resultDisplay.classList.remove('show');
statusMsg.innerHTML = '<div class="loading">ANALYZING...</div>';
statusMsg.style.display = 'block';
btn.disabled = true;
const formData = new FormData();
formData.append("file", file);
try {
const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData });
const data = await res.json();
const entries = Object.entries(data);
const best = entries.reduce((a, b) => a[1] > b[1] ? a : b);
document.getElementById('topPrediction').innerText = best[0];
const list = document.getElementById('probList');
list.innerHTML = entries.map(([name, prob]) => `
<div class="prob-pill">${name.toUpperCase()}: ${prob.toFixed(2)}</div>
`).join("");
statusMsg.style.display = 'none';
resultDisplay.classList.add('show');
} catch (e) {
statusMsg.innerHTML = '<span class="placeholder-text" style="color: #ef4444;">Error during analysis.</span>';
} finally {
btn.disabled = false;
}
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.post("/predict")
async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)):
# I kept the prediction logic optimized for LightBox's RAM
if model_type not in loaded_models:
return {"error": "Model not found"}
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(image).unsqueeze(0)
selected_model = loaded_models[model_type]
with torch.no_grad():
logits = selected_model(img_tensor)
probs = torch.softmax(logits, dim=1).squeeze()
return {class_names[i]: float(probs[i]) for i in range(len(class_names))}