File size: 1,255 Bytes
ed57c13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | from fastapi import FastAPI, UploadFile, File, HTTPException
from vit_classifier import ViTClassifier
import io
import uvicorn
app = FastAPI(title="ViT Model Deployment")
# Initialize model on startup
print("Initializing model...")
ViTClassifier.get_instance()
@app.get("/")
def read_root():
return {"status": "online", "message": "ViT Model API is running"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
image_file = io.BytesIO(image_data)
classifier = ViTClassifier.get_instance()
predicted_class, confidence, all_probs = classifier.predict(image_file)
if predicted_class is None:
raise HTTPException(status_code=500, detail="Model failed to predict")
return {
"status": "success",
"prediction": predicted_class,
"confidence": confidence,
"probabilities": all_probs
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |