import io from fastapi import FastAPI, UploadFile, File, Query from fastapi.responses import HTMLResponse import torch import torchvision from torchvision.transforms import InterpolationMode from huggingface_hub import hf_hub_download from PIL import Image app = FastAPI() # Model configurations mapped to the weights you provided MODEL_CONFIGS = { "b1": {"repo": "Shad0wKillar/efficientnet-b1", "file": "EfficientNet_B1_20percent.pth", "features": 1280}, "b3": {"repo": "Shad0wKillar/efficientnet-b3", "file": "EfficientNet_B3_20percent.pth", "features": 1536}, "b5": {"repo": "Shad0wKillar/efficientnet-b5", "file": "EfficientNet_B5_20percent.pth", "features": 2048}, "b7": {"repo": "Shad0wKillar/efficientnet-b7", "file": "EfficientNet_B7_20percent.pth", "features": 2560}, } def create_model(model_type): # I matched architectures to the weights in EfficientNet_TransferLearned.zip if model_type == "b1": model = torchvision.models.efficientnet_b1() elif model_type == "b3": model = torchvision.models.efficientnet_b3() elif model_type == "b5": model = torchvision.models.efficientnet_b5() elif model_type == "b7": model = torchvision.models.efficientnet_b7() model.classifier = torch.nn.Sequential( torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(in_features=MODEL_CONFIGS[model_type]["features"], out_features=3, bias=True), ) return model # I pre-loaded the dictionary for faster response times loaded_models = {} for m_type, config in MODEL_CONFIGS.items(): m = create_model(m_type) path = hf_hub_download(repo_id=config["repo"], filename=config["file"]) m.load_state_dict(torch.load(path, map_location=torch.device("cpu"), weights_only=True)) m.eval() loaded_models[m_type] = m transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(255, interpolation=InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(240), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class_names = ["pizza", "steak", "sushi"] @app.get("/", response_class=HTMLResponse) async def read_root(): # I adjusted the CSS flexbox for perfect horizontal and vertical alignment html_content = """ EfficientNet AI - MultiModel

Classifier

Select a model and upload an image to begin.

Ready for Prediction...
""" return HTMLResponse(content=html_content) @app.post("/predict") async def predict(model_type: str = Query("b1"), file: UploadFile = File(...)): # I kept the prediction logic optimized for LightBox's RAM if model_type not in loaded_models: return {"error": "Model not found"} image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") img_tensor = transform(image).unsqueeze(0) selected_model = loaded_models[model_type] with torch.no_grad(): logits = selected_model(img_tensor) probs = torch.softmax(logits, dim=1).squeeze() return {class_names[i]: float(probs[i]) for i in range(len(class_names))}