Prathamesh Sable commited on
Commit
5c8f5bf
·
1 Parent(s): 325bd1a

bug fix and improvements

Browse files
main.py CHANGED
@@ -10,6 +10,8 @@ from dotenv import load_dotenv
10
  import os
11
  import uvicorn
12
  from pathlib import Path
 
 
13
 
14
  load_dotenv()
15
  # Load environment variables from .env file
@@ -20,6 +22,17 @@ templates = Jinja2Templates(directory="templates")
20
 
21
  app = FastAPI()
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  @app.get("/")
24
  def read_root():
25
  return RedirectResponse("/api")
 
10
  import os
11
  import uvicorn
12
  from pathlib import Path
13
+ import tensorflow as tf
14
+ import tensorflow_hub as hub
15
 
16
  load_dotenv()
17
  # Load environment variables from .env file
 
22
 
23
  app = FastAPI()
24
 
25
+ # Suppress TensorFlow warnings
26
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 0=all, 1=no INFO, 2=no WARNING, 3=no ERROR
27
+
28
+ # Store the model as a state variable in the app
29
+ @app.on_event("startup")
30
+ async def startup_event():
31
+ # Load model once during startup
32
+ print("Loading TensorFlow model...")
33
+ app.state.detector = hub.load("https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1").signatures['default']
34
+ print("TensorFlow model loaded successfully!")
35
+
36
  @app.get("/")
37
  def read_root():
38
  return RedirectResponse("/api")
routers/analysis.py CHANGED
@@ -94,13 +94,26 @@ async def process_ingredients_endpoint(product_ingredient: ProductIngredientsReq
94
 
95
  # Step 2: Generate aggregate analysis with product analyzer agent
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  product_analysis = await analyze_product_ingredients(
98
  ingredients_data=ingredient_results,
99
- user_preferences={
100
- "user_id": current_user.id,
101
- "allergies": current_user.preferences[0].allergens if current_user.preferences else None,
102
- "dietary_restrictions": current_user.preferences[0].dietary_restrictions if current_user.preferences else None
103
- } if current_user else {}
104
  )
105
 
106
  # print("Product analysis result:", product_analysis)
 
94
 
95
  # Step 2: Generate aggregate analysis with product analyzer agent
96
 
97
+ # Safely get user preferences, handling the case where the preferences table doesn't exist
98
+ user_preferences = {}
99
+ if current_user:
100
+ user_preferences["user_id"] = current_user.id
101
+ try:
102
+ # Only try to access preferences if the relationship exists
103
+ if hasattr(current_user, 'preferences') and current_user.preferences:
104
+ user_preferences["allergies"] = current_user.preferences[0].allergens
105
+ user_preferences["dietary_restrictions"] = current_user.preferences[0].dietary_restrictions
106
+ else:
107
+ user_preferences["allergies"] = None
108
+ user_preferences["dietary_restrictions"] = None
109
+ except Exception as e:
110
+ log_error(f"Error accessing user preferences: {e}", e)
111
+ user_preferences["allergies"] = None
112
+ user_preferences["dietary_restrictions"] = None
113
+
114
  product_analysis = await analyze_product_ingredients(
115
  ingredients_data=ingredient_results,
116
+ user_preferences=user_preferences
 
 
 
 
117
  )
118
 
119
  # print("Product analysis result:", product_analysis)
routers/product.py CHANGED
@@ -26,26 +26,13 @@ from dotenv import load_dotenv
26
  from services.ingredients import IngredientService
27
  from services.productAnalyzerAgent import analyze_product_ingredients
28
  from utils.db_utils import add_product_to_database
 
29
  from utils.fetch_data import fetch_product_data_from_api
 
30
 
31
 
32
  load_dotenv()
33
 
34
-
35
- UPLOADED_IMAGES_DIR = "uploaded_images"
36
- if not os.path.exists(UPLOADED_IMAGES_DIR):
37
- os.makedirs(UPLOADED_IMAGES_DIR)
38
-
39
-
40
- # TensorFlow model caching
41
- detector = None
42
-
43
-
44
- def load_detector():
45
- global detector
46
- if detector is None:
47
- detector = hub.load("https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1").signatures['default']
48
-
49
  VUFORIA_SERVER_ACCESS_KEY = os.getenv("VUFORIA_SERVER_ACCESS_KEY")
50
  VUFORIA_SERVER_SECRET_KEY = os.getenv("VUFORIA_SERVER_SECRET_KEY")
51
  VUFORIA_TARGET_DATABASE_NAME = os.getenv("VUFORIA_TARGET_DATABASE_NAME")
@@ -56,12 +43,11 @@ router = APIRouter()
56
 
57
  TARGET_CLASSES = set(["Food processor", "Fast food", "Food", "Seafood", "Snack"])
58
 
59
- def run_object_detection(image: Image.Image):
60
- load_detector() # Ensure model is loaded
 
61
  image_np = np.array(image)
62
- # Convert to tensor without specifying dtype
63
  input_tensor = tf.convert_to_tensor(image_np)[tf.newaxis, ...]
64
- # Convert to float32 and normalize to [0,1]
65
  input_tensor = tf.cast(input_tensor, tf.float32) / 255.0
66
  results = detector(input_tensor)
67
  results = {k: v.numpy() for k, v in results.items()}
@@ -70,26 +56,25 @@ def run_object_detection(image: Image.Image):
70
  def get_filtered_class_boxes(results):
71
  # for same class, keep the one with the highest score
72
  # and remove duplicates
73
- boxes = []
74
- classes = []
75
- scores = []
76
 
77
  for i in range(len(results["detection_scores"])):
78
  class_name = results["detection_class_entities"][i].decode("utf-8")
79
  box = results["detection_boxes"][i]
80
  score = results["detection_scores"][i]
81
  if class_name in TARGET_CLASSES:
82
- if class_name not in classes:
83
- boxes.append(box)
84
- classes.append(class_name)
85
- scores.append(score)
86
  else:
87
- index = classes.index(class_name)
88
- if score > scores[index]:
89
- boxes[index] = box
90
- classes[index] = class_name
91
- scores[index] = score
92
- return boxes, classes, scores
93
 
94
  def crop_image(image_np, box):
95
  ymin, xmin, ymax, xmax = box
@@ -179,7 +164,7 @@ async def create_product(
179
 
180
 
181
  @router.post("/process_image")
182
- async def process_image_endpoint(file: UploadFile = File(...), db: Session = Depends(get_db)):
183
  """
184
  Receives an image file, performs object detection, and returns information about detected objects.
185
  """
@@ -189,34 +174,30 @@ async def process_image_endpoint(file: UploadFile = File(...), db: Session = Dep
189
  image_data = await file.read()
190
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
191
 
192
- # Run object detection
193
- results, image_np = run_object_detection(image)
194
 
195
  # Get filtered class boxes
196
- boxes, class_names, scores = get_filtered_class_boxes(results)
197
 
198
- detected_objects = []
199
- for i in range(len(boxes)):
200
  # Crop the detected object
201
- cropped_img = crop_image(image_np, boxes[i])
202
-
203
- # Save the cropped image temporarily
204
- cropped_image_path = os.path.join(UPLOADED_IMAGES_DIR, f"detected_{class_names[i]}_{scores[i]:.2f}.jpg")
205
- cropped_img.save(cropped_image_path)
206
 
207
- # Find if a product with this image exists in the database
208
- product_repo = ProductRepository(db)
209
- product = product_repo.get_product_by_image_name(os.path.basename(cropped_image_path))
210
-
211
- detected_objects.append({
212
- "class_name": class_names[i],
213
- "score": float(scores[i]),
214
- "product_info": product.to_dict() if product else None # Assuming Product model has a to_dict method
215
- })
216
 
217
- return JSONResponse({"detected_objects": detected_objects})
 
 
 
 
218
  except Exception as e:
219
- log_error(f"Error processing image: {e}", exc_info=True)
220
  raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
221
 
222
 
 
26
  from services.ingredients import IngredientService
27
  from services.productAnalyzerAgent import analyze_product_ingredients
28
  from utils.db_utils import add_product_to_database
29
+ from utils.vuforia_utils import add_target_to_vuforia, UPLOADED_IMAGES_DIR
30
  from utils.fetch_data import fetch_product_data_from_api
31
+ import uuid
32
 
33
 
34
  load_dotenv()
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  VUFORIA_SERVER_ACCESS_KEY = os.getenv("VUFORIA_SERVER_ACCESS_KEY")
37
  VUFORIA_SERVER_SECRET_KEY = os.getenv("VUFORIA_SERVER_SECRET_KEY")
38
  VUFORIA_TARGET_DATABASE_NAME = os.getenv("VUFORIA_TARGET_DATABASE_NAME")
 
43
 
44
  TARGET_CLASSES = set(["Food processor", "Fast food", "Food", "Seafood", "Snack"])
45
 
46
+ def run_object_detection(image: Image.Image, request: Request):
47
+ # Access the model from app state
48
+ detector = request.app.state.detector
49
  image_np = np.array(image)
 
50
  input_tensor = tf.convert_to_tensor(image_np)[tf.newaxis, ...]
 
51
  input_tensor = tf.cast(input_tensor, tf.float32) / 255.0
52
  results = detector(input_tensor)
53
  results = {k: v.numpy() for k, v in results.items()}
 
56
  def get_filtered_class_boxes(results):
57
  # for same class, keep the one with the highest score
58
  # and remove duplicates
59
+ high_boxes = None
60
+ high_classes = None
61
+ high_scores = None
62
 
63
  for i in range(len(results["detection_scores"])):
64
  class_name = results["detection_class_entities"][i].decode("utf-8")
65
  box = results["detection_boxes"][i]
66
  score = results["detection_scores"][i]
67
  if class_name in TARGET_CLASSES:
68
+ if high_boxes is None:
69
+ high_boxes = box
70
+ high_classes = class_name
71
+ high_scores = score
72
  else:
73
+ if score > high_scores:
74
+ high_boxes = box
75
+ high_classes = class_name
76
+ high_scores = score
77
+ return high_boxes, high_classes, high_scores
 
78
 
79
  def crop_image(image_np, box):
80
  ymin, xmin, ymax, xmax = box
 
164
 
165
 
166
  @router.post("/process_image")
167
+ async def process_image_endpoint(file: UploadFile = File(...), db: Session = Depends(get_db), request: Request = None):
168
  """
169
  Receives an image file, performs object detection, and returns information about detected objects.
170
  """
 
174
  image_data = await file.read()
175
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
176
 
177
+ # Run object detection with the request object
178
+ results, image_np = run_object_detection(image, request)
179
 
180
  # Get filtered class boxes
181
+ box, class_name, score = get_filtered_class_boxes(results)
182
 
 
 
183
  # Crop the detected object
184
+ cropped_img = crop_image(image_np, box)
 
 
 
 
185
 
186
+ # Save the cropped image temporarily
187
+ unique_id = uuid.uuid4().hex
188
+ cropped_image_name = f"detected_{class_name}_{score:.2f}_{unique_id}.jpg"
189
+ cropped_image_path = os.path.join(
190
+ UPLOADED_IMAGES_DIR, cropped_image_name
191
+ )
192
+ cropped_img.save(cropped_image_path)
 
 
193
 
194
+ return JSONResponse({
195
+ "class_name": class_name,
196
+ "score": float(score),
197
+ "image_name": cropped_image_name
198
+ })
199
  except Exception as e:
200
+ log_error(f"Error processing image: {e}", e)
201
  raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
202
 
203
 
utils/db_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from sqlalchemy.orm import Session
2
  from interfaces.ingredientModels import IngredientAnalysisResult
3
  from interfaces.productModels import ProductCreate
@@ -6,21 +7,61 @@ from logger_manager import log_info, log_error
6
  from fastapi import HTTPException
7
  import os
8
  from services.product_service import ProductService
9
- from routers.product import add_target_to_vuforia, UPLOADED_IMAGES_DIR # Assuming add_target_to_vuforia and UPLOADED_IMAGES_DIR are needed and will remain in product.py for now. If they are also moved, the import needs adjustment.
 
10
 
11
 
12
  def ingredient_db_to_pydantic(db_ingredient):
13
  """Convert a database ingredient model to a Pydantic model."""
14
- return IngredientAnalysisResult(
15
- name=db_ingredient.name,
16
- alternate_names=db_ingredient.alternate_names or [],
17
- is_found=True,
18
- id=db_ingredient.id,
19
- safety_rating=db_ingredient.safety_rating or 5,
20
- description=db_ingredient.description or "No description available",
21
- health_effects=db_ingredient.health_effects or ["Unknown"],
22
- details_with_source=[source.data for source in db_ingredient.sources]
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  async def add_product_to_database(
 
1
+ from typing import Dict, List,Any
2
  from sqlalchemy.orm import Session
3
  from interfaces.ingredientModels import IngredientAnalysisResult
4
  from interfaces.productModels import ProductCreate
 
7
  from fastapi import HTTPException
8
  import os
9
  from services.product_service import ProductService
10
+ from utils.vuforia_utils import add_target_to_vuforia, UPLOADED_IMAGES_DIR # Assuming add_target_to_vuforia and UPLOADED_IMAGES_DIR are needed and will remain in product.py for now. If they are also moved, the import needs adjustment.
11
+ import json
12
 
13
 
14
  def ingredient_db_to_pydantic(db_ingredient):
15
  """Convert a database ingredient model to a Pydantic model."""
16
+ try:
17
+ # Parse string fields that should be lists or dictionaries
18
+ if isinstance(db_ingredient.alternate_names, str):
19
+ alternate_names = json.loads(db_ingredient.alternate_names)
20
+ else:
21
+ alternate_names = db_ingredient.alternate_names or []
22
+
23
+ if isinstance(db_ingredient.health_effects, str):
24
+ health_effects = json.loads(db_ingredient.health_effects)
25
+ else:
26
+ health_effects = db_ingredient.health_effects or ["Unknown"]
27
+
28
+ # Handle details_with_source, which should be a list of dictionaries
29
+ if hasattr(db_ingredient, 'sources') and db_ingredient.sources:
30
+ details = []
31
+ for source in db_ingredient.sources:
32
+ if isinstance(source.data, str):
33
+ try:
34
+ details.append(json.loads(source.data))
35
+ except json.JSONDecodeError:
36
+ details.append({"source": "Unknown", "data": source.data})
37
+ else:
38
+ details.append(source.data)
39
+ else:
40
+ details = []
41
+
42
+ return IngredientAnalysisResult(
43
+ name=db_ingredient.name,
44
+ alternate_names=alternate_names,
45
+ is_found=True,
46
+ id=db_ingredient.id,
47
+ safety_rating=db_ingredient.safety_rating or 5,
48
+ description=db_ingredient.description or "No description available",
49
+ health_effects=health_effects,
50
+ details_with_source=details
51
+ )
52
+ except Exception as e:
53
+ log_error(f"Error converting DB ingredient to Pydantic model: {e}", e)
54
+ # Fallback with minimal valid data
55
+ return IngredientAnalysisResult(
56
+ name=db_ingredient.name,
57
+ alternate_names=[],
58
+ is_found=True,
59
+ id=db_ingredient.id,
60
+ safety_rating=db_ingredient.safety_rating or 5,
61
+ description=db_ingredient.description or "No description available",
62
+ health_effects=["Unknown"],
63
+ details_with_source=[]
64
+ )
65
 
66
 
67
  async def add_product_to_database(
utils/ingredient_utils.py CHANGED
@@ -4,11 +4,14 @@ from sqlalchemy.orm import Session
4
  from db.database import SessionLocal
5
  from db.repositories import IngredientRepository
6
  from interfaces.ingredientModels import IngredientAnalysisResult
 
7
  from services.ingredientFinderAgent import IngredientInfoAgentLangGraph
8
  from dotenv import load_dotenv
9
  from langsmith import traceable
10
  import pytz
11
 
 
 
12
  # Load environment variables
13
  load_dotenv()
14
 
@@ -20,45 +23,49 @@ llm_semaphore = asyncio.Semaphore(PARALLEL_RATE_LIMIT)
20
 
21
 
22
  @traceable
23
- async def process_single_ingredient(ingredient_name: str):
24
  """Process a single ingredient asynchronously with rate limiting"""
25
- # Create a new DB session for this specific task to avoid conflicts
26
- session = SessionLocal()
27
-
28
  try:
29
- # Check if ingredient exists in database
30
- repo = IngredientRepository(session)
31
- db_ingredient = repo.get_ingredient_by_name(ingredient_name)
32
-
33
- if db_ingredient:
34
- # Assuming ingredient_db_to_pydantic is now in a utils file, e.g., utils.db_utils
35
- from .db_utils import ingredient_db_to_pydantic
36
- ingredient_data = ingredient_db_to_pydantic(db_ingredient)
37
- return ingredient_data
38
- else:
39
- # Apply rate limiting for LLM calls only if not in database
40
- async with llm_semaphore:
41
- # Get from agent if not in database
42
- ingredient_finder = IngredientInfoAgentLangGraph()
43
-
44
- ingredient_data = await ingredient_finder.process_ingredient_async(ingredient_name)
45
-
46
- # Save to database for future use
47
- repo.create_ingredient(ingredient_data)
48
-
49
- return ingredient_data
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
- # Return a minimal result on error to avoid failing the entire batch
 
52
  return IngredientAnalysisResult(
53
  name=ingredient_name,
54
  is_found=False,
 
 
55
  safety_rating=0,
56
- description=f"Error during processing: {str(e)}",
57
- health_effects=["Error during processing"],
58
- allergic_info=[],
59
- diet_type="unknown",
60
  details_with_source=[]
61
- )
62
- finally:
63
- # Important: Close the session when done
64
- session.close()
 
4
  from db.database import SessionLocal
5
  from db.repositories import IngredientRepository
6
  from interfaces.ingredientModels import IngredientAnalysisResult
7
+ from logger_manager import log_error, log_info
8
  from services.ingredientFinderAgent import IngredientInfoAgentLangGraph
9
  from dotenv import load_dotenv
10
  from langsmith import traceable
11
  import pytz
12
 
13
+ from utils.db_utils import ingredient_db_to_pydantic
14
+
15
  # Load environment variables
16
  load_dotenv()
17
 
 
23
 
24
 
25
  @traceable
26
+ async def process_single_ingredient(ingredient_name: str) -> IngredientAnalysisResult:
27
  """Process a single ingredient asynchronously with rate limiting"""
 
 
 
28
  try:
29
+ # First check if ingredient exists in the database
30
+ with SessionLocal() as db:
31
+ repo = IngredientRepository(db)
32
+ db_ingredient = repo.get_ingredient_by_name(ingredient_name)
33
+
34
+ if db_ingredient:
35
+ log_info(f"Using cached ingredient data for: {ingredient_name}")
36
+ return ingredient_db_to_pydantic(db_ingredient)
37
+
38
+ # If not in database, process it
39
+ log_info(f"Processing new ingredient: {ingredient_name}")
40
+ ingredient_finder = IngredientInfoAgentLangGraph()
41
+
42
+ try:
43
+ result = await ingredient_finder.process_ingredient_async(ingredient_name)
44
+ except RuntimeError:
45
+ result = ingredient_finder.process_ingredient(ingredient_name)
46
+
47
+ # Important: Add an id field even for new ingredients
48
+ # You can use a temporary id (will be replaced when saved to DB)
49
+ result.id = 0 # Temporary ID
50
+
51
+ # Save to database for future use
52
+ with SessionLocal() as db:
53
+ repo = IngredientRepository(db)
54
+ db_ingredient = repo.create_ingredient(result)
55
+ # Update with the real database ID
56
+ result.id = db_ingredient.id
57
+
58
+ return result
59
  except Exception as e:
60
+ log_error(f"Error processing ingredient {ingredient_name}: {e}", e)
61
+ # Return a minimal valid result for failed ingredients
62
  return IngredientAnalysisResult(
63
  name=ingredient_name,
64
  is_found=False,
65
+ id=0, # Add this missing required field
66
+ alternate_names=[],
67
  safety_rating=0,
68
+ description="Error processing this ingredient",
69
+ health_effects=["Unknown"],
 
 
70
  details_with_source=[]
71
+ )
 
 
 
utils/vuforia_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from logger_manager import log_info, log_error
3
+ from PIL import Image
4
+ import os
5
+ from pathlib import Path
6
+ from dotenv import load_dotenv
7
+ import requests
8
+ load_dotenv()
9
+
10
+ UPLOADED_IMAGES_DIR = "uploaded_images"
11
+ if not os.path.exists(UPLOADED_IMAGES_DIR):
12
+ os.makedirs(UPLOADED_IMAGES_DIR)
13
+
14
+
15
+ VUFORIA_SERVER_ACCESS_KEY = os.getenv("VUFORIA_SERVER_ACCESS_KEY")
16
+ VUFORIA_SERVER_SECRET_KEY = os.getenv("VUFORIA_SERVER_SECRET_KEY")
17
+ VUFORIA_TARGET_DATABASE_NAME = os.getenv("VUFORIA_TARGET_DATABASE_NAME")
18
+ VUFORIA_TARGET_DATABASE_ID = os.getenv("VUFORIA_TARGET_DATABASE_ID")
19
+
20
+ def get_vuforia_auth_headers():
21
+ """
22
+ Returns the authentication headers for Vuforia API requests.
23
+ """
24
+ return {
25
+ "Authorization": f"VWS {VUFORIA_SERVER_ACCESS_KEY}:{VUFORIA_SERVER_SECRET_KEY}",
26
+ "Content-Type": "application/json",
27
+ }
28
+
29
+
30
+ async def add_target_to_vuforia(image_name: str, image_path: str) -> str:
31
+ """
32
+ Adds a target to the Vuforia database and returns the Vuforia target ID.
33
+ """
34
+ log_info(f"Adding target {image_name} to Vuforia")
35
+
36
+ try:
37
+ with open(image_path, "rb") as image_file:
38
+ image_data = image_file.read()
39
+
40
+ url = f"https://vws.vuforia.com/targets"
41
+
42
+ headers = get_vuforia_auth_headers()
43
+ payload = {
44
+ "name": image_name,
45
+ "width": 1.0, # Default width
46
+ "image": image_data.hex(), # Convert image data to hex
47
+ "active_flag": True,
48
+ }
49
+
50
+ response = await requests.post(url, headers=headers, json=payload)
51
+ response_data = json.loads(response.text)
52
+ if response.status_code == 201:
53
+ log_info(
54
+ f"Target {image_name} added successfully with Vuforia ID: {response_data['target_id']}"
55
+ )
56
+ return response_data["target_id"]
57
+ else:
58
+ log_error(f"Failed to add target {image_name}: {response.text}")
59
+ raise Exception(f"Failed to add target {image_name}: {response.text}")
60
+ except Exception as e:
61
+ log_error(f"Error adding target {image_name}: {e}",e)
62
+ raise
63
+
64
+