|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
|
from vit_classifier import ViTClassifier |
|
|
import io |
|
|
import uvicorn |
|
|
|
|
|
app = FastAPI(title="ViT Model Deployment") |
|
|
|
|
|
|
|
|
print("Initializing model...") |
|
|
ViTClassifier.get_instance() |
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"status": "online", "message": "ViT Model API is running"} |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
if not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
image_data = await file.read() |
|
|
image_file = io.BytesIO(image_data) |
|
|
|
|
|
classifier = ViTClassifier.get_instance() |
|
|
predicted_class, confidence, all_probs = classifier.predict(image_file) |
|
|
|
|
|
if predicted_class is None: |
|
|
raise HTTPException(status_code=500, detail="Model failed to predict") |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"prediction": predicted_class, |
|
|
"confidence": confidence, |
|
|
"probabilities": all_probs |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |