| import io |
| from fastapi import FastAPI, UploadFile, File, Query |
| from fastapi.responses import HTMLResponse |
| import torch |
| import torchvision |
| from torchvision.transforms import InterpolationMode |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
| app = FastAPI() |
|
|
| |
| MODEL_CONFIGS = { |
| "b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280}, |
| "b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536}, |
| "b5": {"repo": "Shad0wKillar/efficientnet-b5", "file": "EfficientNet_B5_20percent.pth", "features": 2048}, |
| "b7": {"repo": "Shad0wKillar/efficientnet-b7", "file": "EfficientNet_B7_20percent.pth", "features": 2560}, |
| } |
|
|
| def create_model(model_type): |
| |
| if model_type == "b1": model = torchvision.models.efficientnet_b1() |
| elif model_type == "b3": model = torchvision.models.efficientnet_b3() |
| elif model_type == "b5": model = torchvision.models.efficientnet_b5() |
| elif model_type == "b7": model = torchvision.models.efficientnet_b7() |
| |
| model.classifier = torch.nn.Sequential( |
| torch.nn.Dropout(p=0.2, inplace=True), |
| torch.nn.Linear(in_features=MODEL_CONFIGS[model_type]["features"], out_features=3, bias=True), |
| ) |
| return model |
|
|
| |
| loaded_models = {} |
| for m_type, config in MODEL_CONFIGS.items(): |
| m = create_model(m_type) |
| path = hf_hub_download(repo_id=config["repo"], filename=config["file"]) |
| m.load_state_dict(torch.load(path, map_location=torch.device("cpu"), weights_only=True)) |
| m.eval() |
| loaded_models[m_type] = m |
|
|
| transform = torchvision.transforms.Compose([ |
| torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR), |
| torchvision.transforms.CenterCrop(240), |
| torchvision.transforms.ToTensor(), |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| class_names = ["pizza", "steak", "sushi"] |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def read_root(): |
| |
| html_content = """ |
| <!DOCTYPE html> |
| <html lang="en"> |
| <head> |
| <meta charset="UTF-8"> |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| <title>EfficientNet AI - MultiModel</title> |
| <style> |
| :root { --bg: #0b0f19; --card: #1e293b; --accent: #3b82f6; --success: #10b981; --amber: #fbbf24; } |
| body, html { margin: 0; padding: 0; height: 100%; font-family: system-ui, sans-serif; background-color: var(--bg); color: #e5e7eb; overflow: hidden; } |
| |
| .split-container { display: flex; height: 100vh; width: 100vw; } |
| |
| .left-panel { flex: 1; padding: 40px; display: flex; flex-direction: column; justify-content: center; border-right: 1px solid #374151; background: #0f172a; } |
| |
| /* I added flex-direction column and width 100% to ensure true centering */ |
| .right-panel { |
| flex: 1.2; |
| display: flex; |
| flex-direction: column; |
| align-items: center; |
| justify-content: center; |
| background-color: var(--bg); |
| position: relative; |
| text-align: center; |
| } |
| |
| .content-width { max-width: 400px; width: 100%; margin: 0 auto; } |
| |
| select, button { width: 100%; padding: 14px; margin-bottom: 20px; border-radius: 10px; border: 1px solid #374151; background: var(--bg); color: white; font-size: 15px; outline: none; } |
| |
| .upload-label { |
| display: flex; flex-direction: column; align-items: center; justify-content: center; |
| width: 100%; height: 150px; border: 2px dashed #4b5563; border-radius: 15px; |
| cursor: pointer; transition: 0.3s; margin-bottom: 20px; background: #1e293b44; |
| } |
| .upload-label:hover { border-color: var(--accent); background: #1e293b88; } |
| #imageInput { display: none; } |
| |
| button { background: var(--accent); font-weight: 700; border: none; transition: 0.2s; } |
| button:hover { background: #2563eb; transform: translateY(-1px); } |
| button:disabled { background: #4b5563; opacity: 0.6; } |
| |
| #preview { width: 100%; border-radius: 12px; display: none; margin-bottom: 20px; border: 1px solid #374151; object-fit: cover; height: 200px; } |
| |
| .result-display { width: 100%; opacity: 0; transform: translateY(20px); transition: 0.5s ease-out; } |
| .result-display.show { opacity: 1; transform: translateY(0); } |
| |
| /* I forced the placeholder to occupy full width for centering */ |
| #statusMsg { width: 100%; text-align: center; } |
| .placeholder-text { color: #4b5563; font-size: 1.2rem; font-style: italic; width: 100%; display: block; } |
| |
| .prediction-title { font-size: 4rem; font-weight: 900; color: var(--success); text-transform: uppercase; letter-spacing: -2px; margin: 0; } |
| .prob-row { display: flex; justify-content: center; gap: 15px; margin-top: 20px; flex-wrap: wrap; padding: 0 20px; } |
| .prob-pill { background: #1e293b; padding: 8px 15px; border-radius: 20px; border: 1px solid #374151; color: var(--amber); font-family: monospace; font-weight: bold; } |
| |
| @keyframes pulse { 0% { opacity: 0.5; } 50% { opacity: 1; } 100% { opacity: 0.5; } } |
| .loading { animation: pulse 1s infinite; color: var(--accent); font-size: 1.5rem; font-weight: bold; width: 100%; text-align: center; } |
| </style> |
| </head> |
| <body> |
| <div class="split-container"> |
| <div class="left-panel"> |
| <div class="content-width"> |
| <h2 style="margin: 0 0 10px 0; font-size: 2rem;">Classifier</h2> |
| <p style="color: #9ca3af; margin-bottom: 30px;">Select a model and upload an image to begin.</p> |
| |
| <select id="modelSelect"> |
| <option value="b1">EfficientNet-B1</option> |
| <option value="b3">EfficientNet-B3</option> |
| <option value="b5">EfficientNet-B5</option> |
| <option value="b7">EfficientNet-B7</option> |
| </select> |
| |
| <input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)"> |
| <label for="imageInput" class="upload-label"> |
| <span style="font-size: 32px; margin-bottom: 10px;">📤</span> |
| <span id="uploadText">Drop or click to upload</span> |
| </label> |
| |
| <img id="preview"> |
| <button onclick="testAPI()" id="runBtn">Run Analysis</button> |
| </div> |
| </div> |
| |
| <div class="right-panel" id="resultContainer"> |
| <!-- I ensured this container is the central focus of the right side --> |
| <div id="statusMsg"> |
| <span class="placeholder-text">Ready for Prediction...</span> |
| </div> |
| <div class="result-display" id="resultDisplay"> |
| <div class="prediction-title" id="topPrediction"></div> |
| <div class="prob-row" id="probList"></div> |
| </div> |
| </div> |
| </div> |
| |
| <script> |
| function previewImage(event) { |
| const file = event.target.files[0]; |
| if (!file) return; |
| const reader = new FileReader(); |
| reader.onload = () => { |
| const p = document.getElementById('preview'); |
| p.src = reader.result; p.style.display = 'block'; |
| document.getElementById('uploadText').innerText = file.name; |
| }; |
| reader.readAsDataURL(file); |
| } |
| |
| async function testAPI() { |
| const file = document.getElementById('imageInput').files[0]; |
| const model = document.getElementById('modelSelect').value; |
| if (!file) return alert("Please select an image."); |
| |
| const statusMsg = document.getElementById('statusMsg'); |
| const resultDisplay = document.getElementById('resultDisplay'); |
| const btn = document.getElementById('runBtn'); |
| |
| resultDisplay.classList.remove('show'); |
| statusMsg.innerHTML = '<div class="loading">ANALYZING...</div>'; |
| statusMsg.style.display = 'block'; |
| btn.disabled = true; |
| |
| const formData = new FormData(); |
| formData.append("file", file); |
| |
| try { |
| const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData }); |
| const data = await res.json(); |
| |
| const entries = Object.entries(data); |
| const best = entries.reduce((a, b) => a[1] > b[1] ? a : b); |
| |
| document.getElementById('topPrediction').innerText = best[0]; |
| |
| const list = document.getElementById('probList'); |
| list.innerHTML = entries.map(([name, prob]) => ` |
| <div class="prob-pill">${name.toUpperCase()}: ${prob.toFixed(2)}</div> |
| `).join(""); |
| |
| statusMsg.style.display = 'none'; |
| resultDisplay.classList.add('show'); |
| } catch (e) { |
| statusMsg.innerHTML = '<span class="placeholder-text" style="color: #ef4444;">Error during analysis.</span>'; |
| } finally { |
| btn.disabled = false; |
| } |
| } |
| </script> |
| </body> |
| </html> |
| """ |
| return HTMLResponse(content=html_content) |
|
|
| @app.post("/predict") |
| async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)): |
| |
| if model_type not in loaded_models: |
| return {"error": "Model not found"} |
| |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| img_tensor = transform(image).unsqueeze(0) |
| |
| selected_model = loaded_models[model_type] |
| with torch.no_grad(): |
| logits = selected_model(img_tensor) |
| probs = torch.softmax(logits, dim=1).squeeze() |
| |
| return {class_names[i]: float(probs[i]) for i in range(len(class_names))} |