prathameshks commited on
Commit
3edbe0b
·
1 Parent(s): 689e789

update to yolo

Browse files
Files changed (1) hide show
  1. routers/analysis.py +48 -56
routers/analysis.py CHANGED
@@ -2,33 +2,21 @@ import asyncio
2
  import os
3
  from datetime import datetime
4
  from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
5
- from fastapi.responses import JSONResponse
6
  import pytz
7
  from sqlalchemy.orm import Session
8
  from typing import List, Dict, Any
9
- from db.models import User
10
  from interfaces.ingredientModels import IngredientAnalysisResult, IngredientRequest
11
- from interfaces.productModels import ProductIngredientsRequest
12
  from services.auth_service import get_current_user
13
- from logger_manager import log_info, log_error,logger
14
  from db.database import get_db,SessionLocal
15
- from db.repositories import IngredientRepository
16
  from dotenv import load_dotenv
17
  from langsmith import traceable
18
-
19
- from PIL import Image
20
  import io
21
- import base64
22
- from fastapi.encoders import jsonable_encoder
23
- import uuid
24
- from typing import List
25
- from fastapi import APIRouter, File, Request, UploadFile
26
- from fastapi.responses import JSONResponse
27
- import cv2
28
- import numpy as np
29
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
30
-
31
-
32
  from services.ingredientFinderAgent import IngredientInfoAgentLangGraph
33
  from services.productAnalyzerAgent import analyze_product_ingredients
34
 
@@ -42,20 +30,8 @@ log_info(f"Using parallel rate limit of {PARALLEL_RATE_LIMIT}")
42
  # Create a semaphore to limit concurrent API calls
43
  llm_semaphore = asyncio.Semaphore(PARALLEL_RATE_LIMIT)
44
 
45
-
46
- # SAM model path
47
- SAM_CHECKPOINT = "models/mobile_sam.pt" # Replace with your SAM checkpoint file
48
-
49
- # SAM model setup
50
- sam_checkpoint = SAM_CHECKPOINT
51
- model_type = "vit_t"
52
-
53
- # Load SAM model
54
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
55
-
56
- # Initialize the mask generator
57
- mask_generator = SamAutomaticMaskGenerator(sam)
58
-
59
  UPLOADED_IMAGES_DIR = "uploaded_images"
60
  if not os.path.exists(UPLOADED_IMAGES_DIR):
61
  os.makedirs(UPLOADED_IMAGES_DIR)
@@ -75,41 +51,54 @@ def ingredient_db_to_pydantic(db_ingredient):
75
  details_with_source=[source.data for source in db_ingredient.sources]
76
  )
77
 
78
- def extract_product_from_image(image_path: str) -> str | None:
79
- """
80
- Extracts the product image from an image using SAM.
81
-
82
- Args:
83
- image_path: Path to the input image.
84
 
85
- Returns:
86
- Path to the extracted product image, or None if extraction failed.
87
- """
88
  try:
89
- # Load the image
90
  image = cv2.imread(image_path)
91
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92
 
93
- # Generate masks
94
- masks = mask_generator.generate(image)
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- if not masks:
97
- print("No masks generated.")
98
  return None
99
 
100
- # Find the largest mask
101
- largest_mask = max(masks, key=lambda mask: mask['area'])
102
- mask = largest_mask["segmentation"]
 
103
 
 
 
 
 
 
 
 
 
104
  # Create a masked image
105
  masked_image = np.zeros_like(image)
106
- masked_image[mask] = image[mask]
107
 
108
  # Crop the image
109
- y_coords, x_coords = np.where(mask)
110
  x_min, x_max = np.min(x_coords), np.max(x_coords)
111
  y_min, y_max = np.min(y_coords), np.max(y_coords)
112
- cropped_image = masked_image[y_min:y_max, x_min:x_max]
113
 
114
  # Save the cropped image
115
  cropped_image_path = os.path.join(
@@ -123,6 +112,7 @@ def extract_product_from_image(image_path: str) -> str | None:
123
  print(f"Error during image processing: {e}")
124
  return None
125
 
 
126
  @router.post("/process_image")
127
  async def process_image(image: UploadFile = File(...)):
128
  """
@@ -145,7 +135,7 @@ async def process_image(image: UploadFile = File(...)):
145
  print("Image saved temporarily to:", temp_image_path)
146
 
147
  # Extract the product
148
- extracted_product_path = extract_product_from_image(temp_image_path)
149
 
150
  # Remove the temporary file
151
  os.remove(temp_image_path)
@@ -157,6 +147,7 @@ async def process_image(image: UploadFile = File(...)):
157
  {
158
  "message": "Product extracted successfully",
159
  "product_image_path": extracted_product_path,
 
160
  }
161
  )
162
  else:
@@ -169,6 +160,7 @@ async def process_image(image: UploadFile = File(...)):
169
  print("Error:", e)
170
  return JSONResponse({"error": str(e)}, status_code=500)
171
 
 
172
  # process single ingredient
173
  @router.post("/process_ingredient", response_model=IngredientAnalysisResult)
174
  @traceable
@@ -257,7 +249,7 @@ async def process_single_ingredient(ingredient_name: str):
257
 
258
  @router.post("/process_product_ingredients", response_model=Dict[str, Any])
259
  @traceable
260
- async def process_ingredients_endpoint(product_ingredient: ProductIngredientsRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
261
  log_info(f"process_ingredients_endpoint called for {len(product_ingredient.ingredients)} ingredients")
262
  ingredients = product_ingredient.ingredients
263
  try:
@@ -301,4 +293,4 @@ async def process_ingredients_endpoint(product_ingredient: ProductIngredientsReq
301
 
302
  except Exception as e:
303
  log_error(f"Error in process_ingredients_endpoint: {str(e)}")
304
- raise HTTPException(status_code=500, detail="Internal Server Error")
 
2
  import os
3
  from datetime import datetime
4
  from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
5
+ from fastapi.responses import JSONResponse, FileResponse
6
  import pytz
7
  from sqlalchemy.orm import Session
8
  from typing import List, Dict, Any
9
+ from db.models import User, Ingredient
10
  from interfaces.ingredientModels import IngredientAnalysisResult, IngredientRequest
11
+ from interfaces.productModels import ProductIngredientsRequest,ProductData
12
  from services.auth_service import get_current_user
13
+ from logger_manager import log_info, log_error, logger
14
  from db.database import get_db,SessionLocal
15
+ from db.repositories import IngredientRepository, ProductRepository
16
  from dotenv import load_dotenv
17
  from langsmith import traceable
 
 
18
  import io
19
+ from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
20
  from services.ingredientFinderAgent import IngredientInfoAgentLangGraph
21
  from services.productAnalyzerAgent import analyze_product_ingredients
22
 
 
30
  # Create a semaphore to limit concurrent API calls
31
  llm_semaphore = asyncio.Semaphore(PARALLEL_RATE_LIMIT)
32
 
33
+ # Load YOLO model
34
+ yolo_model = YOLO("yolov8n-seg.pt") # Downloaded automatically if needed
 
 
 
 
 
 
 
 
 
 
 
 
35
  UPLOADED_IMAGES_DIR = "uploaded_images"
36
  if not os.path.exists(UPLOADED_IMAGES_DIR):
37
  os.makedirs(UPLOADED_IMAGES_DIR)
 
51
  details_with_source=[source.data for source in db_ingredient.sources]
52
  )
53
 
 
 
 
 
 
 
54
 
55
+ def extract_product_from_image_yolo(image_path: str) -> str | None:
56
+ """Extracts the product image using YOLOv8 with preprocessing and postprocessing."""
 
57
  try:
58
+ # Load image
59
  image = cv2.imread(image_path)
 
60
 
61
+ # Preprocessing: Resize image
62
+ target_size = (640, 640)
63
+ image_resized = cv2.resize(image, target_size)
64
+
65
+ # Run inference with YOLO
66
+ results = yolo_model(image_resized)
67
+
68
+ if not results:
69
+ print("No objects detected by YOLO.")
70
+ return None
71
+
72
+ # Process results
73
+ result = results[0]
74
+ masks = result.masks
75
 
76
+ if masks is None:
77
+ print("No segmentation masks found by YOLO.")
78
  return None
79
 
80
+ # Select the largest mask
81
+ largest_mask_index = np.argmax([mask.area for mask in masks])
82
+ largest_mask_tensor = masks[largest_mask_index].data.cpu()
83
+ largest_mask = largest_mask_tensor.numpy().astype(np.uint8)
84
 
85
+ # Resize the mask to the original image size
86
+ largest_mask = cv2.resize(largest_mask, (image.shape[1], image.shape[0]))
87
+
88
+ # Postprocessing: Basic mask cleanup (dilation/erosion)
89
+ kernel = np.ones((5, 5), np.uint8)
90
+ mask_cleaned = cv2.dilate(largest_mask, kernel, iterations=1)
91
+ mask_cleaned = cv2.erode(mask_cleaned, kernel, iterations=1)
92
+
93
  # Create a masked image
94
  masked_image = np.zeros_like(image)
95
+ masked_image[mask_cleaned.astype(bool)] = image[mask_cleaned.astype(bool)]
96
 
97
  # Crop the image
98
+ y_coords, x_coords = np.where(mask_cleaned)
99
  x_min, x_max = np.min(x_coords), np.max(x_coords)
100
  y_min, y_max = np.min(y_coords), np.max(y_coords)
101
+ cropped_image = masked_image[y_min:y_max, x_min:x_max]
102
 
103
  # Save the cropped image
104
  cropped_image_path = os.path.join(
 
112
  print(f"Error during image processing: {e}")
113
  return None
114
 
115
+
116
  @router.post("/process_image")
117
  async def process_image(image: UploadFile = File(...)):
118
  """
 
135
  print("Image saved temporarily to:", temp_image_path)
136
 
137
  # Extract the product
138
+ extracted_product_path = extract_product_from_image_yolo(temp_image_path)
139
 
140
  # Remove the temporary file
141
  os.remove(temp_image_path)
 
147
  {
148
  "message": "Product extracted successfully",
149
  "product_image_path": extracted_product_path,
150
+ "image": FileResponse(extracted_product_path, media_type="image/jpeg")
151
  }
152
  )
153
  else:
 
160
  print("Error:", e)
161
  return JSONResponse({"error": str(e)}, status_code=500)
162
 
163
+
164
  # process single ingredient
165
  @router.post("/process_ingredient", response_model=IngredientAnalysisResult)
166
  @traceable
 
249
 
250
  @router.post("/process_product_ingredients", response_model=Dict[str, Any])
251
  @traceable
252
+ async def process_ingredients_endpoint(product_ingredient: ProductIngredientsRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
253
  log_info(f"process_ingredients_endpoint called for {len(product_ingredient.ingredients)} ingredients")
254
  ingredients = product_ingredient.ingredients
255
  try:
 
293
 
294
  except Exception as e:
295
  log_error(f"Error in process_ingredients_endpoint: {str(e)}")
296
+ raise HTTPException(status_code=500, detail="Internal Server Error")