|
|
"""Model wrapper for LiteLLM""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
try: |
|
|
import litellm |
|
|
except ImportError: |
|
|
print("⚠️ litellm not installed. Install with: pip install litellm") |
|
|
litellm = None |
|
|
|
|
|
|
|
|
class LiteLLMModel: |
|
|
"""Wrapper for LiteLLM models""" |
|
|
|
|
|
def __init__(self, model_id: str): |
|
|
self.model_id = model_id |
|
|
|
|
|
|
|
|
if "groq" in model_id.lower(): |
|
|
if not os.getenv("GROQ_API_KEY"): |
|
|
print("⚠️ GROQ_API_KEY not set in environment") |
|
|
raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.") |
|
|
|
|
|
def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict: |
|
|
if not litellm: |
|
|
return {"content": "Unknown - litellm not installed"} |
|
|
|
|
|
try: |
|
|
formatted_tools = None |
|
|
if tools: |
|
|
formatted_tools = [ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": tool.name, |
|
|
"description": tool.description, |
|
|
"parameters": tool.parameters |
|
|
} |
|
|
} |
|
|
for tool in tools |
|
|
] |
|
|
|
|
|
|
|
|
if "groq" in self.model_id.lower(): |
|
|
api_key = os.getenv("GROQ_API_KEY") |
|
|
if not api_key: |
|
|
raise RuntimeError("GROQ_API_KEY not set in environment") |
|
|
|
|
|
print(f"DEBUG: Using Groq model: {self.model_id}") |
|
|
|
|
|
response = litellm.completion( |
|
|
model=self.model_id, |
|
|
api_key=api_key, |
|
|
messages=messages, |
|
|
tools=formatted_tools, |
|
|
temperature=0.1 |
|
|
) |
|
|
else: |
|
|
|
|
|
response = litellm.completion( |
|
|
model=self.model_id, |
|
|
messages=messages, |
|
|
tools=formatted_tools, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
message = response.choices[0].message |
|
|
result = { |
|
|
"content": message.content or "" |
|
|
} |
|
|
|
|
|
if hasattr(message, 'tool_calls') and message.tool_calls: |
|
|
result["tool_calls"] = [] |
|
|
for tc in message.tool_calls: |
|
|
|
|
|
args = tc.function.arguments |
|
|
if isinstance(args, str): |
|
|
try: |
|
|
args = json.loads(args) |
|
|
except: |
|
|
args = {} |
|
|
|
|
|
result["tool_calls"].append({ |
|
|
"id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}", |
|
|
"name": tc.function.name, |
|
|
"arguments": args |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Model error: {e}") |
|
|
return {"content": "Unknown"} |
|
|
|
|
|
|
|
|
def get_model(model_type: str, model_id: str): |
|
|
if model_type == "LiteLLMModel": |
|
|
return LiteLLMModel(model_id) |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {model_type}") |