SantoshKumar1310 commited on
Commit
39a3afe
·
verified ·
1 Parent(s): 321243d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +15 -41
model.py CHANGED
@@ -11,29 +11,21 @@ except ImportError:
11
  print("⚠️ litellm not installed. Install with: pip install litellm")
12
  litellm = None
13
 
14
-
15
  class LiteLLMModel:
16
  """Wrapper for LiteLLM models"""
17
 
18
  def __init__(self, model_id: str):
19
  self.model_id = model_id
20
 
21
- # Set API keys from environment for Gemini
22
  if "gemini" in model_id.lower():
23
  if not os.getenv("GEMINI_API_KEY"):
24
  print("⚠️ GEMINI_API_KEY not set in environment")
25
 
26
  def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
27
- """
28
- Generate response from model.
29
- Returns:
30
- Dict with 'content' and optionally 'tool_calls'
31
- """
32
  if not litellm:
33
  return {"content": "Unknown - litellm not installed"}
34
 
35
  try:
36
- # Convert tools to OpenAI format
37
  formatted_tools = None
38
  if tools:
39
  formatted_tools = [
@@ -48,45 +40,30 @@ class LiteLLMModel:
48
  for tool in tools
49
  ]
50
 
51
- # --- Provider Handling ---
52
  if "gemini" in self.model_id.lower():
53
  api_key = os.getenv("GEMINI_API_KEY")
54
  if not api_key:
55
  raise RuntimeError("GEMINI_API_KEY not set in environment")
 
56
 
57
- # Try with original model ID
58
- model_name_1 = self.model_id
59
- # Try adding "google/" prefix if missing
60
- model_name_2 = f"google/{self.model_id}" if not self.model_id.startswith("google/") else self.model_id
 
61
 
62
- # Debug print for environment details
63
  print("DEBUG: GEMINI_API_KEY last4:", api_key[-4:])
 
64
 
65
- # Try first model_name_1
66
- try:
67
- print(f"DEBUG: Trying Gemini with model_name_1 = {model_name_1}")
68
- response = litellm.completion(
69
- provider="google",
70
- model=model_name_1,
71
- api_key=api_key,
72
- messages=messages,
73
- tools=formatted_tools,
74
- temperature=0.1
75
- )
76
- except Exception as e1:
77
- print(f"DEBUG: model_name_1 failed: {e1}")
78
- # Try model_name_2
79
- print(f"DEBUG: Trying Gemini with model_name_2 = {model_name_2}")
80
- response = litellm.completion(
81
- provider="google",
82
- model=model_name_2,
83
- api_key=api_key,
84
- messages=messages,
85
- tools=formatted_tools,
86
- temperature=0.1
87
- )
88
  else:
89
- # For other models/providers, pass model without provider and api_key
90
  response = litellm.completion(
91
  model=self.model_id,
92
  messages=messages,
@@ -99,7 +76,6 @@ class LiteLLMModel:
99
  "content": message.content or ""
100
  }
101
 
102
- # Handle tool calls if present
103
  if hasattr(message, 'tool_calls') and message.tool_calls:
104
  result["tool_calls"] = [
105
  {
@@ -115,9 +91,7 @@ class LiteLLMModel:
115
  print(f"Model error: {e}")
116
  return {"content": "Unknown"}
117
 
118
-
119
  def get_model(model_type: str, model_id: str):
120
- """Get model instance"""
121
  if model_type == "LiteLLMModel":
122
  return LiteLLMModel(model_id)
123
  else:
 
11
  print("⚠️ litellm not installed. Install with: pip install litellm")
12
  litellm = None
13
 
 
14
  class LiteLLMModel:
15
  """Wrapper for LiteLLM models"""
16
 
17
  def __init__(self, model_id: str):
18
  self.model_id = model_id
19
 
 
20
  if "gemini" in model_id.lower():
21
  if not os.getenv("GEMINI_API_KEY"):
22
  print("⚠️ GEMINI_API_KEY not set in environment")
23
 
24
  def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
 
 
 
 
 
25
  if not litellm:
26
  return {"content": "Unknown - litellm not installed"}
27
 
28
  try:
 
29
  formatted_tools = None
30
  if tools:
31
  formatted_tools = [
 
40
  for tool in tools
41
  ]
42
 
 
43
  if "gemini" in self.model_id.lower():
44
  api_key = os.getenv("GEMINI_API_KEY")
45
  if not api_key:
46
  raise RuntimeError("GEMINI_API_KEY not set in environment")
47
+ provider = "google"
48
 
49
+ # Use a known supported Gemini model id from LiteLLM and Google docs
50
+ # Replace MODEL_ID in your app.py with one of these if needed
51
+ # Example model names: "vertex_ai/gemini-1.0-pro-001", "vertex_ai/gemini-2.5-flash"
52
+ # Choose one according to what is currently supported
53
+ supported_model_id = "vertex_ai/gemini-1.0-pro-001"
54
 
 
55
  print("DEBUG: GEMINI_API_KEY last4:", api_key[-4:])
56
+ print(f"DEBUG: Using supported Gemini model id: {supported_model_id}")
57
 
58
+ response = litellm.completion(
59
+ provider=provider,
60
+ model=supported_model_id, # Use supported model id here
61
+ api_key=api_key,
62
+ messages=messages,
63
+ tools=formatted_tools,
64
+ temperature=0.1
65
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  else:
 
67
  response = litellm.completion(
68
  model=self.model_id,
69
  messages=messages,
 
76
  "content": message.content or ""
77
  }
78
 
 
79
  if hasattr(message, 'tool_calls') and message.tool_calls:
80
  result["tool_calls"] = [
81
  {
 
91
  print(f"Model error: {e}")
92
  return {"content": "Unknown"}
93
 
 
94
  def get_model(model_type: str, model_id: str):
 
95
  if model_type == "LiteLLMModel":
96
  return LiteLLMModel(model_id)
97
  else: