Rohit-Pimpale's picture
Create app.py
3be62c9 verified
from fastapi import FastAPI, UploadFile, File
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import io
# 1. Initialize API and load model into RAM
app = FastAPI(title="AgriSmart Disease API")
model_name = "dsett-ml/BengalCropDisease-finetuned-vit"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
# 2. Health Check Endpoint
# Logic: Provides a simple GET route to verify the container is running
@app.get("/")
def read_root():
return {"status": "Active", "model": "Vision Transformer loaded"}
# 3. Prediction Endpoint
# Logic: Intercepts POST requests containing image files
@app.post("/predict")
async def predict_disease(file: UploadFile = File(...)):
# Read network stream into RAM
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process into tensors
inputs = processor(images=image, return_tensors="pt")
# Execute inference
with torch.no_grad():
outputs = model(**inputs)
# Calculate probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted_idx = torch.max(probabilities, dim=1)
# Extract label mapping
predicted_label = model.config.id2label[predicted_idx.item()]
# Return JSON response
return {
"disease": predicted_label,
"confidence": float(confidence.item())
}