AUXteam commited on
Commit
155a7ae
·
verified ·
1 Parent(s): a8a10e9

Upload folder using huggingface_hub

Browse files
src/magentic_ui/backend/__init__.py CHANGED
@@ -1,6 +1,2 @@
 
1
  from .database.db_manager import DatabaseManager
2
- from .datamodel import Team
3
- from .teammanager import TeamManager
4
- from ..version import __version__
5
-
6
- __all__ = ["DatabaseManager", "Team", "TeamManager", "__version__"]
 
1
+ # src/magentic_ui/backend/__init__.py
2
  from .database.db_manager import DatabaseManager
 
 
 
 
 
src/magentic_ui/backend/managers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/magentic_ui/backend/managers/__init__.py
src/magentic_ui/backend/managers/vllm_manager.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import signal
4
+ import subprocess
5
+ import time
6
+ from typing import Optional
7
+
8
+ from loguru import logger
9
+
10
+
11
+ class VLLMManager:
12
+ """Manages the lifecycle of a VLLM server process."""
13
+
14
+ def __init__(
15
+ self,
16
+ model_name: str = "yujiepan/ui-tars-1.5-7B-GPTQ-W4A16g128",
17
+ port: int = 5000,
18
+ host: str = "0.0.0.0",
19
+ gpu_memory_utilization: float = 0.9,
20
+ ) -> None:
21
+ self.model_name = model_name
22
+ self.port = port
23
+ self.host = host
24
+ self.gpu_memory_utilization = gpu_memory_utilization
25
+ self._process: Optional[subprocess.Popen] = None
26
+
27
+ async def start(self) -> None:
28
+ """Start the VLLM server process."""
29
+ if self._process is not None:
30
+ logger.warning("VLLM server is already running.")
31
+ return
32
+
33
+ cmd = [
34
+ "vllm",
35
+ "serve",
36
+ self.model_name,
37
+ "--port",
38
+ str(self.port),
39
+ "--host",
40
+ self.host,
41
+ "--gpu-memory-utilization",
42
+ str(self.gpu_memory_utilization),
43
+ "--dtype",
44
+ "auto",
45
+ "--trust-remote-code",
46
+ ]
47
+
48
+ logger.info(f"Starting VLLM server with command: {' '.join(cmd)}")
49
+ try:
50
+ self._process = subprocess.Popen(
51
+ cmd,
52
+ stdout=subprocess.PIPE,
53
+ stderr=subprocess.PIPE,
54
+ text=True,
55
+ preexec_fn=os.setsid, # Create a new process group
56
+ )
57
+ # Wait for the server to be ready
58
+ await self._wait_for_ready()
59
+ logger.info("VLLM server started successfully.")
60
+ except Exception as e:
61
+ logger.error(f"Failed to start VLLM server: {e}")
62
+ self.stop()
63
+ raise
64
+
65
+ async def _wait_for_ready(self, timeout: int = 300) -> None:
66
+ """Wait for the VLLM server to be ready."""
67
+ start_time = time.time()
68
+ while time.time() - start_time < timeout:
69
+ if self._process and self._process.poll() is not None:
70
+ raise RuntimeError("VLLM process exited unexpectedly.")
71
+
72
+ try:
73
+ # Check health endpoint
74
+ # In a real scenario, we would use `requests` or `httpx` to check /health
75
+ # For simplicity, we can just check if the port is open or rely on log output parsing
76
+ # Here we assume it takes some time to initialize
77
+ # A more robust check would involve HTTP request
78
+ await asyncio.sleep(5)
79
+ # Placeholder: assume ready after 10s for now or check logs
80
+ # In production, check `http://localhost:5000/health`
81
+ if self._process:
82
+ # Check if process is still running
83
+ if self._process.poll() is None:
84
+ # Optimistic assumption for this example
85
+ # Real implementation: requests.get(f"http://{self.host}:{self.port}/health")
86
+ return
87
+
88
+ except Exception:
89
+ pass
90
+ await asyncio.sleep(2)
91
+
92
+ raise TimeoutError("VLLM server failed to start within timeout.")
93
+
94
+ def stop(self) -> None:
95
+ """Stop the VLLM server process."""
96
+ if self._process:
97
+ logger.info("Stopping VLLM server...")
98
+ try:
99
+ os.killpg(os.getpgid(self._process.pid), signal.SIGTERM)
100
+ self._process.wait(timeout=10)
101
+ except Exception as e:
102
+ logger.warning(f"Error stopping VLLM server: {e}")
103
+ if self._process:
104
+ self._process.kill()
105
+ finally:
106
+ self._process = None
107
+ logger.info("VLLM server stopped.")
108
+
109
+ def is_running(self) -> bool:
110
+ """Check if the VLLM server is running."""
111
+ return self._process is not None and self._process.poll() is None
src/magentic_ui/backend/web/app.py CHANGED
@@ -24,11 +24,14 @@ from .routes import (
24
  ws,
25
  mcp,
26
  )
 
27
 
28
  # Initialize application
29
  app_file_path = os.path.dirname(os.path.abspath(__file__))
30
  initializer = AppInitializer(settings, app_file_path)
31
 
 
 
32
 
33
  @asynccontextmanager
34
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@@ -36,6 +39,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
36
  Lifecycle manager for the FastAPI application.
37
  Handles initialization and cleanup of application resources.
38
  """
 
39
 
40
  try:
41
  # Load the config if provided
@@ -51,6 +55,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
51
  if os.environ.get("FARA_AGENT") is not None:
52
  config["use_fara_agent"] = os.environ["FARA_AGENT"] == "True"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Initialize managers (DB, Connection, Team)
55
  await init_managers(
56
  initializer.database_uri,
@@ -78,6 +96,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
78
  try:
79
  logger.info("Cleaning up application resources...")
80
  await cleanup_managers()
 
 
81
  logger.info("Application shutdown complete")
82
  except Exception as e:
83
  logger.error(f"Error during shutdown: {str(e)}")
 
24
  ws,
25
  mcp,
26
  )
27
+ from ..managers.vllm_manager import VLLMManager
28
 
29
  # Initialize application
30
  app_file_path = os.path.dirname(os.path.abspath(__file__))
31
  initializer = AppInitializer(settings, app_file_path)
32
 
33
+ # Global VLLM Manager
34
+ vllm_manager = None
35
 
36
  @asynccontextmanager
37
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
 
39
  Lifecycle manager for the FastAPI application.
40
  Handles initialization and cleanup of application resources.
41
  """
42
+ global vllm_manager
43
 
44
  try:
45
  # Load the config if provided
 
55
  if os.environ.get("FARA_AGENT") is not None:
56
  config["use_fara_agent"] = os.environ["FARA_AGENT"] == "True"
57
 
58
+ # Initialize VLLM if configured
59
+ if os.environ.get("USE_LOCAL_VLLM") == "True":
60
+ try:
61
+ vllm_port = int(os.environ.get("VLLM_PORT", 5000))
62
+ vllm_model = os.environ.get("VLLM_MODEL", "yujiepan/ui-tars-1.5-7B-GPTQ-W4A16g128")
63
+ vllm_manager = VLLMManager(model_name=vllm_model, port=vllm_port)
64
+ await vllm_manager.start()
65
+ # Inject VLLM URL into config for agents to use
66
+ config["vllm_base_url"] = f"http://localhost:{vllm_port}"
67
+ except Exception as e:
68
+ logger.error(f"Failed to start VLLM manager: {e}")
69
+ # decide if we should fail hard or continue without vision
70
+ # raise e
71
+
72
  # Initialize managers (DB, Connection, Team)
73
  await init_managers(
74
  initializer.database_uri,
 
96
  try:
97
  logger.info("Cleaning up application resources...")
98
  await cleanup_managers()
99
+ if vllm_manager:
100
+ vllm_manager.stop()
101
  logger.info("Application shutdown complete")
102
  except Exception as e:
103
  logger.error(f"Error during shutdown: {str(e)}")
src/magentic_ui/backend/web/initialization.py CHANGED
@@ -1,4 +1,4 @@
1
- # api/initialization.py
2
  import os
3
  from pathlib import Path
4
 
@@ -7,6 +7,7 @@ from loguru import logger
7
  from pydantic import BaseModel
8
 
9
  from .config import Settings
 
10
 
11
 
12
  class _AppPaths(BaseModel):
 
1
+ # src/magentic_ui/backend/web/initialization.py
2
  import os
3
  from pathlib import Path
4
 
 
7
  from pydantic import BaseModel
8
 
9
  from .config import Settings
10
+ from ..managers.vllm_manager import VLLMManager
11
 
12
 
13
  class _AppPaths(BaseModel):
src/magentic_ui/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # src/magentic_ui/utils/__init__.py
2
+ from .utils import (
3
+ LLMCallFilter,
4
+ json_data_to_markdown,
5
+ dict_to_str,
6
+ thread_to_context,
7
+ get_internal_urls,
8
+ )
src/magentic_ui/utils/midscene_adapter.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, Any, List
3
+
4
+ class MidsceneAdapter:
5
+ """Adapts user prompts into VLLM/Midscene compatible payloads."""
6
+
7
+ def __init__(self, model_url: str = "http://localhost:5000"):
8
+ self.model_url = model_url
9
+
10
+ def format_prompt(self, user_instruction: str, screenshot_base64: str) -> Dict[str, Any]:
11
+ """Format the input for the VLLM/Midscene model.
12
+
13
+ This assumes the model accepts multimodal input (image + text) in a specific format.
14
+ Adjust the format based on the actual model's requirement (e.g., UI-TARS uses specific prompting).
15
+ """
16
+ # Example prompt structure for UI-TARS or similar vision models
17
+ prompt = {
18
+ "messages": [
19
+ {
20
+ "role": "user",
21
+ "content": [
22
+ {
23
+ "type": "image_url",
24
+ "image_url": {
25
+ "url": f"data:image/jpeg;base64,{screenshot_base64}"
26
+ }
27
+ },
28
+ {
29
+ "type": "text",
30
+ "text": f"Given the screenshot, perform the following action: {user_instruction}"
31
+ }
32
+ ]
33
+ }
34
+ ],
35
+ "temperature": 0.0,
36
+ "max_tokens": 1024
37
+ }
38
+ return prompt
39
+
40
+ def parse_response(self, response_data: Dict[str, Any]) -> List[Dict[str, Any]]:
41
+ """Parse the model's response into actionable commands."""
42
+ # Example: The model might return a JSON string or a specific text format like `click(100, 200)`
43
+ content = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
44
+
45
+ actions = []
46
+ try:
47
+ import re
48
+ # simple heuristic parsing, replace with robust logic based on model output
49
+ if "click" in content:
50
+ # extract coordinates
51
+ match = re.search(r"click\((\d+),\s*(\d+)\)", content)
52
+ if match:
53
+ x, y = map(int, match.groups())
54
+ actions.append({"type": "click", "x": x, "y": y})
55
+ # Check for type command - prioritize specific regex
56
+ elif "type" in content:
57
+ # Look for type('text') or type("text")
58
+ match = re.search(r"type\(['\"](.*?)['\"]\)", content)
59
+ if match:
60
+ text = match.group(1)
61
+ actions.append({"type": "type", "text": text})
62
+ else:
63
+ # Fallback: return as a raw thought or stop action
64
+ actions.append({"type": "stop", "reason": content})
65
+
66
+ except Exception as e:
67
+ actions.append({"type": "error", "message": str(e)})
68
+
69
+ return actions
src/magentic_ui/utils/utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import psutil
5
+ from typing import List, Union, Dict
6
+
7
+ from autogen_core.models import (
8
+ LLMMessage,
9
+ UserMessage,
10
+ AssistantMessage,
11
+ )
12
+
13
+ from autogen_agentchat.utils import remove_images
14
+ from autogen_agentchat.messages import (
15
+ BaseChatMessage,
16
+ BaseTextChatMessage,
17
+ HandoffMessage,
18
+ MultiModalMessage,
19
+ StopMessage,
20
+ TextMessage,
21
+ ToolCallRequestEvent,
22
+ ToolCallExecutionEvent,
23
+ BaseAgentEvent,
24
+ )
25
+
26
+ from ..types import HumanInputFormat, RunPaths
27
+
28
+
29
+ class LLMCallFilter(logging.Filter):
30
+ def filter(self, record: logging.LogRecord) -> bool:
31
+ try:
32
+ message = json.loads(record.getMessage())
33
+ return message.get("type") == "LLMCall"
34
+ except (json.JSONDecodeError, AttributeError):
35
+ return False
36
+
37
+
38
+ # Define recursive types for JSON structures
39
+ JsonPrimitive = Union[str, int, float, bool, None]
40
+ JsonList = List[Union[JsonPrimitive, "JsonDict", "JsonList"]]
41
+ JsonDict = Dict[str, Union[JsonPrimitive, JsonList, "JsonDict"]]
42
+ JsonData = Union[JsonDict, JsonList, str]
43
+
44
+
45
+ def json_data_to_markdown(data: JsonData) -> str:
46
+ """
47
+ Convert a dictionary, list, or JSON string to a nicely formatted Markdown string.
48
+ Handles nested structures of dictionaries and lists.
49
+
50
+ Args:
51
+ data (JsonData): The data to convert, can be:
52
+ - A dictionary with string keys and JSON-compatible values
53
+ - A list of JSON-compatible values
54
+ - A JSON string representing either of the above
55
+
56
+ Returns:
57
+ str: The formatted Markdown string.
58
+
59
+ Raises:
60
+ ValueError: If the input cannot be parsed or converted to markdown format.
61
+ json.JSONDecodeError: If the input string is not valid JSON.
62
+ """
63
+
64
+ def format_dict(d: JsonDict, indent: int = 0) -> str:
65
+ md = ""
66
+ for key, value in d.items():
67
+ md += " " * indent + f"- {key}: "
68
+ if isinstance(value, dict):
69
+ md += "\n" + format_dict(value, indent + 1)
70
+ elif isinstance(value, list):
71
+ md += "\n" + format_list(value, indent + 1)
72
+ else:
73
+ md += f"{value}\n"
74
+ return md
75
+
76
+ def format_list(lst: JsonList, indent: int = 0) -> str:
77
+ md = ""
78
+ for item in lst:
79
+ if isinstance(item, dict):
80
+ md += " " * indent + "- \n" + format_dict(item, indent + 1)
81
+ elif isinstance(item, list):
82
+ md += " " * indent + "- \n" + format_list(item, indent + 1)
83
+ else:
84
+ md += " " * indent + f"- {item}\n"
85
+ return md
86
+
87
+ try:
88
+ if isinstance(data, str):
89
+ data = json.loads(data)
90
+
91
+ if isinstance(data, list):
92
+ return format_list(data)
93
+ elif isinstance(data, dict):
94
+ return format_dict(data)
95
+ else:
96
+ raise ValueError(f"Expected dict, list or JSON string, got {type(data)}")
97
+
98
+ except json.JSONDecodeError as e:
99
+ raise json.JSONDecodeError(f"Invalid JSON string: {str(e)}", e.doc, e.pos)
100
+ except Exception as e:
101
+ raise ValueError(f"Failed to convert to markdown: {str(e)}")
102
+
103
+
104
+ def dict_to_str(data: Union[JsonDict, str]) -> str:
105
+ """
106
+ Convert a dictionary or JSON string to a JSON string.
107
+
108
+ Args:
109
+ data (JsonDict | str): The dictionary or JSON string to convert.
110
+
111
+ Returns:
112
+ str: The input dictionary in JSON format.
113
+ """
114
+ if isinstance(data, dict):
115
+ return json.dumps(data)
116
+ elif isinstance(data, str):
117
+ return data
118
+ else:
119
+ raise ValueError("Unexpected input type")
120
+
121
+
122
+ def thread_to_context(
123
+ messages: List[BaseAgentEvent | BaseChatMessage],
124
+ agent_name: str,
125
+ is_multimodal: bool = False,
126
+ ) -> List[LLMMessage]:
127
+ """Convert the message thread to a context for the model."""
128
+ context: List[LLMMessage] = []
129
+ for m in messages:
130
+ if isinstance(m, ToolCallRequestEvent | ToolCallExecutionEvent):
131
+ # Ignore tool call messages.
132
+ continue
133
+ elif isinstance(m, StopMessage | HandoffMessage):
134
+ context.append(UserMessage(content=m.content, source=m.source))
135
+ elif m.source == agent_name:
136
+ assert isinstance(m, TextMessage), f"{type(m)}"
137
+ context.append(AssistantMessage(content=m.content, source=m.source))
138
+ elif m.source == "user_proxy" or m.source == "user":
139
+ assert isinstance(m, TextMessage | MultiModalMessage), f"{type(m)}"
140
+ if isinstance(m.content, str):
141
+ human_input = HumanInputFormat.from_str(m.content)
142
+ content = f"{human_input.content}"
143
+ if human_input.plan is not None:
144
+ content += f"\n\nI created the following plan: {human_input.plan}"
145
+ context.append(UserMessage(content=content, source=m.source))
146
+ else:
147
+ # If content is a list, transform only the string part
148
+ content_list = list(m.content) # Create a copy of the list
149
+ for i, item in enumerate(content_list):
150
+ if isinstance(item, str):
151
+ human_input = HumanInputFormat.from_str(item)
152
+ content_list[i] = f"{human_input.content}"
153
+ if human_input.plan is not None and isinstance(
154
+ content_list[i], str
155
+ ):
156
+ content_list[i] = (
157
+ f"{content_list[i]}\n\nI created the following plan: {human_input.plan}"
158
+ )
159
+ context.append(UserMessage(content=content_list, source=m.source)) # type: ignore
160
+ else:
161
+ assert isinstance(m, BaseTextChatMessage) or isinstance(
162
+ m, MultiModalMessage
163
+ ), f"{type(m)}"
164
+ context.append(UserMessage(content=m.content, source=m.source))
165
+ if is_multimodal:
166
+ return context
167
+ else:
168
+ return remove_images(context)
169
+
170
+
171
+ def get_internal_urls(inside_docker: bool, paths: RunPaths) -> List[str] | None:
172
+ if not inside_docker:
173
+ return None
174
+ urls: List[str] = []
175
+ for _, addrs in psutil.net_if_addrs().items():
176
+ for addr in addrs:
177
+ if addr.family.name == "AF_INET":
178
+ urls.append(addr.address)
179
+
180
+ hostname = os.getenv("HOSTNAME")
181
+ if hostname is not None:
182
+ urls.append(hostname)
183
+ container_name = os.getenv("CONTAINER_NAME")
184
+ if container_name is not None:
185
+ urls.append(container_name)
186
+ return urls
tests/test_midscene_adapter.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ # Ensure src is in python path
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
8
+
9
+ from magentic_ui.utils.midscene_adapter import MidsceneAdapter
10
+
11
+ class TestMidsceneAdapter(unittest.TestCase):
12
+
13
+ def setUp(self):
14
+ self.adapter = MidsceneAdapter()
15
+
16
+ def test_format_prompt(self):
17
+ user_instruction = "Click the login button"
18
+ screenshot_base64 = "base64encodedimage"
19
+
20
+ prompt = self.adapter.format_prompt(user_instruction, screenshot_base64)
21
+
22
+ self.assertIn("messages", prompt)
23
+ messages = prompt["messages"]
24
+ self.assertEqual(len(messages), 1)
25
+ content = messages[0]["content"]
26
+
27
+ has_image = False
28
+ has_text = False
29
+
30
+ for item in content:
31
+ if item.get("type") == "image_url":
32
+ self.assertIn(screenshot_base64, item["image_url"]["url"])
33
+ has_image = True
34
+ if item.get("type") == "text":
35
+ self.assertIn(user_instruction, item["text"])
36
+ has_text = True
37
+
38
+ self.assertTrue(has_image)
39
+ self.assertTrue(has_text)
40
+
41
+ def test_parse_response_click(self):
42
+ mock_response = {
43
+ "choices": [{
44
+ "message": {
45
+ "content": "I see the button. click(150, 300)"
46
+ }
47
+ }]
48
+ }
49
+
50
+ actions = self.adapter.parse_response(mock_response)
51
+ self.assertEqual(len(actions), 1)
52
+ self.assertEqual(actions[0]["type"], "click")
53
+ self.assertEqual(actions[0]["x"], 150)
54
+ self.assertEqual(actions[0]["y"], 300)
55
+
56
+ def test_parse_response_type(self):
57
+ mock_response = {
58
+ "choices": [{
59
+ "message": {
60
+ "content": "type('hello world')"
61
+ }
62
+ }]
63
+ }
64
+
65
+ actions = self.adapter.parse_response(mock_response)
66
+ self.assertEqual(len(actions), 1)
67
+ self.assertEqual(actions[0]["type"], "type")
68
+ self.assertEqual(actions[0]["text"], "hello world")
69
+
70
+ if __name__ == '__main__':
71
+ unittest.main()
tests/test_vllm_manager.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import unittest
3
+ import os
4
+ import sys
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ # Ensure src is in python path
8
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
9
+
10
+ from magentic_ui.backend.managers.vllm_manager import VLLMManager
11
+
12
+ class TestVLLMManager(unittest.TestCase):
13
+
14
+ @patch("magentic_ui.backend.managers.vllm_manager.subprocess.Popen")
15
+ def test_vllm_manager_start(self, mock_popen):
16
+ manager = VLLMManager()
17
+
18
+ # Mock process
19
+ mock_process = MagicMock()
20
+ mock_process.poll.return_value = None # Process running
21
+ mock_popen.return_value = mock_process
22
+
23
+ async def run_test():
24
+ # Mock _wait_for_ready to avoid actual sleep loop in test
25
+ with patch.object(manager, '_wait_for_ready', new_callable=AsyncMock):
26
+ await manager.start()
27
+
28
+ # Helper for async test
29
+ class AsyncMock(MagicMock):
30
+ async def __call__(self, *args, **kwargs):
31
+ return super(AsyncMock, self).__call__(*args, **kwargs)
32
+
33
+ loop = asyncio.new_event_loop()
34
+ loop.run_until_complete(run_test())
35
+ loop.close()
36
+
37
+ self.assertTrue(manager.is_running())
38
+ mock_popen.assert_called_once()
39
+
40
+ def test_vllm_manager_stop(self):
41
+ manager = VLLMManager()
42
+ manager._process = MagicMock()
43
+ manager._process.pid = 12345
44
+
45
+ # Patch os.getpgid since that's called before killpg
46
+ with patch("os.killpg") as mock_killpg, \
47
+ patch("os.getpgid", return_value=12345):
48
+
49
+ manager.stop()
50
+ # The killpg call happens inside stop()
51
+ mock_killpg.assert_called()
52
+
53
+ self.assertIsNone(manager._process)
54
+
55
+ if __name__ == '__main__':
56
+ unittest.main()