SantoshKumar1310 commited on
Commit
f700012
·
verified ·
1 Parent(s): 1e52380

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +89 -0
model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Model wrapper for LiteLLM
4
+ """
5
+
6
+ import os
7
+ from typing import List, Dict, Any, Optional
8
+
9
+ try:
10
+ import litellm
11
+ except ImportError:
12
+ print("⚠️ litellm not installed. Install with: pip install litellm")
13
+ litellm = None
14
+
15
+
16
+ class LiteLLMModel:
17
+ """Wrapper for LiteLLM models"""
18
+
19
+ def __init__(self, model_id: str):
20
+ self.model_id = model_id
21
+
22
+ # Set API keys from environment
23
+ if "gemini" in model_id.lower():
24
+ if not os.getenv("GEMINI_API_KEY"):
25
+ print("⚠️ GEMINI_API_KEY not set in environment")
26
+
27
+ def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
28
+ """
29
+ Generate response from model.
30
+
31
+ Returns:
32
+ Dict with 'content' and optionally 'tool_calls'
33
+ """
34
+ if not litellm:
35
+ return {"content": "Unknown - litellm not installed"}
36
+
37
+ try:
38
+ # Convert tools to OpenAI format
39
+ formatted_tools = None
40
+ if tools:
41
+ formatted_tools = [
42
+ {
43
+ "type": "function",
44
+ "function": {
45
+ "name": tool.name,
46
+ "description": tool.description,
47
+ "parameters": tool.parameters
48
+ }
49
+ }
50
+ for tool in tools
51
+ ]
52
+
53
+ # Call LiteLLM
54
+ response = litellm.completion(
55
+ model=self.model_id,
56
+ messages=messages,
57
+ tools=formatted_tools,
58
+ temperature=0.1
59
+ )
60
+
61
+ message = response.choices[0].message
62
+
63
+ result = {
64
+ "content": message.content or ""
65
+ }
66
+
67
+ # Handle tool calls
68
+ if hasattr(message, 'tool_calls') and message.tool_calls:
69
+ result["tool_calls"] = [
70
+ {
71
+ "name": tc.function.name,
72
+ "arguments": eval(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments
73
+ }
74
+ for tc in message.tool_calls
75
+ ]
76
+
77
+ return result
78
+
79
+ except Exception as e:
80
+ print(f"Model error: {e}")
81
+ return {"content": "Unknown"}
82
+
83
+
84
+ def get_model(model_type: str, model_id: str):
85
+ """Get model instance"""
86
+ if model_type == "LiteLLMModel":
87
+ return LiteLLMModel(model_id)
88
+ else:
89
+ raise ValueError(f"Unknown model type: {model_type}")