Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import io | |
| from torchvision import transforms | |
| from model_loader import load_model | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| model = None | |
| device = torch.device("cpu") | |
| # --- LOAD MODEL --- | |
| print("--- STARTING SERVER ---") | |
| try: | |
| model = load_model("InceptionViT_best_model.pth") | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ CRITICAL ERROR: {e}") | |
| # --- TRANSFORM --- | |
| # Matches your training code exactly | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def home(): | |
| return {"status": "Running"} | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| return {"error": "Model not loaded"} | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probabilities = torch.nn.functional.softmax(logits, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| return { | |
| "prediction": str(predicted.item()), | |
| "confidence": float(confidence.item()) | |
| } |