fengmiguoji commited on
Commit
351da5e
·
verified ·
1 Parent(s): 440ba6c

Upload 3 files

Browse files
src/utils/__init__.py CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2025/1/1
3
+ # @Author : wenshao
4
+ # @Email : wenshaoguo1026@gmail.com
5
+ # @Project : browser-use-webui
6
+ # @FileName: __init__.py.py
src/utils/agent_state.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ class AgentState:
4
+ _instance = None
5
+
6
+ def __init__(self):
7
+ if not hasattr(self, '_stop_requested'):
8
+ self._stop_requested = asyncio.Event()
9
+ self.last_valid_state = None # store the last valid browser state
10
+
11
+ def __new__(cls):
12
+ if cls._instance is None:
13
+ cls._instance = super(AgentState, cls).__new__(cls)
14
+ return cls._instance
15
+
16
+ def request_stop(self):
17
+ self._stop_requested.set()
18
+
19
+ def clear_stop(self):
20
+ self._stop_requested.clear()
21
+ self.last_valid_state = None
22
+
23
+ def is_stop_requested(self):
24
+ return self._stop_requested.is_set()
25
+
26
+ def set_last_valid_state(self, state):
27
+ self.last_valid_state = state
28
+
29
+ def get_last_valid_state(self):
30
+ return self.last_valid_state
src/utils/utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2025/1/1
3
+ # @Author : wenshao
4
+ # @Email : wenshaoguo1026@gmail.com
5
+ # @Project : browser-use-webui
6
+ # @FileName: utils.py
7
+ import base64
8
+ import os
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Dict, Optional
12
+
13
+ from langchain_anthropic import ChatAnthropic
14
+ from langchain_google_genai import ChatGoogleGenerativeAI
15
+ from langchain_ollama import ChatOllama
16
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
17
+ import gradio as gr
18
+
19
+ def get_llm_model(provider: str, **kwargs):
20
+ """
21
+ 获取LLM 模型
22
+ :param provider: 模型类型
23
+ :param kwargs:
24
+ :return:
25
+ """
26
+ if provider == "anthropic":
27
+ if not kwargs.get("base_url", ""):
28
+ base_url = "https://api.anthropic.com"
29
+ else:
30
+ base_url = kwargs.get("base_url")
31
+
32
+ if not kwargs.get("api_key", ""):
33
+ api_key = os.getenv("ANTHROPIC_API_KEY", "")
34
+ else:
35
+ api_key = kwargs.get("api_key")
36
+
37
+ return ChatAnthropic(
38
+ model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
39
+ temperature=kwargs.get("temperature", 0.0),
40
+ base_url=base_url,
41
+ api_key=api_key,
42
+ )
43
+ elif provider == "openai":
44
+ if not kwargs.get("base_url", ""):
45
+ base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
46
+ else:
47
+ base_url = kwargs.get("base_url")
48
+
49
+ if not kwargs.get("api_key", ""):
50
+ api_key = os.getenv("OPENAI_API_KEY", "")
51
+ else:
52
+ api_key = kwargs.get("api_key")
53
+
54
+ return ChatOpenAI(
55
+ model=kwargs.get("model_name", "gpt-4o"),
56
+ temperature=kwargs.get("temperature", 0.0),
57
+ base_url=base_url,
58
+ api_key=api_key,
59
+ )
60
+ elif provider == "deepseek":
61
+ if not kwargs.get("base_url", ""):
62
+ base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
63
+ else:
64
+ base_url = kwargs.get("base_url")
65
+
66
+ if not kwargs.get("api_key", ""):
67
+ api_key = os.getenv("DEEPSEEK_API_KEY", "")
68
+ else:
69
+ api_key = kwargs.get("api_key")
70
+
71
+ return ChatOpenAI(
72
+ model=kwargs.get("model_name", "deepseek-chat"),
73
+ temperature=kwargs.get("temperature", 0.0),
74
+ base_url=base_url,
75
+ api_key=api_key,
76
+ )
77
+ elif provider == "gemini":
78
+ if not kwargs.get("api_key", ""):
79
+ api_key = os.getenv("GOOGLE_API_KEY", "")
80
+ else:
81
+ api_key = kwargs.get("api_key")
82
+ return ChatGoogleGenerativeAI(
83
+ model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
84
+ temperature=kwargs.get("temperature", 0.0),
85
+ google_api_key=api_key,
86
+ )
87
+ elif provider == "ollama":
88
+ return ChatOllama(
89
+ model=kwargs.get("model_name", "qwen2.5:7b"),
90
+ temperature=kwargs.get("temperature", 0.0),
91
+ num_ctx=128000,
92
+ base_url=kwargs.get("base_url", "http://localhost:11434"),
93
+ )
94
+ elif provider == "azure_openai":
95
+ if not kwargs.get("base_url", ""):
96
+ base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
97
+ else:
98
+ base_url = kwargs.get("base_url")
99
+ if not kwargs.get("api_key", ""):
100
+ api_key = os.getenv("AZURE_OPENAI_API_KEY", "")
101
+ else:
102
+ api_key = kwargs.get("api_key")
103
+ return AzureChatOpenAI(
104
+ model=kwargs.get("model_name", "gpt-4o"),
105
+ temperature=kwargs.get("temperature", 0.0),
106
+ api_version="2024-05-01-preview",
107
+ azure_endpoint=base_url,
108
+ api_key=api_key,
109
+ )
110
+ else:
111
+ raise ValueError(f"Unsupported provider: {provider}")
112
+
113
+ # Predefined model names for common providers
114
+ model_names = {
115
+ "anthropic": ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
116
+ "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
117
+ "deepseek": ["deepseek-chat"],
118
+ "gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
119
+ "ollama": ["qwen2.5:7b", "llama2:7b"],
120
+ "azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
121
+ }
122
+
123
+ # Callback to update the model name dropdown based on the selected provider
124
+ def update_model_dropdown(llm_provider, api_key=None, base_url=None):
125
+ """
126
+ Update the model name dropdown with predefined models for the selected provider.
127
+ """
128
+ # Use API keys from .env if not provided
129
+ if not api_key:
130
+ api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
131
+ if not base_url:
132
+ base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
133
+
134
+ # Use predefined models for the selected provider
135
+ if llm_provider in model_names:
136
+ return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
137
+ else:
138
+ return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
139
+
140
+ def encode_image(img_path):
141
+ if not img_path:
142
+ return None
143
+ with open(img_path, "rb") as fin:
144
+ image_data = base64.b64encode(fin.read()).decode("utf-8")
145
+ return image_data
146
+
147
+
148
+ def get_latest_files(directory: str, file_types: list = ['.webm', '.zip']) -> Dict[str, Optional[str]]:
149
+ """Get the latest recording and trace files"""
150
+ latest_files: Dict[str, Optional[str]] = {ext: None for ext in file_types}
151
+
152
+ if not os.path.exists(directory):
153
+ os.makedirs(directory, exist_ok=True)
154
+ return latest_files
155
+
156
+ for file_type in file_types:
157
+ try:
158
+ matches = list(Path(directory).rglob(f"*{file_type}"))
159
+ if matches:
160
+ latest = max(matches, key=lambda p: p.stat().st_mtime)
161
+ # Only return files that are complete (not being written)
162
+ if time.time() - latest.stat().st_mtime > 1.0:
163
+ latest_files[file_type] = str(latest)
164
+ except Exception as e:
165
+ print(f"Error getting latest {file_type} file: {e}")
166
+
167
+ return latest_files
168
+ async def capture_screenshot(browser_context):
169
+ """Capture and encode a screenshot"""
170
+ # Extract the Playwright browser instance
171
+ playwright_browser = browser_context.browser.playwright_browser # Ensure this is correct.
172
+
173
+ # Check if the browser instance is valid and if an existing context can be reused
174
+ if playwright_browser and playwright_browser.contexts:
175
+ playwright_context = playwright_browser.contexts[0]
176
+ else:
177
+ return None
178
+
179
+ # Access pages in the context
180
+ pages = None
181
+ if playwright_context:
182
+ pages = playwright_context.pages
183
+
184
+ # Use an existing page or create a new one if none exist
185
+ if pages:
186
+ active_page = pages[0]
187
+ for page in pages:
188
+ if page.url != "about:blank":
189
+ active_page = page
190
+ else:
191
+ return None
192
+
193
+ # Take screenshot
194
+ try:
195
+ screenshot = await active_page.screenshot(
196
+ type='jpeg',
197
+ quality=75,
198
+ scale="css"
199
+ )
200
+ encoded = base64.b64encode(screenshot).decode('utf-8')
201
+ return encoded
202
+ except Exception as e:
203
+ return None