SantoshKumar1310's picture
Update model.py
76e7fd5 verified
"""Model wrapper for LiteLLM"""
import os
import json
from typing import List, Dict, Any, Optional
try:
import litellm
except ImportError:
print("⚠️ litellm not installed. Install with: pip install litellm")
litellm = None
class LiteLLMModel:
"""Wrapper for LiteLLM models"""
def __init__(self, model_id: str):
self.model_id = model_id
# Check for Groq API key
if "groq" in model_id.lower():
if not os.getenv("GROQ_API_KEY"):
print("⚠️ GROQ_API_KEY not set in environment")
raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.")
def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
if not litellm:
return {"content": "Unknown - litellm not installed"}
try:
formatted_tools = None
if tools:
formatted_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
}
for tool in tools
]
# Groq configuration
if "groq" in self.model_id.lower():
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise RuntimeError("GROQ_API_KEY not set in environment")
print(f"DEBUG: Using Groq model: {self.model_id}")
response = litellm.completion(
model=self.model_id,
api_key=api_key,
messages=messages,
tools=formatted_tools,
temperature=0.1
)
else:
# Generic model support
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=formatted_tools,
temperature=0.1
)
message = response.choices[0].message
result = {
"content": message.content or ""
}
if hasattr(message, 'tool_calls') and message.tool_calls:
result["tool_calls"] = []
for tc in message.tool_calls:
# Parse arguments if they're a string
args = tc.function.arguments
if isinstance(args, str):
try:
args = json.loads(args)
except:
args = {}
result["tool_calls"].append({
"id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}",
"name": tc.function.name,
"arguments": args
})
return result
except Exception as e:
print(f"Model error: {e}")
return {"content": "Unknown"}
def get_model(model_type: str, model_id: str):
if model_type == "LiteLLMModel":
return LiteLLMModel(model_id)
else:
raise ValueError(f"Unknown model type: {model_type}")