Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| from PIL import Image | |
| import torch | |
| import io | |
| # 1. Initialize API and load model into RAM | |
| app = FastAPI(title="AgriSmart Disease API") | |
| model_name = "dsett-ml/BengalCropDisease-finetuned-vit" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| # 2. Health Check Endpoint | |
| # Logic: Provides a simple GET route to verify the container is running | |
| def read_root(): | |
| return {"status": "Active", "model": "Vision Transformer loaded"} | |
| # 3. Prediction Endpoint | |
| # Logic: Intercepts POST requests containing image files | |
| async def predict_disease(file: UploadFile = File(...)): | |
| # Read network stream into RAM | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Process into tensors | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Execute inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Calculate probabilities | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| confidence, predicted_idx = torch.max(probabilities, dim=1) | |
| # Extract label mapping | |
| predicted_label = model.config.id2label[predicted_idx.item()] | |
| # Return JSON response | |
| return { | |
| "disease": predicted_label, | |
| "confidence": float(confidence.item()) | |
| } |