SantoshKumar1310 commited on
Commit
5c5ce54
·
verified ·
1 Parent(s): 326e70a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -26
model.py CHANGED
@@ -1,6 +1,4 @@
1
- """
2
- Model wrapper for LiteLLM
3
- """
4
 
5
  import os
6
  from typing import List, Dict, Any, Optional
@@ -11,20 +9,21 @@ except ImportError:
11
  print("⚠️ litellm not installed. Install with: pip install litellm")
12
  litellm = None
13
 
 
14
  class LiteLLMModel:
15
  """Wrapper for LiteLLM models"""
16
-
17
  def __init__(self, model_id: str):
18
  self.model_id = model_id
19
-
20
  if "gemini" in model_id.lower():
21
  if not os.getenv("GEMINI_API_KEY"):
22
  print("⚠️ GEMINI_API_KEY not set in environment")
23
-
24
  def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
25
  if not litellm:
26
  return {"content": "Unknown - litellm not installed"}
27
-
28
  try:
29
  formatted_tools = None
30
  if tools:
@@ -39,25 +38,16 @@ class LiteLLMModel:
39
  }
40
  for tool in tools
41
  ]
42
-
43
  if "gemini" in self.model_id.lower():
44
  api_key = os.getenv("GEMINI_API_KEY")
45
  if not api_key:
46
  raise RuntimeError("GEMINI_API_KEY not set in environment")
47
- provider = "google"
48
-
49
- # Use a known supported Gemini model id from LiteLLM and Google docs
50
- # Replace MODEL_ID in your app.py with one of these if needed
51
- # Example model names: "vertex_ai/gemini-1.0-pro-001", "vertex_ai/gemini-2.5-flash"
52
- # Choose one according to what is currently supported
53
- supported_model_id = "vertex_ai/gemini-1.0-pro-001"
54
-
55
- print("DEBUG: GEMINI_API_KEY last4:", api_key[-4:])
56
- print(f"DEBUG: Using supported Gemini model id: {supported_model_id}")
57
-
58
  response = litellm.completion(
59
- provider=provider,
60
- model=supported_model_id, # Use supported model id here
61
  api_key=api_key,
62
  messages=messages,
63
  tools=formatted_tools,
@@ -70,12 +60,12 @@ class LiteLLMModel:
70
  tools=formatted_tools,
71
  temperature=0.1
72
  )
73
-
74
  message = response.choices[0].message
75
  result = {
76
  "content": message.content or ""
77
  }
78
-
79
  if hasattr(message, 'tool_calls') and message.tool_calls:
80
  result["tool_calls"] = [
81
  {
@@ -84,15 +74,16 @@ class LiteLLMModel:
84
  }
85
  for tc in message.tool_calls
86
  ]
87
-
88
  return result
89
-
90
  except Exception as e:
91
  print(f"Model error: {e}")
92
  return {"content": "Unknown"}
93
 
 
94
  def get_model(model_type: str, model_id: str):
95
  if model_type == "LiteLLMModel":
96
  return LiteLLMModel(model_id)
97
  else:
98
- raise ValueError(f"Unknown model type: {model_type}")
 
1
+ """Model wrapper for LiteLLM"""
 
 
2
 
3
  import os
4
  from typing import List, Dict, Any, Optional
 
9
  print("⚠️ litellm not installed. Install with: pip install litellm")
10
  litellm = None
11
 
12
+
13
  class LiteLLMModel:
14
  """Wrapper for LiteLLM models"""
15
+
16
  def __init__(self, model_id: str):
17
  self.model_id = model_id
18
+
19
  if "gemini" in model_id.lower():
20
  if not os.getenv("GEMINI_API_KEY"):
21
  print("⚠️ GEMINI_API_KEY not set in environment")
22
+
23
  def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
24
  if not litellm:
25
  return {"content": "Unknown - litellm not installed"}
26
+
27
  try:
28
  formatted_tools = None
29
  if tools:
 
38
  }
39
  for tool in tools
40
  ]
41
+
42
  if "gemini" in self.model_id.lower():
43
  api_key = os.getenv("GEMINI_API_KEY")
44
  if not api_key:
45
  raise RuntimeError("GEMINI_API_KEY not set in environment")
46
+
47
+ print(f"DEBUG: Using model id: {self.model_id}")
48
+
 
 
 
 
 
 
 
 
49
  response = litellm.completion(
50
+ model=self.model_id,
 
51
  api_key=api_key,
52
  messages=messages,
53
  tools=formatted_tools,
 
60
  tools=formatted_tools,
61
  temperature=0.1
62
  )
63
+
64
  message = response.choices[0].message
65
  result = {
66
  "content": message.content or ""
67
  }
68
+
69
  if hasattr(message, 'tool_calls') and message.tool_calls:
70
  result["tool_calls"] = [
71
  {
 
74
  }
75
  for tc in message.tool_calls
76
  ]
77
+
78
  return result
79
+
80
  except Exception as e:
81
  print(f"Model error: {e}")
82
  return {"content": "Unknown"}
83
 
84
+
85
  def get_model(model_type: str, model_id: str):
86
  if model_type == "LiteLLMModel":
87
  return LiteLLMModel(model_id)
88
  else:
89
+ raise ValueError(f"Unknown model type: {model_type}")