BiteWise / app.py
saaddar666's picture
Upload app.py
bc5f1b9 verified
import io
import base64
import torch
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from ultralytics import YOLO
from transformers import AutoImageProcessor, AutoModelForImageClassification
# --- App Config ---
app = FastAPI(
title="Food & Vegetable AI API",
description="Separate APIs for ViT Classification and YOLO Detection",
version="2.1.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Base64ImageRequest(BaseModel):
image: str
# --- Model Loading ---
print("Loading models...")
try:
vit_model = AutoModelForImageClassification.from_pretrained(
"eslamxm/vit-base-food101"
)
vit_processor = AutoImageProcessor.from_pretrained(
"eslamxm/vit-base-food101"
)
yolo_model = YOLO("yolo_fruits_and_vegetables_v3.pt")
print("✓ Models loaded successfully")
except Exception as e:
print(f"✗ Model loading failed: {e}")
vit_model = None
yolo_model = None
# --- Utility ---
def load_image_from_bytes(image_bytes: bytes) -> Image.Image:
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
# --- YOLO Detection ---
def run_yolo(image: Image.Image):
if not yolo_model:
raise HTTPException(status_code=500, detail="YOLO model not loaded")
results = yolo_model(image)
detections = []
summary = {}
for r in results:
for i in range(len(r.boxes)):
label = yolo_model.names[int(r.boxes.cls[i])]
detections.append({
"label": label,
"confidence": float(r.boxes.conf[i]),
"bbox": r.boxes.xyxy[i].tolist()
})
summary[label] = summary.get(label, 0) + 1
return {
"detections": detections,
"summary": summary
}
# --- ViT Classification ---
def run_vit(image: Image.Image):
if not vit_model:
raise HTTPException(status_code=500, detail="ViT model not loaded")
inputs = vit_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = vit_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_id = probs.argmax().item()
return {
"label": vit_model.config.id2label[pred_id],
"confidence": round(probs[0][pred_id].item(), 4)
}
# --- Routes ---
@app.get("/")
async def root():
return {
"message": "API running",
"endpoints": ["/predict-vit", "/predict-yolo"]
}
# ---------- YOLO Endpoint ----------
@app.post("/predict-yolo")
async def predict_yolo(file: UploadFile = File(...)):
try:
image = load_image_from_bytes(await file.read())
return {"detection": run_yolo(image)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------- ViT Endpoint ----------
@app.post("/predict-vit")
async def predict_vit(file: UploadFile = File(...)):
try:
image = load_image_from_bytes(await file.read())
return {"classification": run_vit(image)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------- Base64 Support (optional) ----------
@app.post("/predict-vit-base64")
async def predict_vit_base64(request: Base64ImageRequest):
try:
_, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image)
image_bytes = base64.b64decode(encoded)
image = load_image_from_bytes(image_bytes)
return {"classification": run_vit(image)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"ViT base64 error: {str(e)}")
@app.post("/predict-yolo-base64")
async def predict_yolo_base64(request: Base64ImageRequest):
try:
_, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image)
image_bytes = base64.b64decode(encoded)
image = load_image_from_bytes(image_bytes)
return {"detection": run_yolo(image)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"YOLO base64 error: {str(e)}")
# --- Run ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)