File size: 3,492 Bytes
5c5ce54
27c705a
 
76e7fd5
27c705a
 
 
 
 
 
 
 
5c5ce54
27c705a
 
5c5ce54
27c705a
 
5c5ce54
5567bb1
 
 
 
 
5c5ce54
27c705a
 
 
5c5ce54
27c705a
 
 
 
 
 
 
 
 
 
 
 
 
 
5c5ce54
5567bb1
 
 
27c705a
5567bb1
5c5ce54
5567bb1
5c5ce54
39a3afe
5c5ce54
39a3afe
 
 
 
 
27c705a
5567bb1
27c705a
 
 
 
 
 
5c5ce54
27c705a
 
 
 
5c5ce54
27c705a
76e7fd5
 
 
 
 
 
 
 
 
 
 
 
27c705a
76e7fd5
 
5c5ce54
27c705a
5c5ce54
27c705a
 
 
 
5c5ce54
27c705a
 
 
 
5c5ce54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Model wrapper for LiteLLM"""

import os
import json
from typing import List, Dict, Any, Optional

try:
    import litellm
except ImportError:
    print("⚠️ litellm not installed. Install with: pip install litellm")
    litellm = None


class LiteLLMModel:
    """Wrapper for LiteLLM models"""
    
    def __init__(self, model_id: str):
        self.model_id = model_id
        
        # Check for Groq API key
        if "groq" in model_id.lower():
            if not os.getenv("GROQ_API_KEY"):
                print("⚠️ GROQ_API_KEY not set in environment")
                raise RuntimeError("GROQ_API_KEY not set. Please add it to your Space secrets.")
    
    def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
        if not litellm:
            return {"content": "Unknown - litellm not installed"}
        
        try:
            formatted_tools = None
            if tools:
                formatted_tools = [
                    {
                        "type": "function",
                        "function": {
                            "name": tool.name,
                            "description": tool.description,
                            "parameters": tool.parameters
                        }
                    }
                    for tool in tools
                ]
            
            # Groq configuration
            if "groq" in self.model_id.lower():
                api_key = os.getenv("GROQ_API_KEY")
                if not api_key:
                    raise RuntimeError("GROQ_API_KEY not set in environment")
                
                print(f"DEBUG: Using Groq model: {self.model_id}")
                
                response = litellm.completion(
                    model=self.model_id,
                    api_key=api_key,
                    messages=messages,
                    tools=formatted_tools,
                    temperature=0.1
                )
            else:
                # Generic model support
                response = litellm.completion(
                    model=self.model_id,
                    messages=messages,
                    tools=formatted_tools,
                    temperature=0.1
                )
            
            message = response.choices[0].message
            result = {
                "content": message.content or ""
            }
            
            if hasattr(message, 'tool_calls') and message.tool_calls:
                result["tool_calls"] = []
                for tc in message.tool_calls:
                    # Parse arguments if they're a string
                    args = tc.function.arguments
                    if isinstance(args, str):
                        try:
                            args = json.loads(args)
                        except:
                            args = {}
                    
                    result["tool_calls"].append({
                        "id": tc.id if hasattr(tc, 'id') else f"call_{tc.function.name}",
                        "name": tc.function.name,
                        "arguments": args
                    })
            
            return result
        
        except Exception as e:
            print(f"Model error: {e}")
            return {"content": "Unknown"}


def get_model(model_type: str, model_id: str):
    if model_type == "LiteLLMModel":
        return LiteLLMModel(model_id)
    else:
        raise ValueError(f"Unknown model type: {model_type}")