Josedcape commited on
Commit
545e280
·
verified ·
1 Parent(s): ee4bd26

Update src/utils/utils.py

Browse files
Files changed (1) hide show
  1. src/utils/utils.py +142 -207
src/utils/utils.py CHANGED
@@ -1,207 +1,142 @@
1
- import base64
2
- import os
3
- import time
4
- from pathlib import Path
5
- from typing import Dict, Optional
6
-
7
- from langchain_anthropic import ChatAnthropic
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
- from langchain_ollama import ChatOllama
10
- from langchain_openai import AzureChatOpenAI, ChatOpenAI
11
- import gradio as gr
12
-
13
- from .llm import DeepSeekR1ChatOpenAI
14
-
15
- def get_llm_model(provider: str, **kwargs):
16
- """
17
- 获取LLM 模型
18
- :param provider: 模型类型
19
- :param kwargs:
20
- :return:
21
- """
22
- if provider == "anthropic":
23
- if not kwargs.get("base_url", ""):
24
- base_url = "https://api.anthropic.com"
25
- else:
26
- base_url = kwargs.get("base_url")
27
-
28
- if not kwargs.get("api_key", ""):
29
- api_key = os.getenv("ANTHROPIC_API_KEY", "")
30
- else:
31
- api_key = kwargs.get("api_key")
32
-
33
- return ChatAnthropic(
34
- model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
35
- temperature=kwargs.get("temperature", 0.0),
36
- base_url=base_url,
37
- api_key=api_key,
38
- )
39
- elif provider == "openai":
40
- if not kwargs.get("base_url", ""):
41
- base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
42
- else:
43
- base_url = kwargs.get("base_url")
44
-
45
- if not kwargs.get("api_key", ""):
46
- api_key = os.getenv("OPENAI_API_KEY", "")
47
- else:
48
- api_key = kwargs.get("api_key")
49
-
50
- return ChatOpenAI(
51
- model=kwargs.get("model_name", "gpt-4o"),
52
- temperature=kwargs.get("temperature", 0.0),
53
- base_url=base_url,
54
- api_key=api_key,
55
- )
56
- elif provider == "deepseek":
57
- if not kwargs.get("base_url", ""):
58
- base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
59
- else:
60
- base_url = kwargs.get("base_url")
61
-
62
- if not kwargs.get("api_key", ""):
63
- api_key = os.getenv("DEEPSEEK_API_KEY", "")
64
- else:
65
- api_key = kwargs.get("api_key")
66
-
67
- if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
68
- return DeepSeekR1ChatOpenAI(
69
- model=kwargs.get("model_name", "deepseek-reasoner"),
70
- temperature=kwargs.get("temperature", 0.0),
71
- base_url=base_url,
72
- api_key=api_key,
73
- )
74
- else:
75
- return ChatOpenAI(
76
- model=kwargs.get("model_name", "deepseek-chat"),
77
- temperature=kwargs.get("temperature", 0.0),
78
- base_url=base_url,
79
- api_key=api_key,
80
- )
81
- elif provider == "gemini":
82
- if not kwargs.get("api_key", ""):
83
- api_key = os.getenv("GOOGLE_API_KEY", "")
84
- else:
85
- api_key = kwargs.get("api_key")
86
- return ChatGoogleGenerativeAI(
87
- model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
88
- temperature=kwargs.get("temperature", 0.0),
89
- google_api_key=api_key,
90
- )
91
- elif provider == "ollama":
92
- return ChatOllama(
93
- model=kwargs.get("model_name", "qwen2.5:7b"),
94
- temperature=kwargs.get("temperature", 0.0),
95
- num_ctx=kwargs.get("num_ctx", 32000),
96
- base_url=kwargs.get("base_url", "http://localhost:11434"),
97
- )
98
- elif provider == "azure_openai":
99
- if not kwargs.get("base_url", ""):
100
- base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
101
- else:
102
- base_url = kwargs.get("base_url")
103
- if not kwargs.get("api_key", ""):
104
- api_key = os.getenv("AZURE_OPENAI_API_KEY", "")
105
- else:
106
- api_key = kwargs.get("api_key")
107
- return AzureChatOpenAI(
108
- model=kwargs.get("model_name", "gpt-4o"),
109
- temperature=kwargs.get("temperature", 0.0),
110
- api_version="2024-05-01-preview",
111
- azure_endpoint=base_url,
112
- api_key=api_key,
113
- )
114
- else:
115
- raise ValueError(f"Unsupported provider: {provider}")
116
-
117
- # Predefined model names for common providers
118
- model_names = {
119
- "anthropic": ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
120
- "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
121
- "deepseek": ["deepseek-chat", "deepseek-reasoner"],
122
- "gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
123
- "ollama": ["qwen2.5:7b", "llama2:7b"],
124
- "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
125
- }
126
-
127
- # Callback to update the model name dropdown based on the selected provider
128
- def update_model_dropdown(llm_provider, api_key=None, base_url=None):
129
- """
130
- Update the model name dropdown with predefined models for the selected provider.
131
- """
132
- # Use API keys from .env if not provided
133
- if not api_key:
134
- api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
135
- if not base_url:
136
- base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
137
-
138
- # Use predefined models for the selected provider
139
- if llm_provider in model_names:
140
- return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
141
- else:
142
- return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
143
-
144
- def encode_image(img_path):
145
- if not img_path:
146
- return None
147
- with open(img_path, "rb") as fin:
148
- image_data = base64.b64encode(fin.read()).decode("utf-8")
149
- return image_data
150
-
151
-
152
- def get_latest_files(directory: str, file_types: list = ['.webm', '.zip']) -> Dict[str, Optional[str]]:
153
- """Get the latest recording and trace files"""
154
- latest_files: Dict[str, Optional[str]] = {ext: None for ext in file_types}
155
-
156
- if not os.path.exists(directory):
157
- os.makedirs(directory, exist_ok=True)
158
- return latest_files
159
-
160
- for file_type in file_types:
161
- try:
162
- matches = list(Path(directory).rglob(f"*{file_type}"))
163
- if matches:
164
- latest = max(matches, key=lambda p: p.stat().st_mtime)
165
- # Only return files that are complete (not being written)
166
- if time.time() - latest.stat().st_mtime > 1.0:
167
- latest_files[file_type] = str(latest)
168
- except Exception as e:
169
- print(f"Error getting latest {file_type} file: {e}")
170
-
171
- return latest_files
172
- async def capture_screenshot(browser_context):
173
- """Capture and encode a screenshot"""
174
- # Extract the Playwright browser instance
175
- playwright_browser = browser_context.browser.playwright_browser # Ensure this is correct.
176
-
177
- # Check if the browser instance is valid and if an existing context can be reused
178
- if playwright_browser and playwright_browser.contexts:
179
- playwright_context = playwright_browser.contexts[0]
180
- else:
181
- return None
182
-
183
- # Access pages in the context
184
- pages = None
185
- if playwright_context:
186
- pages = playwright_context.pages
187
-
188
- # Use an existing page or create a new one if none exist
189
- if pages:
190
- active_page = pages[0]
191
- for page in pages:
192
- if page.url != "about:blank":
193
- active_page = page
194
- else:
195
- return None
196
-
197
- # Take screenshot
198
- try:
199
- screenshot = await active_page.screenshot(
200
- type='jpeg',
201
- quality=75,
202
- scale="css"
203
- )
204
- encoded = base64.b64encode(screenshot).decode('utf-8')
205
- return encoded
206
- except Exception as e:
207
- return None
 
1
+ import base64
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Dict, Optional
6
+
7
+ from langchain_anthropic import ChatAnthropic
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain_ollama import ChatOllama
10
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
11
+ import gradio as gr
12
+
13
+ from .llm import DeepSeekR1ChatOpenAI
14
+
15
+ def get_llm_model(provider: str, **kwargs):
16
+ if provider == "anthropic":
17
+ base_url = kwargs.get("base_url", "https://api.anthropic.com")
18
+ api_key = kwargs.get("api_key", os.getenv("ANTHROPIC_API_KEY", ""))
19
+ return ChatAnthropic(
20
+ model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
21
+ temperature=kwargs.get("temperature", 0.0),
22
+ base_url=base_url,
23
+ api_key=api_key,
24
+ )
25
+ elif provider == "openai":
26
+ base_url = kwargs.get("base_url", os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1"))
27
+ api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY", ""))
28
+ return ChatOpenAI(
29
+ model=kwargs.get("model_name", "gpt-4o"),
30
+ temperature=kwargs.get("temperature", 0.0),
31
+ base_url=base_url,
32
+ api_key=api_key,
33
+ )
34
+ elif provider == "deepseek":
35
+ base_url = kwargs.get("base_url", os.getenv("DEEPSEEK_ENDPOINT", ""))
36
+ api_key = kwargs.get("api_key", os.getenv("DEEPSEEK_API_KEY", ""))
37
+ model_name = kwargs.get("model_name", "deepseek-chat")
38
+ if model_name == "deepseek-reasoner":
39
+ return DeepSeekR1ChatOpenAI(
40
+ model=model_name,
41
+ temperature=kwargs.get("temperature", 0.0),
42
+ base_url=base_url,
43
+ api_key=api_key,
44
+ )
45
+ else:
46
+ return ChatOpenAI(
47
+ model=model_name,
48
+ temperature=kwargs.get("temperature", 0.0),
49
+ base_url=base_url,
50
+ api_key=api_key,
51
+ )
52
+ elif provider == "gemini":
53
+ api_key = kwargs.get("api_key", os.getenv("GOOGLE_API_KEY", ""))
54
+ return ChatGoogleGenerativeAI(
55
+ model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
56
+ temperature=kwargs.get("temperature", 0.0),
57
+ google_api_key=api_key,
58
+ )
59
+ elif provider == "ollama":
60
+ return ChatOllama(
61
+ model=kwargs.get("model_name", "qwen2.5:7b"),
62
+ temperature=kwargs.get("temperature", 0.0),
63
+ num_ctx=kwargs.get("num_ctx", 32000),
64
+ base_url=kwargs.get("base_url", "http://localhost:11434"),
65
+ )
66
+ elif provider == "azure_openai":
67
+ base_url = kwargs.get("base_url", os.getenv("AZURE_OPENAI_ENDPOINT", ""))
68
+ api_key = kwargs.get("api_key", os.getenv("AZURE_OPENAI_API_KEY", ""))
69
+ return AzureChatOpenAI(
70
+ model=kwargs.get("model_name", "gpt-4o"),
71
+ temperature=kwargs.get("temperature", 0.0),
72
+ api_version="2024-05-01-preview",
73
+ azure_endpoint=base_url,
74
+ api_key=api_key,
75
+ )
76
+ else:
77
+ raise ValueError(f"Unsupported provider: {provider}")
78
+
79
+ model_names = {
80
+ "anthropic": ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
81
+ "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
82
+ "deepseek": ["deepseek-chat", "deepseek-reasoner"],
83
+ "gemini": ["gemini-2.0-flash-exp", "gemini-1.5-flash-latest"],
84
+ "ollama": ["qwen2.5:7b", "llama2:7b"],
85
+ "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
86
+ }
87
+
88
+ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
89
+ if not api_key:
90
+ api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
91
+ if not base_url:
92
+ base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
93
+
94
+ if llm_provider in model_names:
95
+ return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
96
+ else:
97
+ return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
98
+
99
+ def encode_image(img_path):
100
+ if not img_path:
101
+ return None
102
+ with open(img_path, "rb") as fin:
103
+ return base64.b64encode(fin.read()).decode("utf-8")
104
+
105
+ def get_latest_files(directory: str, file_types: list = ['.webm', '.zip']) -> Dict[str, Optional[str]]:
106
+ latest_files = {ext: None for ext in file_types}
107
+
108
+ if not os.path.exists(directory):
109
+ os.makedirs(directory, exist_ok=True)
110
+ return latest_files
111
+
112
+ for file_type in file_types:
113
+ try:
114
+ matches = list(Path(directory).rglob(f"*{file_type}"))
115
+ if matches:
116
+ latest = max(matches, key=lambda p: p.stat().st_mtime)
117
+ if time.time() - latest.stat().st_mtime > 1.0:
118
+ latest_files[file_type] = str(latest)
119
+ except Exception as e:
120
+ print(f"Error getting latest {file_type} file: {e}")
121
+
122
+ return latest_files
123
+
124
+ def capture_screenshot(browser_context):
125
+ try:
126
+ playwright_browser = browser_context.browser.playwright_browser
127
+ if not playwright_browser or not playwright_browser.contexts:
128
+ return None
129
+
130
+ playwright_context = playwright_browser.contexts[0]
131
+ pages = playwright_context.pages if playwright_context else []
132
+
133
+ active_page = next((page for page in pages if page.url != "about:blank"), pages[0] if pages else None)
134
+
135
+ if not active_page:
136
+ return None
137
+
138
+ screenshot = active_page.screenshot(type='jpeg', quality=75, scale="css")
139
+ return base64.b64encode(screenshot).decode('utf-8')
140
+ except Exception as e:
141
+ print(f"Error capturing screenshot: {e}")
142
+ return None