File size: 3,492 Bytes
5c5ce54 27c705a 76e7fd5 27c705a 5c5ce54 27c705a 5c5ce54 27c705a 5c5ce54 5567bb1 5c5ce54 27c705a 5c5ce54 27c705a 5c5ce54 5567bb1 27c705a 5567bb1 5c5ce54 5567bb1 5c5ce54 39a3afe 5c5ce54 39a3afe 27c705a 5567bb1 27c705a 5c5ce54 27c705a 5c5ce54 27c705a 76e7fd5 27c705a 76e7fd5 5c5ce54 27c705a 5c5ce54 27c705a 5c5ce54 27c705a 5c5ce54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
"""Model wrapper for LiteLLM"""
import os
import json
from typing import List, Dict, Any, Optional
try:
import litellm
except ImportError:
print("⚠️ litellm not installed. Install with: pip install litellm")
litellm = None
class LiteLLMModel:
"""Wrapper for LiteLLM models"""
def __init__(self, model_id: str):
self.model_id = model_id
# Check for Groq API key
if "groq" in model_id.lower():
if not os.getenv("GROQ_API_KEY"):
print("⚠️ GROQ_API_KEY not set in environment")
raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.")
def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
if not litellm:
return {"content": "Unknown - litellm not installed"}
try:
formatted_tools = None
if tools:
formatted_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
}
for tool in tools
]
# Groq configuration
if "groq" in self.model_id.lower():
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise RuntimeError("GROQ_API_KEY not set in environment")
print(f"DEBUG: Using Groq model: {self.model_id}")
response = litellm.completion(
model=self.model_id,
api_key=api_key,
messages=messages,
tools=formatted_tools,
temperature=0.1
)
else:
# Generic model support
response = litellm.completion(
model=self.model_id,
messages=messages,
tools=formatted_tools,
temperature=0.1
)
message = response.choices[0].message
result = {
"content": message.content or ""
}
if hasattr(message, 'tool_calls') and message.tool_calls:
result["tool_calls"] = []
for tc in message.tool_calls:
# Parse arguments if they're a string
args = tc.function.arguments
if isinstance(args, str):
try:
args = json.loads(args)
except:
args = {}
result["tool_calls"].append({
"id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}",
"name": tc.function.name,
"arguments": args
})
return result
except Exception as e:
print(f"Model error: {e}")
return {"content": "Unknown"}
def get_model(model_type: str, model_id: str):
if model_type == "LiteLLMModel":
return LiteLLMModel(model_id)
else:
raise ValueError(f"Unknown model type: {model_type}") |