Gaurav vashistha commited on
Commit
1f88d4a
Β·
1 Parent(s): 88ae637

Switch to Gemini 2.0 Flash Exp for improved availability

Browse files
Files changed (2) hide show
  1. agents/visual_analyst.py +8 -10
  2. requirements.txt +1 -1
agents/visual_analyst.py CHANGED
@@ -11,25 +11,25 @@ class VisualAnalyst:
11
  self.api_key = os.getenv("GEMINI_API_KEY")
12
  if not self.api_key:
13
  raise ValueError("GEMINI_API_KEY not found")
14
-
15
  genai.configure(api_key=self.api_key)
16
 
17
  print("πŸ” Checking available Gemini models...")
18
  try:
19
- # Filter for models that support content generation
20
  my_models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
21
  print(f"πŸ“‹ Available Models: {my_models}")
22
 
23
- # UPDATED PRIORITY: Force 'Pro' model to bypass Flash 404s
24
  preferred_order = [
25
- 'models/gemini-1.5-pro', # <--- The Heavy Hitter (Smarter, diff backend)
 
26
  'models/gemini-1.5-pro-001',
27
  'models/gemini-1.5-flash',
28
  'models/gemini-1.5-flash-001',
29
  'models/gemini-pro-vision'
30
  ]
31
 
32
- selected_model = "models/gemini-1.5-pro" # Default
33
 
34
  for model_name in preferred_order:
35
  if model_name in my_models:
@@ -40,8 +40,8 @@ class VisualAnalyst:
40
  self.model = genai.GenerativeModel(selected_model)
41
 
42
  except Exception as e:
43
- print(f"⚠️ Model list failed ({e}), defaulting to gemini-1.5-pro")
44
- self.model = genai.GenerativeModel('gemini-1.5-pro')
45
 
46
  async def analyze_image(self, image_path: str):
47
  # Adaptation: Read file path to bytes, as main.py passes a path
@@ -55,13 +55,11 @@ class VisualAnalyst:
55
  "visual_features": [f"Error reading file: {str(e)}"]
56
  }
57
 
58
- # Prompt for analysis
59
  prompt = (
60
  "Analyze this product image for an e-commerce listing. "
61
  "Return a JSON object with keys: main_color, product_type, design_style, visual_features."
62
  )
63
  try:
64
- # Note: Pro model is sometimes stricter with image formats, but 'parts' usually works.
65
  # Adaptation: Run in thread to allow async await
66
  response = await asyncio.to_thread(
67
  self.model.generate_content,
@@ -73,7 +71,7 @@ class VisualAnalyst:
73
 
74
  text = response.text
75
  if text.startswith('```json'): text = text[7:]
76
- if text.startswith('```'): text = text[:-3]
77
 
78
  return json.loads(text.strip())
79
  except Exception as e:
 
11
  self.api_key = os.getenv("GEMINI_API_KEY")
12
  if not self.api_key:
13
  raise ValueError("GEMINI_API_KEY not found")
14
+
15
  genai.configure(api_key=self.api_key)
16
 
17
  print("πŸ” Checking available Gemini models...")
18
  try:
 
19
  my_models = [m.name for m in genai.list_models() if 'generateContent' in m.supported_generation_methods]
20
  print(f"πŸ“‹ Available Models: {my_models}")
21
 
22
+ # UPDATED PRIORITY: GEMINI 2.0 FIRST
23
  preferred_order = [
24
+ 'models/gemini-2.0-flash-exp', # <--- Newest & Smartest (Available in logs)
25
+ 'models/gemini-1.5-pro',
26
  'models/gemini-1.5-pro-001',
27
  'models/gemini-1.5-flash',
28
  'models/gemini-1.5-flash-001',
29
  'models/gemini-pro-vision'
30
  ]
31
 
32
+ selected_model = "models/gemini-2.0-flash-exp" # Default to the new one
33
 
34
  for model_name in preferred_order:
35
  if model_name in my_models:
 
40
  self.model = genai.GenerativeModel(selected_model)
41
 
42
  except Exception as e:
43
+ print(f"⚠️ Model list failed ({e}), defaulting to gemini-2.0-flash-exp")
44
+ self.model = genai.GenerativeModel('models/gemini-2.0-flash-exp')
45
 
46
  async def analyze_image(self, image_path: str):
47
  # Adaptation: Read file path to bytes, as main.py passes a path
 
55
  "visual_features": [f"Error reading file: {str(e)}"]
56
  }
57
 
 
58
  prompt = (
59
  "Analyze this product image for an e-commerce listing. "
60
  "Return a JSON object with keys: main_color, product_type, design_style, visual_features."
61
  )
62
  try:
 
63
  # Adaptation: Run in thread to allow async await
64
  response = await asyncio.to_thread(
65
  self.model.generate_content,
 
71
 
72
  text = response.text
73
  if text.startswith('```json'): text = text[7:]
74
+ if text.endswith('```'): text = text[:-3]
75
 
76
  return json.loads(text.strip())
77
  except Exception as e:
requirements.txt CHANGED
@@ -8,6 +8,6 @@ langchain-groq
8
  pinecone>=3.0.0
9
  pydantic
10
  python-dotenv
11
- google-generativeai>=0.7.2
12
  groq
13
  Pillow
 
8
  pinecone>=3.0.0
9
  pydantic
10
  python-dotenv
11
+ google-generativeai>=0.8.3
12
  groq
13
  Pillow