SantoshKumar1310 commited on
Commit
27c705a
·
verified ·
1 Parent(s): 4a6817b

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +103 -0
model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model wrapper for LiteLLM
3
+ """
4
+
5
+ import os
6
+ from typing import List, Dict, Any, Optional
7
+
8
+ try:
9
+ import litellm
10
+ except ImportError:
11
+ print("⚠️ litellm not installed. Install with: pip install litellm")
12
+ litellm = None
13
+
14
+ class LiteLLMModel:
15
+ """Wrapper for LiteLLM models"""
16
+
17
+ def __init__(self, model_id: str):
18
+ self.model_id = model_id
19
+
20
+ # Set API keys from environment for Gemini
21
+ if "gemini" in model_id.lower():
22
+ if not os.getenv("GEMINI_API_KEY"):
23
+ print("⚠️ GEMINI_API_KEY not set in environment")
24
+
25
+ def generate(self, messages: List[Dict], tools: Optional[List] = None) -> Dict:
26
+ """
27
+ Generate response from model.
28
+ Returns:
29
+ Dict with 'content' and optionally 'tool_calls'
30
+ """
31
+ if not litellm:
32
+ return {"content": "Unknown - litellm not installed"}
33
+
34
+ try:
35
+ # Convert tools to OpenAI format
36
+ formatted_tools = None
37
+ if tools:
38
+ formatted_tools = [
39
+ {
40
+ "type": "function",
41
+ "function": {
42
+ "name": tool.name,
43
+ "description": tool.description,
44
+ "parameters": tool.parameters
45
+ }
46
+ }
47
+ for tool in tools
48
+ ]
49
+
50
+ # --- Provider Handling ---
51
+ if "gemini" in self.model_id.lower():
52
+ api_key = os.getenv("GEMINI_API_KEY")
53
+ if not api_key:
54
+ raise RuntimeError("GEMINI_API_KEY not set in environment")
55
+ provider = "google"
56
+ model_name = (
57
+ f"google/{self.model_id}"
58
+ if not self.model_id.startswith("google/")
59
+ else self.model_id
60
+ )
61
+ response = litellm.completion(
62
+ provider=provider,
63
+ model=model_name,
64
+ api_key=api_key,
65
+ messages=messages,
66
+ tools=formatted_tools,
67
+ temperature=0.1
68
+ )
69
+ else:
70
+ # For other models/providers, only pass model (configure as needed for your environment)
71
+ response = litellm.completion(
72
+ model=self.model_id,
73
+ messages=messages,
74
+ tools=formatted_tools,
75
+ temperature=0.1
76
+ )
77
+
78
+ message = response.choices[0].message
79
+ result = {
80
+ "content": message.content or ""
81
+ }
82
+
83
+ # Handle tool calls if present
84
+ if hasattr(message, 'tool_calls') and message.tool_calls:
85
+ result["tool_calls"] = [
86
+ {
87
+ "name": tc.function.name,
88
+ "arguments": eval(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments
89
+ }
90
+ for tc in message.tool_calls
91
+ ]
92
+ return result
93
+
94
+ except Exception as e:
95
+ print(f"Model error: {e}")
96
+ return {"content": "Unknown"}
97
+
98
+ def get_model(model_type: str, model_id: str):
99
+ """Get model instance"""
100
+ if model_type == "LiteLLMModel":
101
+ return LiteLLMModel(model_id)
102
+ else:
103
+ raise ValueError(f"Unknown model type: {model_type}")