ViT_MODEL / main.py
pranit144's picture
Update main.py
ed57c13 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from vit_classifier import ViTClassifier
import io
import uvicorn
app = FastAPI(title="ViT Model Deployment")
# Initialize model on startup
print("Initializing model...")
ViTClassifier.get_instance()
@app.get("/")
def read_root():
return {"status": "online", "message": "ViT Model API is running"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
image_file = io.BytesIO(image_data)
classifier = ViTClassifier.get_instance()
predicted_class, confidence, all_probs = classifier.predict(image_file)
if predicted_class is None:
raise HTTPException(status_code=500, detail="Model failed to predict")
return {
"status": "success",
"prediction": predicted_class,
"confidence": confidence,
"probabilities": all_probs
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)