saaddar666 commited on
Commit
bc5f1b9
·
verified ·
1 Parent(s): 01e0839

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -64
app.py CHANGED
@@ -3,17 +3,16 @@ import base64
3
  import torch
4
  from PIL import Image
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
6
- from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
  from ultralytics import YOLO
10
  from transformers import AutoImageProcessor, AutoModelForImageClassification
11
 
12
- # --- Configuration ---
13
  app = FastAPI(
14
  title="Food & Vegetable AI API",
15
- description="Combined API for Food Classification (ViT) and Fruit/Veg Detection (YOLO)",
16
- version="2.0.0"
17
  )
18
 
19
  app.add_middleware(
@@ -29,83 +28,127 @@ class Base64ImageRequest(BaseModel):
29
 
30
  # --- Model Loading ---
31
  print("Loading models...")
 
32
  try:
33
- # 1. ViT Food Classifier
34
- vit_model = AutoModelForImageClassification.from_pretrained("eslamxm/vit-base-food101")
35
- vit_processor = AutoImageProcessor.from_pretrained("eslamxm/vit-base-food101")
36
-
37
- # 2. YOLO Fruit/Veg Detector
38
- # Ensure this file is in your root directory
39
- yolo_model = YOLO('yolo_fruits_and_vegetables_v3.pt')
40
-
41
- print("✓ All models loaded successfully!")
 
42
  except Exception as e:
43
- print(f"✗ Error loading models: {e}")
44
  vit_model = None
45
  yolo_model = None
46
 
47
- # --- Utility Functions ---
48
- def process_pil_image(image: Image.Image):
49
- """Common logic for YOLO detection and ViT classification."""
50
- results = {"detection": None, "classification": None}
51
-
52
- # YOLO Inference
53
- if yolo_model:
54
- y_results = yolo_model(image)
55
- detections = []
56
- summary = {}
57
- for r in y_results:
58
- for i in range(len(r.boxes)):
59
- label = yolo_model.names[int(r.boxes.cls[i])]
60
- detections.append({
61
- "label": label,
62
- "confidence": float(r.boxes.conf[i]),
63
- "bbox": r.boxes.xyxy[i].tolist()
64
- })
65
- summary[label] = summary.get(label, 0) + 1
66
- results["detection"] = {"detections": detections, "summary": summary}
67
-
68
- # ViT Inference
69
- if vit_model:
70
- inputs = vit_processor(images=image, return_tensors="pt")
71
- with torch.no_grad():
72
- outputs = vit_model(**inputs)
73
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
74
- pred_id = probs.argmax().item()
75
- results["classification"] = {
76
- "label": vit_model.config.id2label[pred_id],
77
- "confidence": round(probs[0][pred_id].item(), 4)
78
- }
79
-
80
- return results
81
-
82
- # --- Endpoints ---
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.get("/")
85
  async def root():
86
- return {"message": "AI API is online. Use /predict-upload or /predict-base64."}
 
 
 
87
 
88
- @app.post("/predict-upload")
89
- async def predict_file(file: UploadFile = File(...)):
90
- """Upload a raw image file for full analysis."""
 
91
  try:
92
- image = Image.open(file.file).convert("RGB")
93
- return process_pil_image(image)
94
  except Exception as e:
95
  raise HTTPException(status_code=500, detail=str(e))
96
 
97
- @app.post("/predict-base64")
98
- async def predict_base64(request: Base64ImageRequest):
99
- """Send a base64 string (useful for mobile apps)."""
 
100
  try:
101
- header, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image)
 
 
 
 
 
 
 
 
 
 
102
  image_bytes = base64.b64decode(encoded)
103
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
104
- return process_pil_image(image)
105
  except Exception as e:
106
- raise HTTPException(status_code=500, detail=f"Base64 processing failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
108
  if __name__ == "__main__":
109
  import uvicorn
110
- # 7860 is the required port for Hugging Face Spaces
111
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  import torch
4
  from PIL import Image
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
  from ultralytics import YOLO
9
  from transformers import AutoImageProcessor, AutoModelForImageClassification
10
 
11
+ # --- App Config ---
12
  app = FastAPI(
13
  title="Food & Vegetable AI API",
14
+ description="Separate APIs for ViT Classification and YOLO Detection",
15
+ version="2.1.0"
16
  )
17
 
18
  app.add_middleware(
 
28
 
29
  # --- Model Loading ---
30
  print("Loading models...")
31
+
32
  try:
33
+ vit_model = AutoModelForImageClassification.from_pretrained(
34
+ "eslamxm/vit-base-food101"
35
+ )
36
+ vit_processor = AutoImageProcessor.from_pretrained(
37
+ "eslamxm/vit-base-food101"
38
+ )
39
+
40
+ yolo_model = YOLO("yolo_fruits_and_vegetables_v3.pt")
41
+
42
+ print("✓ Models loaded successfully")
43
  except Exception as e:
44
+ print(f"✗ Model loading failed: {e}")
45
  vit_model = None
46
  yolo_model = None
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # --- Utility ---
50
+ def load_image_from_bytes(image_bytes: bytes) -> Image.Image:
51
+ return Image.open(io.BytesIO(image_bytes)).convert("RGB")
52
+
53
+
54
+ # --- YOLO Detection ---
55
+ def run_yolo(image: Image.Image):
56
+ if not yolo_model:
57
+ raise HTTPException(status_code=500, detail="YOLO model not loaded")
58
+
59
+ results = yolo_model(image)
60
+
61
+ detections = []
62
+ summary = {}
63
+
64
+ for r in results:
65
+ for i in range(len(r.boxes)):
66
+ label = yolo_model.names[int(r.boxes.cls[i])]
67
+ detections.append({
68
+ "label": label,
69
+ "confidence": float(r.boxes.conf[i]),
70
+ "bbox": r.boxes.xyxy[i].tolist()
71
+ })
72
+ summary[label] = summary.get(label, 0) + 1
73
+
74
+ return {
75
+ "detections": detections,
76
+ "summary": summary
77
+ }
78
+
79
+
80
+ # --- ViT Classification ---
81
+ def run_vit(image: Image.Image):
82
+ if not vit_model:
83
+ raise HTTPException(status_code=500, detail="ViT model not loaded")
84
+
85
+ inputs = vit_processor(images=image, return_tensors="pt")
86
+
87
+ with torch.no_grad():
88
+ outputs = vit_model(**inputs)
89
+
90
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
91
+ pred_id = probs.argmax().item()
92
+
93
+ return {
94
+ "label": vit_model.config.id2label[pred_id],
95
+ "confidence": round(probs[0][pred_id].item(), 4)
96
+ }
97
+
98
+
99
+ # --- Routes ---
100
  @app.get("/")
101
  async def root():
102
+ return {
103
+ "message": "API running",
104
+ "endpoints": ["/predict-vit", "/predict-yolo"]
105
+ }
106
 
107
+
108
+ # ---------- YOLO Endpoint ----------
109
+ @app.post("/predict-yolo")
110
+ async def predict_yolo(file: UploadFile = File(...)):
111
  try:
112
+ image = load_image_from_bytes(await file.read())
113
+ return {"detection": run_yolo(image)}
114
  except Exception as e:
115
  raise HTTPException(status_code=500, detail=str(e))
116
 
117
+
118
+ # ---------- ViT Endpoint ----------
119
+ @app.post("/predict-vit")
120
+ async def predict_vit(file: UploadFile = File(...)):
121
  try:
122
+ image = load_image_from_bytes(await file.read())
123
+ return {"classification": run_vit(image)}
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+
127
+
128
+ # ---------- Base64 Support (optional) ----------
129
+ @app.post("/predict-vit-base64")
130
+ async def predict_vit_base64(request: Base64ImageRequest):
131
+ try:
132
+ _, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image)
133
  image_bytes = base64.b64decode(encoded)
134
+ image = load_image_from_bytes(image_bytes)
135
+ return {"classification": run_vit(image)}
136
  except Exception as e:
137
+ raise HTTPException(status_code=500, detail=f"ViT base64 error: {str(e)}")
138
+
139
+
140
+ @app.post("/predict-yolo-base64")
141
+ async def predict_yolo_base64(request: Base64ImageRequest):
142
+ try:
143
+ _, encoded = request.image.split(",", 1) if "," in request.image else (None, request.image)
144
+ image_bytes = base64.b64decode(encoded)
145
+ image = load_image_from_bytes(image_bytes)
146
+ return {"detection": run_yolo(image)}
147
+ except Exception as e:
148
+ raise HTTPException(status_code=500, detail=f"YOLO base64 error: {str(e)}")
149
+
150
 
151
+ # --- Run ---
152
  if __name__ == "__main__":
153
  import uvicorn
 
154
  uvicorn.run(app, host="0.0.0.0", port=7860)