from fastapi import FastAPI, UploadFile, File from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch import io # 1. Initialize API and load model into RAM app = FastAPI(title="AgriSmart Disease API") model_name = "dsett-ml/BengalCropDisease-finetuned-vit" processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) # 2. Health Check Endpoint # Logic: Provides a simple GET route to verify the container is running @app.get("/") def read_root(): return {"status": "Active", "model": "Vision Transformer loaded"} # 3. Prediction Endpoint # Logic: Intercepts POST requests containing image files @app.post("/predict") async def predict_disease(file: UploadFile = File(...)): # Read network stream into RAM image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Process into tensors inputs = processor(images=image, return_tensors="pt") # Execute inference with torch.no_grad(): outputs = model(**inputs) # Calculate probabilities probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) confidence, predicted_idx = torch.max(probabilities, dim=1) # Extract label mapping predicted_label = model.config.id2label[predicted_idx.item()] # Return JSON response return { "disease": predicted_label, "confidence": float(confidence.item()) }