File size: 10,977 Bytes
2761f40
4c4efcd
f9d30f0
2761f40
 
 
 
 
 
 
 
1e6f209
4c4efcd
 
 
 
 
 
 
 
1e6f209
4c4efcd
 
 
 
 
2761f40
 
4c4efcd
2761f40
 
 
1e6f209
4c4efcd
 
 
 
 
 
 
 
 
 
 
 
 
 
2761f40
 
 
c6403bc
 
1e6f209
c6403bc
 
f9d30f0
c6403bc
f9d30f0
 
63ef65b
c6403bc
63ef65b
 
f98cb2d
63ef65b
 
 
 
1e6f209
 
 
 
 
 
 
 
 
 
 
63ef65b
 
 
 
f98cb2d
 
 
63ef65b
 
f98cb2d
63ef65b
f98cb2d
 
63ef65b
 
 
f98cb2d
63ef65b
 
1e6f209
63ef65b
 
1e6f209
 
 
 
63ef65b
1e6f209
63ef65b
f98cb2d
63ef65b
1e6f209
c6403bc
 
 
63ef65b
 
 
 
 
 
 
1e6f209
63ef65b
 
1e6f209
63ef65b
 
 
 
 
 
 
 
 
 
 
 
f98cb2d
63ef65b
1e6f209
 
 
 
63ef65b
 
 
 
f9d30f0
c6403bc
f98cb2d
c6403bc
f9d30f0
f98cb2d
 
63ef65b
4c4efcd
 
f98cb2d
 
f9d30f0
f98cb2d
f9d30f0
f98cb2d
c6403bc
4c4efcd
 
63ef65b
c6403bc
63ef65b
 
4c4efcd
63ef65b
 
 
 
 
c6403bc
 
4c4efcd
 
c6403bc
4c4efcd
 
f9d30f0
f98cb2d
 
 
 
 
63ef65b
 
 
 
f98cb2d
63ef65b
 
 
1e6f209
63ef65b
 
 
c6403bc
 
 
 
 
 
 
2761f40
4c4efcd
1e6f209
4c4efcd
 
 
2761f40
 
 
4c4efcd
 
2761f40
4c4efcd
2761f40
4c4efcd
f9d30f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import io
from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import HTMLResponse
import torch
import torchvision
from torchvision.transforms import InterpolationMode
from huggingface_hub import hf_hub_download
from PIL import Image

app = FastAPI()

# Model configurations mapped to the weights you provided
MODEL_CONFIGS = {
    "b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280},
    "b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536},
    "b5": {"repo": "Shad0wKillar/efficientnet-b5", "file": "EfficientNet_B5_20percent.pth", "features": 2048},
    "b7": {"repo": "Shad0wKillar/efficientnet-b7", "file": "EfficientNet_B7_20percent.pth", "features": 2560},
}

def create_model(model_type):
    # I matched architectures to the weights in EfficientNet_TransferLearned.zip
    if model_type == "b1": model = torchvision.models.efficientnet_b1()
    elif model_type == "b3": model = torchvision.models.efficientnet_b3()
    elif model_type == "b5": model = torchvision.models.efficientnet_b5()
    elif model_type == "b7": model = torchvision.models.efficientnet_b7()
    
    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=0.2, inplace=True),
        torch.nn.Linear(in_features=MODEL_CONFIGS[model_type]["features"], out_features=3, bias=True),
    )
    return model

# I pre-loaded the dictionary for faster response times
loaded_models = {}
for m_type, config in MODEL_CONFIGS.items():
    m = create_model(m_type)
    path = hf_hub_download(repo_id=config["repo"], filename=config["file"])
    m.load_state_dict(torch.load(path, map_location=torch.device("cpu"), weights_only=True))
    m.eval()
    loaded_models[m_type] = m

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(240),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class_names = ["pizza", "steak", "sushi"]

@app.get("/", response_class=HTMLResponse)
async def read_root():
    # I adjusted the CSS flexbox for perfect horizontal and vertical alignment
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>EfficientNet AI - MultiModel</title>
        <style>
            :root { --bg: #0b0f19; --card: #1e293b; --accent: #3b82f6; --success: #10b981; --amber: #fbbf24; }
            body, html { margin: 0; padding: 0; height: 100%; font-family: system-ui, sans-serif; background-color: var(--bg); color: #e5e7eb; overflow: hidden; }
            
            .split-container { display: flex; height: 100vh; width: 100vw; }
            
            .left-panel { flex: 1; padding: 40px; display: flex; flex-direction: column; justify-content: center; border-right: 1px solid #374151; background: #0f172a; }
            
            /* I added flex-direction column and width 100% to ensure true centering */
            .right-panel { 
                flex: 1.2; 
                display: flex; 
                flex-direction: column;
                align-items: center; 
                justify-content: center; 
                background-color: var(--bg); 
                position: relative;
                text-align: center;
            }

            .content-width { max-width: 400px; width: 100%; margin: 0 auto; }
            
            select, button { width: 100%; padding: 14px; margin-bottom: 20px; border-radius: 10px; border: 1px solid #374151; background: var(--bg); color: white; font-size: 15px; outline: none; }
            
            .upload-label {
                display: flex; flex-direction: column; align-items: center; justify-content: center;
                width: 100%; height: 150px; border: 2px dashed #4b5563; border-radius: 15px;
                cursor: pointer; transition: 0.3s; margin-bottom: 20px; background: #1e293b44;
            }
            .upload-label:hover { border-color: var(--accent); background: #1e293b88; }
            #imageInput { display: none; }
            
            button { background: var(--accent); font-weight: 700; border: none; transition: 0.2s; }
            button:hover { background: #2563eb; transform: translateY(-1px); }
            button:disabled { background: #4b5563; opacity: 0.6; }

            #preview { width: 100%; border-radius: 12px; display: none; margin-bottom: 20px; border: 1px solid #374151; object-fit: cover; height: 200px; }

            .result-display { width: 100%; opacity: 0; transform: translateY(20px); transition: 0.5s ease-out; }
            .result-display.show { opacity: 1; transform: translateY(0); }
            
            /* I forced the placeholder to occupy full width for centering */
            #statusMsg { width: 100%; text-align: center; }
            .placeholder-text { color: #4b5563; font-size: 1.2rem; font-style: italic; width: 100%; display: block; }
            
            .prediction-title { font-size: 4rem; font-weight: 900; color: var(--success); text-transform: uppercase; letter-spacing: -2px; margin: 0; }
            .prob-row { display: flex; justify-content: center; gap: 15px; margin-top: 20px; flex-wrap: wrap; padding: 0 20px; }
            .prob-pill { background: #1e293b; padding: 8px 15px; border-radius: 20px; border: 1px solid #374151; color: var(--amber); font-family: monospace; font-weight: bold; }
            
            @keyframes pulse { 0% { opacity: 0.5; } 50% { opacity: 1; } 100% { opacity: 0.5; } }
            .loading { animation: pulse 1s infinite; color: var(--accent); font-size: 1.5rem; font-weight: bold; width: 100%; text-align: center; }
        </style>
    </head>
    <body>
        <div class="split-container">
            <div class="left-panel">
                <div class="content-width">
                    <h2 style="margin: 0 0 10px 0; font-size: 2rem;">Classifier</h2>
                    <p style="color: #9ca3af; margin-bottom: 30px;">Select a model and upload an image to begin.</p>
                    
                    <select id="modelSelect">
                        <option value="b1">EfficientNet-B1</option>
                        <option value="b3">EfficientNet-B3</option>
                        <option value="b5">EfficientNet-B5</option>
                        <option value="b7">EfficientNet-B7</option>
                    </select>

                    <input type="file" id="imageInput" accept="image/*" onchange="previewImage(event)">
                    <label for="imageInput" class="upload-label">
                        <span style="font-size: 32px; margin-bottom: 10px;">📤</span>
                        <span id="uploadText">Drop or click to upload</span>
                    </label>

                    <img id="preview">
                    <button onclick="testAPI()" id="runBtn">Run Analysis</button>
                </div>
            </div>
            
            <div class="right-panel" id="resultContainer">
                <!-- I ensured this container is the central focus of the right side -->
                <div id="statusMsg">
                    <span class="placeholder-text">Ready for Prediction...</span>
                </div>
                <div class="result-display" id="resultDisplay">
                    <div class="prediction-title" id="topPrediction"></div>
                    <div class="prob-row" id="probList"></div>
                </div>
            </div>
        </div>

        <script>
            function previewImage(event) {
                const file = event.target.files[0];
                if (!file) return;
                const reader = new FileReader();
                reader.onload = () => { 
                    const p = document.getElementById('preview');
                    p.src = reader.result; p.style.display = 'block';
                    document.getElementById('uploadText').innerText = file.name;
                };
                reader.readAsDataURL(file);
            }

            async function testAPI() {
                const file = document.getElementById('imageInput').files[0];
                const model = document.getElementById('modelSelect').value;
                if (!file) return alert("Please select an image.");
                
                const statusMsg = document.getElementById('statusMsg');
                const resultDisplay = document.getElementById('resultDisplay');
                const btn = document.getElementById('runBtn');

                resultDisplay.classList.remove('show');
                statusMsg.innerHTML = '<div class="loading">ANALYZING...</div>';
                statusMsg.style.display = 'block';
                btn.disabled = true;

                const formData = new FormData();
                formData.append("file", file);
                
                try {
                    const res = await fetch(`/predict?model_type=${model}`, { method: "POST", body: formData });
                    const data = await res.json();
                    
                    const entries = Object.entries(data);
                    const best = entries.reduce((a, b) => a[1] > b[1] ? a : b);
                    
                    document.getElementById('topPrediction').innerText = best[0];
                    
                    const list = document.getElementById('probList');
                    list.innerHTML = entries.map(([name, prob]) => `
                        <div class="prob-pill">${name.toUpperCase()}: ${prob.toFixed(2)}</div>
                    `).join("");
                    
                    statusMsg.style.display = 'none';
                    resultDisplay.classList.add('show');
                } catch (e) { 
                    statusMsg.innerHTML = '<span class="placeholder-text" style="color: #ef4444;">Error during analysis.</span>';
                } finally { 
                    btn.disabled = false; 
                }
            }
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

@app.post("/predict")
async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)):
    # I kept the prediction logic optimized for LightBox's RAM
    if model_type not in loaded_models:
        return {"error": "Model not found"}
        
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    img_tensor = transform(image).unsqueeze(0)
    
    selected_model = loaded_models[model_type]
    with torch.no_grad():
        logits = selected_model(img_tensor)
        probs = torch.softmax(logits, dim=1).squeeze()
        
    return {class_names[i]: float(probs[i]) for i in range(len(class_names))}