Spaces:
Sleeping
Sleeping
| import io | |
| import base64 | |
| import torch | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from ultralytics import YOLO | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # --- App Config --- | |
| app = FastAPI( | |
| title="Food & Vegetable AI API", | |
| description="Separate APIs for ViT Classification and YOLO Detection", | |
| version="2.1.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class Base64ImageRequest(BaseModel): | |
| image: str | |
| # --- Model Loading --- | |
| print("Loading models...") | |
| try: | |
| vit_model = AutoModelForImageClassification.from_pretrained( | |
| "eslamxm/vit-base-food101" | |
| ) | |
| vit_processor = AutoImageProcessor.from_pretrained( | |
| "eslamxm/vit-base-food101" | |
| ) | |
| yolo_model = YOLO("yolo_fruits_and_vegetables_v3.pt") | |
| print("✓ Models loaded successfully") | |
| except Exception as e: | |
| print(f"✗ Model loading failed: {e}") | |
| vit_model = None | |
| yolo_model = None | |
| # --- Utility --- | |
| def load_image_from_bytes(image_bytes: bytes) -> Image.Image: | |
| return Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # --- YOLO Detection --- | |
| def run_yolo(image: Image.Image): | |
| if not yolo_model: | |
| raise HTTPException(status_code=500, detail="YOLO model not loaded") | |
| results = yolo_model(image) | |
| detections = [] | |
| summary = {} | |
| for r in results: | |
| for i in range(len(r.boxes)): | |
| label = yolo_model.names[int(r.boxes.cls[i])] | |
| detections.append({ | |
| "label": label, | |
| "confidence": float(r.boxes.conf[i]), | |
| "bbox": r.boxes.xyxy[i].tolist() | |
| }) | |
| summary[label] = summary.get(label, 0) + 1 | |
| return { | |
| "detections": detections, | |
| "summary": summary | |
| } | |
| # --- ViT Classification --- | |
| def run_vit(image: Image.Image): | |
| if not vit_model: | |
| raise HTTPException(status_code=500, detail="ViT model not loaded") | |
| inputs = vit_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = vit_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| pred_id = probs.argmax().item() | |
| return { | |
| "label": vit_model.config.id2label[pred_id], | |
| "confidence": round(probs[0][pred_id].item(), 4) | |
| } | |
| # --- Routes --- | |
| async def root(): | |
| return { | |
| "message": "API running", | |
| "endpoints": ["/predict-vit", "/predict-yolo"] | |
| } | |
| # ---------- YOLO Endpoint ---------- | |
| async def predict_yolo(file: UploadFile = File(...)): | |
| try: | |
| image = load_image_from_bytes(await file.read()) | |
| return {"detection": run_yolo(image)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ---------- ViT Endpoint ---------- | |
| async def predict_vit(file: UploadFile = File(...)): | |
| try: | |
| image = load_image_from_bytes(await file.read()) | |
| return {"classification": run_vit(image)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ---------- Base64 Support (optional) ---------- | |
| async def predict_vit_base64(request: Base64ImageRequest): | |
| try: | |
| _, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image) | |
| image_bytes = base64.b64decode(encoded) | |
| image = load_image_from_bytes(image_bytes) | |
| return {"classification": run_vit(image)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"ViT base64 error: {str(e)}") | |
| async def predict_yolo_base64(request: Base64ImageRequest): | |
| try: | |
| _, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image) | |
| image_bytes = base64.b64decode(encoded) | |
| image = load_image_from_bytes(image_bytes) | |
| return {"detection": run_yolo(image)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"YOLO base64 error: {str(e)}") | |
| # --- Run --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |