from fastapi import FastAPI, UploadFile, File, HTTPException from vit_classifier import ViTClassifier import io import uvicorn app = FastAPI(title="ViT Model Deployment") # Initialize model on startup print("Initializing model...") ViTClassifier.get_instance() @app.get("/") def read_root(): return {"status": "online", "message": "ViT Model API is running"} @app.post("/predict") async def predict(file: UploadFile = File(...)): if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") try: image_data = await file.read() image_file = io.BytesIO(image_data) classifier = ViTClassifier.get_instance() predicted_class, confidence, all_probs = classifier.predict(image_file) if predicted_class is None: raise HTTPException(status_code=500, detail="Model failed to predict") return { "status": "success", "prediction": predicted_class, "confidence": confidence, "probabilities": all_probs } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)