x-ray / main.py
Yousuf-Islam's picture
Update main.py
e919ef3 verified
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
from torchvision import transforms
from model_loader import load_model
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = None
device = torch.device("cpu")
# --- LOAD MODEL ---
print("--- STARTING SERVER ---")
try:
model = load_model("InceptionViT_best_model.pth")
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ CRITICAL ERROR: {e}")
# --- TRANSFORM ---
# Matches your training code exactly
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
@app.get("/")
def home():
return {"status": "Running"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if model is None:
return {"error": "Model not loaded"}
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence, predicted = torch.max(probabilities, 1)
return {
"prediction": str(predicted.item()),
"confidence": float(confidence.item())
}