| """
|
| SSH 隧道 + ADB 设备统一管理 (v3)
|
| ==============================
|
| 给 restore.py / sync.py 共享, 解决以下并发/健壮性问题:
|
|
|
| 1) 动态本地端口 - 忽略 SSH 命令里的 -L 写死端口, 用 OS 分配
|
| 彻底避免不同云真机端口撞车
|
| 2) SSH keep-alive - paramiko Transport.set_keepalive, 默认 30s 一次
|
| 云真机平台会 kill 掉空闲 SSH, 心跳解决
|
| 3) 自动重连 - 命令检测到 'device offline' 等错误后重建隧道+adb,
|
| 重试 1 次. 不会无限循环
|
| 4) 硬超时兜底 - 全局看门狗线程, 超过 HKDY_HARD_TIMEOUT 秒的 session
|
| 强制清理 (防任务卡死持有资源)
|
| 5) 并发安全清理 - 只做 adb disconnect localhost:<自己的端口>,
|
| 绝不 adb kill-server (会踩死其他并发用户)
|
| 6) 进程退出兜底 - atexit 注册, 防脚本崩溃泄漏 SSH 连接
|
|
|
| 环境变量 (.env 或 os.environ):
|
| HKDY_HARD_TIMEOUT 任务硬超时秒数 默认 180
|
| HKDY_SSH_KEEPALIVE SSH keep-alive 秒数 默认 30
|
| HKDY_ADB_READY_TIMEOUT adb ready 等待秒数 默认 30
|
| HKDY_TUNNEL_BIND_RETRIES 绑定本地端口重试次数 默认 3
|
| HKDY_WATCHDOG_INTERVAL 看门狗扫描间隔秒数 默认 5
|
|
|
| 典型用法:
|
| with Session(raw_input_text, log=log_fn, hard_timeout=180) as sess:
|
| sess.sh("am force-stop ...")
|
| sess.push("local.xml", "/data/local/tmp/x.xml")
|
| sess.pull("/sdcard/verify.png", "./verify.png")
|
| # 退出 with 块时自动 close (即使异常也会清理)
|
| """
|
|
|
| import atexit
|
| import os
|
| import re
|
| import select
|
| import socket
|
| import subprocess
|
| import sys
|
| import threading
|
| import time
|
|
|
| _HERE = os.path.dirname(os.path.abspath(__file__))
|
| _LIB = os.path.join(_HERE, "lib")
|
| if os.path.isdir(_LIB) and _LIB not in sys.path:
|
| sys.path.insert(0, _LIB)
|
|
|
|
|
| try:
|
| from dotenv import load_dotenv
|
| load_dotenv(os.path.join(_HERE, ".env"))
|
| except Exception:
|
| pass
|
|
|
| import paramiko
|
|
|
|
|
|
|
|
|
| def _envint(name, default):
|
| try:
|
| return int(os.environ.get(name, str(default)))
|
| except (ValueError, TypeError):
|
| return default
|
|
|
|
|
| HARD_TIMEOUT = _envint("HKDY_HARD_TIMEOUT", 180)
|
| SSH_KEEPALIVE = _envint("HKDY_SSH_KEEPALIVE", 30)
|
| ADB_READY_TIMEOUT = _envint("HKDY_ADB_READY_TIMEOUT", 30)
|
| TUNNEL_BIND_RETRIES = _envint("HKDY_TUNNEL_BIND_RETRIES", 3)
|
| WATCHDOG_INTERVAL = _envint("HKDY_WATCHDOG_INTERVAL", 5)
|
|
|
| _CREATE_NO_WINDOW = 0x08000000 if os.name == "nt" else 0
|
|
|
|
|
|
|
|
|
| class _Registry:
|
| def __init__(self):
|
| self._sessions = set()
|
| self._lock = threading.Lock()
|
| self._watchdog_started = False
|
|
|
| def register(self, s):
|
| with self._lock:
|
| self._sessions.add(s)
|
| if not self._watchdog_started:
|
| self._watchdog_started = True
|
| t = threading.Thread(
|
| target=self._watchdog,
|
| daemon=True, name="hkdy-watchdog",
|
| )
|
| t.start()
|
|
|
| def unregister(self, s):
|
| with self._lock:
|
| self._sessions.discard(s)
|
|
|
| def snapshot(self):
|
| with self._lock:
|
| return list(self._sessions)
|
|
|
| def _watchdog(self):
|
| """5 秒扫一次所有 session, 超时的强制清理."""
|
| while True:
|
| time.sleep(WATCHDOG_INTERVAL)
|
| now = time.time()
|
| for s in self.snapshot():
|
| try:
|
| if s._started_at is None:
|
| continue
|
| elapsed = now - s._started_at
|
| if elapsed > s.hard_timeout and not s._timeout_fired:
|
| s._on_timeout(elapsed)
|
| except Exception:
|
| pass
|
|
|
|
|
| _registry = _Registry()
|
|
|
|
|
|
|
|
|
| def parse_ssh_command(cmd):
|
| """从 'ssh -o... USER@HOST -p PORT -L LOCAL:rhost:REMOTE -Nf' 提取 6 元组.
|
| 注: LOCAL 字段会被 Session 忽略 (改用动态分配), 仅作为偏好值."""
|
| m_uh = re.search(r"\s([A-Za-z0-9_.\-]+@[A-Za-z0-9_.\-]+)", cmd)
|
| m_p = re.search(r"-p\s+(\d+)", cmd)
|
| m_l = re.search(r"-L\s+(\d+):([^\s:]+):(\d+)", cmd)
|
| if not (m_uh and m_p and m_l):
|
| raise ValueError("SSH 命令缺少 user@host / -p / -L 字段")
|
| u, h = m_uh.group(1).split("@")
|
| return (u, h, int(m_p.group(1)),
|
| int(m_l.group(1)), m_l.group(2), int(m_l.group(3)))
|
|
|
|
|
|
|
|
|
| def _pipe(chan, sock):
|
| try:
|
| while True:
|
| r, _, _ = select.select([sock, chan], [], [])
|
| if sock in r:
|
| d = sock.recv(65536)
|
| if not d:
|
| break
|
| chan.send(d)
|
| if chan in r:
|
| d = chan.recv(65536)
|
| if not d:
|
| break
|
| sock.send(d)
|
| except Exception:
|
| pass
|
| finally:
|
| try: chan.close()
|
| except Exception: pass
|
| try: sock.close()
|
| except Exception: pass
|
|
|
|
|
| class _Tunnel:
|
| def __init__(self, host, port, user, password, rhost, rport,
|
| preferred_local=0):
|
| self.host, self.port = host, port
|
| self.user, self.password = user, password
|
| self.rhost, self.rport = rhost, rport
|
| self.preferred_local = preferred_local
|
| self.local_port = None
|
| self.client = None
|
| self.listener = None
|
| self._stop = False
|
|
|
| def start(self, timeout=15):
|
| self.client = paramiko.SSHClient()
|
| self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
| self.client.connect(
|
| self.host, port=self.port,
|
| username=self.user, password=self.password,
|
| allow_agent=False, look_for_keys=False, timeout=timeout,
|
| )
|
| transport = self.client.get_transport()
|
| if transport and SSH_KEEPALIVE > 0:
|
| transport.set_keepalive(SSH_KEEPALIVE)
|
|
|
|
|
| bound = False
|
| last_err = None
|
| for attempt in range(TUNNEL_BIND_RETRIES):
|
| try_port = self.preferred_local if attempt == 0 else 0
|
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| try:
|
| sock.bind(("127.0.0.1", try_port))
|
| self.listener = sock
|
| bound = True
|
| break
|
| except OSError as e:
|
| last_err = e
|
| try: sock.close()
|
| except Exception: pass
|
| if not bound:
|
| raise RuntimeError(f"无法绑定本地端口: {last_err}")
|
|
|
| self.local_port = self.listener.getsockname()[1]
|
| self.listener.listen(100)
|
| threading.Thread(
|
| target=self._accept, args=(transport,),
|
| daemon=True, name=f"hkdy-tunnel-{self.local_port}",
|
| ).start()
|
|
|
| def _accept(self, transport):
|
| while not self._stop:
|
| try:
|
| sock, addr = self.listener.accept()
|
| except OSError:
|
| return
|
| try:
|
| chan = transport.open_channel(
|
| "direct-tcpip", (self.rhost, self.rport), addr,
|
| )
|
| except Exception:
|
| sock.close()
|
| continue
|
| threading.Thread(
|
| target=_pipe, args=(chan, sock), daemon=True,
|
| ).start()
|
|
|
| def close(self):
|
| self._stop = True
|
| for x in (self.listener, self.client):
|
| try: x and x.close()
|
| except Exception: pass
|
|
|
|
|
|
|
|
|
| def _run(argv, timeout=120):
|
| return subprocess.run(
|
| argv, capture_output=True,
|
| encoding="utf-8", errors="replace",
|
| timeout=timeout, creationflags=_CREATE_NO_WINDOW,
|
| )
|
|
|
|
|
|
|
|
|
| class Session:
|
| """托管一次任务的 SSH 隧道 + adb 设备.
|
| 线程安全, 支持 with, 带硬超时+自动重连, 并发安全清理."""
|
|
|
| def __init__(self, raw_input_text, *, log=None, hard_timeout=None,
|
| preferred_local=None):
|
| ssh_cmd, password = self._parse_input(raw_input_text)
|
| u, h, p, lp_pref, rh, rp = parse_ssh_command(ssh_cmd)
|
| self._cred = (h, p, u, password, rh, rp)
|
|
|
|
|
|
|
| if preferred_local is None:
|
| preferred_local = 0
|
| self._preferred_local = preferred_local
|
| self.log = log or (lambda _m: None)
|
| self.hard_timeout = hard_timeout if hard_timeout is not None else HARD_TIMEOUT
|
|
|
| self._tunnel = None
|
| self._lock = threading.Lock()
|
| self._closed = False
|
| self._started_at = None
|
| self._timeout_fired = False
|
|
|
|
|
| @staticmethod
|
| def _parse_input(raw):
|
| lines = [l.strip() for l in (raw or "").splitlines() if l.strip()]
|
| ssh = next((l for l in lines if l.lower().startswith("ssh ")), None)
|
| pwd = next(
|
| (l for l in lines if not l.lower().startswith(("ssh ", "adb "))),
|
| None,
|
| )
|
| if not ssh:
|
| raise ValueError("未找到 SSH 命令 (应以 'ssh' 开头)")
|
| if not pwd:
|
| raise ValueError("未找到密码 (不以 ssh/adb 开头的那一行)")
|
| return ssh, pwd
|
|
|
|
|
| def __enter__(self):
|
| self.start()
|
| return self
|
|
|
| def __exit__(self, exc_type, exc, tb):
|
| self.close()
|
| return False
|
|
|
|
|
|
|
| def start(self, timeout=15):
|
| with self._lock:
|
| if self._started_at is not None:
|
| return
|
| host, port, user, pwd, rhost, rport = self._cred
|
| self._tunnel = _Tunnel(
|
| host, port, user, pwd, rhost, rport,
|
| preferred_local=self._preferred_local,
|
| )
|
| try:
|
| self._tunnel.start(timeout=timeout)
|
| except paramiko.AuthenticationException:
|
| self._tunnel = None
|
| raise RuntimeError("SSH 认证失败: 用户名或密码错误")
|
| except Exception as e:
|
| self._tunnel = None
|
| raise RuntimeError(f"SSH 连接失败: {e}")
|
| time.sleep(0.5)
|
| self._started_at = time.time()
|
| _registry.register(self)
|
| try:
|
| self._adb_ready()
|
| except Exception:
|
| self.close()
|
| raise
|
|
|
| def _adb_ready(self):
|
| """把 localhost:local_port 挂进 adb server, 并完成 adb root.
|
| 绝不 kill-server (共享, 会影响其他并发用户)."""
|
| port = self._tunnel.local_port
|
| serial = f"localhost:{port}"
|
| deadline = time.time() + ADB_READY_TIMEOUT
|
| ok = False
|
| while time.time() < deadline:
|
| _run(["adb", "connect", serial], timeout=10)
|
| time.sleep(0.8)
|
| r = _run(["adb", "-s", serial, "get-state"], timeout=5)
|
| if (r.stdout or "").strip() == "device":
|
| ok = True
|
| break
|
| _run(["adb", "disconnect", serial], timeout=5)
|
| time.sleep(0.8)
|
| if not ok:
|
| raise RuntimeError(f"adb 无法进入 device 状态: {serial}")
|
|
|
| _run(["adb", "-s", serial, "root"], timeout=20)
|
| time.sleep(1.5)
|
|
|
| deadline2 = time.time() + 10
|
| while time.time() < deadline2:
|
| _run(["adb", "connect", serial], timeout=10)
|
| time.sleep(0.5)
|
| r = _run(["adb", "-s", serial, "get-state"], timeout=5)
|
| if (r.stdout or "").strip() == "device":
|
| return
|
| raise RuntimeError(f"adb root 后重连失败: {serial}")
|
|
|
| def _on_timeout(self, elapsed):
|
| with self._lock:
|
| if self._timeout_fired or self._closed:
|
| return
|
| self._timeout_fired = True
|
| try:
|
| self.log(f"⚠ 硬超时 {elapsed:.1f}s (上限 {self.hard_timeout}s), 强制清理")
|
| except Exception:
|
| pass
|
| self._hard_close()
|
|
|
| def _hard_close(self):
|
| """并发安全. 只清自己的 serial + 自己的 tunnel. 不 kill-server."""
|
| with self._lock:
|
| if self._closed:
|
| return
|
| self._closed = True
|
| port = self._tunnel.local_port if self._tunnel else None
|
| _registry.unregister(self)
|
| if port is not None:
|
| try:
|
| _run(["adb", "disconnect", f"localhost:{port}"], timeout=5)
|
| except Exception:
|
| pass
|
| if self._tunnel:
|
| try:
|
| self._tunnel.close()
|
| except Exception:
|
| pass
|
|
|
| def close(self):
|
| self._hard_close()
|
|
|
|
|
|
|
| @property
|
| def serial(self):
|
| if self._closed:
|
| raise RuntimeError("Session 已关闭")
|
| if not self._tunnel:
|
| raise RuntimeError("Session 未启动")
|
| return f"localhost:{self._tunnel.local_port}"
|
|
|
| @property
|
| def local_port(self):
|
| return self._tunnel.local_port if self._tunnel else None
|
|
|
| @property
|
| def timed_out(self):
|
| return self._timeout_fired
|
|
|
|
|
|
|
| @staticmethod
|
| def _is_offline_error(msg):
|
| m = (msg or "").lower()
|
| keys = (
|
| "offline", "not found", "closed",
|
| "connection refused", "cannot connect", "failed to connect",
|
| "device not found", "error: no devices",
|
| )
|
| return any(k in m for k in keys)
|
|
|
| def _reconnect_once(self):
|
| self.log(" ⚠ 连接异常, 重建 SSH+adb...")
|
| with self._lock:
|
| if self._closed:
|
| raise RuntimeError("Session 已关闭, 无法重连")
|
| old_port = self._tunnel.local_port if self._tunnel else None
|
| if old_port:
|
| try:
|
| _run(["adb", "disconnect", f"localhost:{old_port}"], timeout=5)
|
| except Exception:
|
| pass
|
| if self._tunnel:
|
| try: self._tunnel.close()
|
| except Exception: pass
|
| host, port, user, pwd, rhost, rport = self._cred
|
| self._tunnel = _Tunnel(
|
| host, port, user, pwd, rhost, rport,
|
| preferred_local=0,
|
| )
|
| self._tunnel.start()
|
| time.sleep(0.5)
|
| self._adb_ready()
|
| self.log(f" 已重连 serial={self.serial}")
|
|
|
|
|
|
|
| def _check_timeout_fired(self):
|
| if self._timeout_fired:
|
| raise RuntimeError(
|
| f"任务硬超时 (>{self.hard_timeout}s), 已被看门狗清理"
|
| )
|
|
|
| def sh(self, cmd, check=True, timeout=120, auto_reconnect=True):
|
| """adb shell 包装. offline 时自动重连 1 次."""
|
| last_err = None
|
| for attempt in range(2 if auto_reconnect else 1):
|
| if self._closed:
|
| self._check_timeout_fired()
|
| raise RuntimeError("Session 已关闭")
|
| r = _run(["adb", "-s", self.serial, "shell", cmd], timeout=timeout)
|
| if r.returncode == 0:
|
| return r.stdout or ""
|
| err = (r.stderr or "") + (r.stdout or "")
|
| last_err = err
|
| if auto_reconnect and attempt == 0 and self._is_offline_error(err):
|
| try:
|
| self._reconnect_once()
|
| except Exception as e:
|
| raise RuntimeError(
|
| f"adb shell 失败 (重连失败): {cmd}\n{err}\n重连错误: {e}"
|
| )
|
| continue
|
| if check:
|
| raise RuntimeError(f"adb shell 失败:\n{cmd}\n{err}")
|
| return r.stdout or ""
|
| if check:
|
| raise RuntimeError(
|
| f"adb shell 失败 (重试后仍失败):\n{cmd}\n{last_err}"
|
| )
|
| return ""
|
|
|
| def push(self, local, remote, timeout=1800):
|
| for attempt in range(2):
|
| if self._closed:
|
| self._check_timeout_fired()
|
| raise RuntimeError("Session 已关闭")
|
| r = _run(["adb", "-s", self.serial, "push", local, remote],
|
| timeout=timeout)
|
| if r.returncode == 0:
|
| return r.stdout
|
| err = (r.stderr or "") + (r.stdout or "")
|
| if attempt == 0 and self._is_offline_error(err):
|
| try:
|
| self._reconnect_once()
|
| except Exception as e:
|
| raise RuntimeError(f"adb push 失败 (重连失败): {err}\n重连: {e}")
|
| continue
|
| raise RuntimeError(f"adb push 失败: {err}")
|
|
|
| def pull(self, remote, local, timeout=300):
|
| for attempt in range(2):
|
| if self._closed:
|
| self._check_timeout_fired()
|
| raise RuntimeError("Session 已关闭")
|
| r = _run(["adb", "-s", self.serial, "pull", remote, local],
|
| timeout=timeout)
|
| if r.returncode == 0:
|
| return r.stdout
|
| err = (r.stderr or "") + (r.stdout or "")
|
| if attempt == 0 and self._is_offline_error(err):
|
| try:
|
| self._reconnect_once()
|
| except Exception as e:
|
| raise RuntimeError(f"adb pull 失败 (重连失败): {err}\n重连: {e}")
|
| continue
|
| raise RuntimeError(f"adb pull 失败: {err}")
|
|
|
|
|
|
|
|
|
| def _atexit_cleanup():
|
| for s in _registry.snapshot():
|
| try: s.close()
|
| except Exception: pass
|
|
|
|
|
| atexit.register(_atexit_cleanup)
|
|
|