sarveshpatel commited on
Commit
3e107da
·
verified ·
1 Parent(s): e89c2af

Create app/package_manager.py

Browse files
Files changed (1) hide show
  1. app/package_manager.py +146 -0
app/package_manager.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Package installation and checking for the configured environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ from pathlib import Path
8
+
9
+ from app.config import Settings, get_settings
10
+ from app.models import PackageCheckResult
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class PackageManager:
16
+ """Manages package installation and checking."""
17
+
18
+ def __init__(self, settings: Settings | None = None):
19
+ self._settings = settings or get_settings()
20
+ self._install_lock = asyncio.Lock()
21
+
22
+ async def install_packages(self, packages: list[str]) -> dict:
23
+ """Install packages into the configured environment."""
24
+ async with self._install_lock:
25
+ cmd = self._settings.get_pip_command() + packages
26
+ logger.info("Installing packages: %s", " ".join(packages))
27
+
28
+ try:
29
+ process = await asyncio.create_subprocess_exec(
30
+ *cmd,
31
+ stdout=asyncio.subprocess.PIPE,
32
+ stderr=asyncio.subprocess.PIPE,
33
+ )
34
+
35
+ stdout_bytes, stderr_bytes = await asyncio.wait_for(
36
+ process.communicate(),
37
+ timeout=300, # 5 minute timeout for installs
38
+ )
39
+
40
+ stdout = stdout_bytes.decode("utf-8", errors="replace")
41
+ stderr = stderr_bytes.decode("utf-8", errors="replace")
42
+ success = process.returncode == 0
43
+
44
+ if success:
45
+ logger.info("Successfully installed: %s", ", ".join(packages))
46
+ else:
47
+ logger.error("Failed to install packages: %s", stderr)
48
+
49
+ return {
50
+ "success": success,
51
+ "stdout": stdout,
52
+ "stderr": stderr,
53
+ "return_code": process.returncode,
54
+ "packages": packages,
55
+ }
56
+ except asyncio.TimeoutError:
57
+ return {
58
+ "success": False,
59
+ "stderr": "Package installation timed out after 300s",
60
+ "stdout": "",
61
+ "return_code": -1,
62
+ "packages": packages,
63
+ }
64
+ except Exception as e:
65
+ logger.exception("Error installing packages")
66
+ return {
67
+ "success": False,
68
+ "stderr": str(e),
69
+ "stdout": "",
70
+ "return_code": -1,
71
+ "packages": packages,
72
+ }
73
+
74
+ async def check_packages(self, packages: list[str]) -> list[PackageCheckResult]:
75
+ """Check if packages are installed."""
76
+ python_exec = self._settings.get_python_executable()
77
+
78
+ # Build a single script that checks all packages
79
+ check_script = "import importlib, json\nresults = {}\n"
80
+ for pkg in packages:
81
+ # Normalize package name for import
82
+ import_name = pkg.split("==")[0].split(">=")[0].split("<=")[0].split("[")[0]
83
+ import_name = import_name.replace("-", "_").lower()
84
+ check_script += f"""
85
+ try:
86
+ mod = importlib.import_module("{import_name}")
87
+ version = getattr(mod, "__version__", "unknown")
88
+ results["{pkg}"] = {{"installed": True, "version": version}}
89
+ except ImportError:
90
+ results["{pkg}"] = {{"installed": False, "version": None}}
91
+ """
92
+ check_script += "print(json.dumps(results))\n"
93
+
94
+ try:
95
+ if self._settings.env_type.value == "conda":
96
+ cmd = f"{python_exec} -c {repr(check_script)}"
97
+ process = await asyncio.create_subprocess_shell(
98
+ cmd,
99
+ stdout=asyncio.subprocess.PIPE,
100
+ stderr=asyncio.subprocess.PIPE,
101
+ )
102
+ else:
103
+ process = await asyncio.create_subprocess_exec(
104
+ python_exec,
105
+ "-c",
106
+ check_script,
107
+ stdout=asyncio.subprocess.PIPE,
108
+ stderr=asyncio.subprocess.PIPE,
109
+ )
110
+
111
+ stdout_bytes, _ = await asyncio.wait_for(
112
+ process.communicate(), timeout=30
113
+ )
114
+
115
+ import json
116
+ results_data = json.loads(stdout_bytes.decode("utf-8").strip())
117
+
118
+ return [
119
+ PackageCheckResult(
120
+ package=pkg,
121
+ installed=info["installed"],
122
+ version=info.get("version"),
123
+ )
124
+ for pkg, info in results_data.items()
125
+ ]
126
+ except Exception as e:
127
+ logger.exception("Error checking packages")
128
+ return [
129
+ PackageCheckResult(package=pkg, installed=False) for pkg in packages
130
+ ]
131
+
132
+
133
+ _manager: PackageManager | None = None
134
+
135
+
136
+ def get_package_manager() -> PackageManager:
137
+ global _manager
138
+ if _manager is None:
139
+ _manager = PackageManager()
140
+ return _manager
141
+
142
+
143
+ def reset_package_manager(settings: Settings | None = None) -> PackageManager:
144
+ global _manager
145
+ _manager = PackageManager(settings)
146
+ return _manager