""" SSH 隧道 + ADB 设备统一管理 (v3) ============================== 给 restore.py / sync.py 共享, 解决以下并发/健壮性问题: 1) 动态本地端口 - 忽略 SSH 命令里的 -L 写死端口, 用 OS 分配 彻底避免不同云真机端口撞车 2) SSH keep-alive - paramiko Transport.set_keepalive, 默认 30s 一次 云真机平台会 kill 掉空闲 SSH, 心跳解决 3) 自动重连 - 命令检测到 'device offline' 等错误后重建隧道+adb, 重试 1 次. 不会无限循环 4) 硬超时兜底 - 全局看门狗线程, 超过 HKDY_HARD_TIMEOUT 秒的 session 强制清理 (防任务卡死持有资源) 5) 并发安全清理 - 只做 adb disconnect localhost:<自己的端口>, 绝不 adb kill-server (会踩死其他并发用户) 6) 进程退出兜底 - atexit 注册, 防脚本崩溃泄漏 SSH 连接 环境变量 (.env 或 os.environ): HKDY_HARD_TIMEOUT 任务硬超时秒数 默认 180 HKDY_SSH_KEEPALIVE SSH keep-alive 秒数 默认 30 HKDY_ADB_READY_TIMEOUT adb ready 等待秒数 默认 30 HKDY_TUNNEL_BIND_RETRIES 绑定本地端口重试次数 默认 3 HKDY_WATCHDOG_INTERVAL 看门狗扫描间隔秒数 默认 5 典型用法: with Session(raw_input_text, log=log_fn, hard_timeout=180) as sess: sess.sh("am force-stop ...") sess.push("local.xml", "/data/local/tmp/x.xml") sess.pull("/sdcard/verify.png", "./verify.png") # 退出 with 块时自动 close (即使异常也会清理) """ import atexit import os import re import select import socket import subprocess import sys import threading import time _HERE = os.path.dirname(os.path.abspath(__file__)) _LIB = os.path.join(_HERE, "lib") if os.path.isdir(_LIB) and _LIB not in sys.path: sys.path.insert(0, _LIB) # 尝试自动加载 .env (有 python-dotenv 就用, 没有就依赖 os.environ) try: from dotenv import load_dotenv load_dotenv(os.path.join(_HERE, ".env")) except Exception: pass import paramiko # ============ 配置 ============ def _envint(name, default): try: return int(os.environ.get(name, str(default))) except (ValueError, TypeError): return default HARD_TIMEOUT = _envint("HKDY_HARD_TIMEOUT", 180) SSH_KEEPALIVE = _envint("HKDY_SSH_KEEPALIVE", 30) ADB_READY_TIMEOUT = _envint("HKDY_ADB_READY_TIMEOUT", 30) TUNNEL_BIND_RETRIES = _envint("HKDY_TUNNEL_BIND_RETRIES", 3) WATCHDOG_INTERVAL = _envint("HKDY_WATCHDOG_INTERVAL", 5) _CREATE_NO_WINDOW = 0x08000000 if os.name == "nt" else 0 # ============ 全局 Session 注册表 + 看门狗 ============ class _Registry: def __init__(self): self._sessions = set() self._lock = threading.Lock() self._watchdog_started = False def register(self, s): with self._lock: self._sessions.add(s) if not self._watchdog_started: self._watchdog_started = True t = threading.Thread( target=self._watchdog, daemon=True, name="hkdy-watchdog", ) t.start() def unregister(self, s): with self._lock: self._sessions.discard(s) def snapshot(self): with self._lock: return list(self._sessions) def _watchdog(self): """5 秒扫一次所有 session, 超时的强制清理.""" while True: time.sleep(WATCHDOG_INTERVAL) now = time.time() for s in self.snapshot(): try: if s._started_at is None: continue elapsed = now - s._started_at if elapsed > s.hard_timeout and not s._timeout_fired: s._on_timeout(elapsed) except Exception: pass _registry = _Registry() # ============ SSH 命令解析 ============ def parse_ssh_command(cmd): """从 'ssh -o... USER@HOST -p PORT -L LOCAL:rhost:REMOTE -Nf' 提取 6 元组. 注: LOCAL 字段会被 Session 忽略 (改用动态分配), 仅作为偏好值.""" m_uh = re.search(r"\s([A-Za-z0-9_.\-]+@[A-Za-z0-9_.\-]+)", cmd) m_p = re.search(r"-p\s+(\d+)", cmd) m_l = re.search(r"-L\s+(\d+):([^\s:]+):(\d+)", cmd) if not (m_uh and m_p and m_l): raise ValueError("SSH 命令缺少 user@host / -p / -L 字段") u, h = m_uh.group(1).split("@") return (u, h, int(m_p.group(1)), int(m_l.group(1)), m_l.group(2), int(m_l.group(3))) # ============ 内部: 隧道 + 双向管道 ============ def _pipe(chan, sock): try: while True: r, _, _ = select.select([sock, chan], [], []) if sock in r: d = sock.recv(65536) if not d: break chan.send(d) if chan in r: d = chan.recv(65536) if not d: break sock.send(d) except Exception: pass finally: try: chan.close() except Exception: pass try: sock.close() except Exception: pass class _Tunnel: def __init__(self, host, port, user, password, rhost, rport, preferred_local=0): self.host, self.port = host, port self.user, self.password = user, password self.rhost, self.rport = rhost, rport self.preferred_local = preferred_local # 0 = 让 OS 分配 self.local_port = None self.client = None self.listener = None self._stop = False def start(self, timeout=15): self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.client.connect( self.host, port=self.port, username=self.user, password=self.password, allow_agent=False, look_for_keys=False, timeout=timeout, ) transport = self.client.get_transport() if transport and SSH_KEEPALIVE > 0: transport.set_keepalive(SSH_KEEPALIVE) # 绑端口: 优先指定端口, 失败则回退 OS 分配 (并发多任务更安全) bound = False last_err = None for attempt in range(TUNNEL_BIND_RETRIES): try_port = self.preferred_local if attempt == 0 else 0 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sock.bind(("127.0.0.1", try_port)) self.listener = sock bound = True break except OSError as e: last_err = e try: sock.close() except Exception: pass if not bound: raise RuntimeError(f"无法绑定本地端口: {last_err}") self.local_port = self.listener.getsockname()[1] self.listener.listen(100) threading.Thread( target=self._accept, args=(transport,), daemon=True, name=f"hkdy-tunnel-{self.local_port}", ).start() def _accept(self, transport): while not self._stop: try: sock, addr = self.listener.accept() except OSError: return try: chan = transport.open_channel( "direct-tcpip", (self.rhost, self.rport), addr, ) except Exception: sock.close() continue threading.Thread( target=_pipe, args=(chan, sock), daemon=True, ).start() def close(self): self._stop = True for x in (self.listener, self.client): try: x and x.close() except Exception: pass # ============ ADB 原生 ============ def _run(argv, timeout=120): return subprocess.run( argv, capture_output=True, encoding="utf-8", errors="replace", timeout=timeout, creationflags=_CREATE_NO_WINDOW, ) # ============ 对外主类: Session ============ class Session: """托管一次任务的 SSH 隧道 + adb 设备. 线程安全, 支持 with, 带硬超时+自动重连, 并发安全清理.""" def __init__(self, raw_input_text, *, log=None, hard_timeout=None, preferred_local=None): ssh_cmd, password = self._parse_input(raw_input_text) u, h, p, lp_pref, rh, rp = parse_ssh_command(ssh_cmd) self._cred = (h, p, u, password, rh, rp) # 重连要用 # preferred_local=None → 直接用 OS 分配 (并发最安全) # preferred_local=0 → 同上 # preferred_local=N → 先试 N, 撞车再 OS 分配 if preferred_local is None: preferred_local = 0 self._preferred_local = preferred_local self.log = log or (lambda _m: None) self.hard_timeout = hard_timeout if hard_timeout is not None else HARD_TIMEOUT self._tunnel = None self._lock = threading.Lock() self._closed = False self._started_at = None self._timeout_fired = False # ---------- 解析输入 ---------- @staticmethod def _parse_input(raw): lines = [l.strip() for l in (raw or "").splitlines() if l.strip()] ssh = next((l for l in lines if l.lower().startswith("ssh ")), None) pwd = next( (l for l in lines if not l.lower().startswith(("ssh ", "adb "))), None, ) if not ssh: raise ValueError("未找到 SSH 命令 (应以 'ssh' 开头)") if not pwd: raise ValueError("未找到密码 (不以 ssh/adb 开头的那一行)") return ssh, pwd # ---------- with ---------- def __enter__(self): self.start() return self def __exit__(self, exc_type, exc, tb): self.close() return False # 异常继续往上抛 # ---------- 生命周期 ---------- def start(self, timeout=15): with self._lock: if self._started_at is not None: return host, port, user, pwd, rhost, rport = self._cred self._tunnel = _Tunnel( host, port, user, pwd, rhost, rport, preferred_local=self._preferred_local, ) try: self._tunnel.start(timeout=timeout) except paramiko.AuthenticationException: self._tunnel = None raise RuntimeError("SSH 认证失败: 用户名或密码错误") except Exception as e: self._tunnel = None raise RuntimeError(f"SSH 连接失败: {e}") time.sleep(0.5) self._started_at = time.time() _registry.register(self) # 注册到看门狗 (在锁外) try: self._adb_ready() except Exception: self.close() raise def _adb_ready(self): """把 localhost:local_port 挂进 adb server, 并完成 adb root. 绝不 kill-server (共享, 会影响其他并发用户).""" port = self._tunnel.local_port serial = f"localhost:{port}" deadline = time.time() + ADB_READY_TIMEOUT ok = False while time.time() < deadline: _run(["adb", "connect", serial], timeout=10) time.sleep(0.8) r = _run(["adb", "-s", serial, "get-state"], timeout=5) if (r.stdout or "").strip() == "device": ok = True break _run(["adb", "disconnect", serial], timeout=5) time.sleep(0.8) if not ok: raise RuntimeError(f"adb 无法进入 device 状态: {serial}") # 切 root (adbd 会重启一下) _run(["adb", "-s", serial, "root"], timeout=20) time.sleep(1.5) # root 后再等一会儿 + 重连 deadline2 = time.time() + 10 while time.time() < deadline2: _run(["adb", "connect", serial], timeout=10) time.sleep(0.5) r = _run(["adb", "-s", serial, "get-state"], timeout=5) if (r.stdout or "").strip() == "device": return raise RuntimeError(f"adb root 后重连失败: {serial}") def _on_timeout(self, elapsed): with self._lock: if self._timeout_fired or self._closed: return self._timeout_fired = True try: self.log(f"⚠ 硬超时 {elapsed:.1f}s (上限 {self.hard_timeout}s), 强制清理") except Exception: pass self._hard_close() def _hard_close(self): """并发安全. 只清自己的 serial + 自己的 tunnel. 不 kill-server.""" with self._lock: if self._closed: return self._closed = True port = self._tunnel.local_port if self._tunnel else None _registry.unregister(self) if port is not None: try: _run(["adb", "disconnect", f"localhost:{port}"], timeout=5) except Exception: pass if self._tunnel: try: self._tunnel.close() except Exception: pass def close(self): self._hard_close() # ---------- 属性 ---------- @property def serial(self): if self._closed: raise RuntimeError("Session 已关闭") if not self._tunnel: raise RuntimeError("Session 未启动") return f"localhost:{self._tunnel.local_port}" @property def local_port(self): return self._tunnel.local_port if self._tunnel else None @property def timed_out(self): return self._timeout_fired # ---------- 自动重连 ---------- @staticmethod def _is_offline_error(msg): m = (msg or "").lower() keys = ( "offline", "not found", "closed", "connection refused", "cannot connect", "failed to connect", "device not found", "error: no devices", ) return any(k in m for k in keys) def _reconnect_once(self): self.log(" ⚠ 连接异常, 重建 SSH+adb...") with self._lock: if self._closed: raise RuntimeError("Session 已关闭, 无法重连") old_port = self._tunnel.local_port if self._tunnel else None if old_port: try: _run(["adb", "disconnect", f"localhost:{old_port}"], timeout=5) except Exception: pass if self._tunnel: try: self._tunnel.close() except Exception: pass host, port, user, pwd, rhost, rport = self._cred self._tunnel = _Tunnel( host, port, user, pwd, rhost, rport, preferred_local=0, # 重连时直接用 OS 分配 ) self._tunnel.start() time.sleep(0.5) self._adb_ready() self.log(f" 已重连 serial={self.serial}") # ---------- ADB 方法 (带超时+重连) ---------- def _check_timeout_fired(self): if self._timeout_fired: raise RuntimeError( f"任务硬超时 (>{self.hard_timeout}s), 已被看门狗清理" ) def sh(self, cmd, check=True, timeout=120, auto_reconnect=True): """adb shell 包装. offline 时自动重连 1 次.""" last_err = None for attempt in range(2 if auto_reconnect else 1): if self._closed: self._check_timeout_fired() raise RuntimeError("Session 已关闭") r = _run(["adb", "-s", self.serial, "shell", cmd], timeout=timeout) if r.returncode == 0: return r.stdout or "" err = (r.stderr or "") + (r.stdout or "") last_err = err if auto_reconnect and attempt == 0 and self._is_offline_error(err): try: self._reconnect_once() except Exception as e: raise RuntimeError( f"adb shell 失败 (重连失败): {cmd}\n{err}\n重连错误: {e}" ) continue if check: raise RuntimeError(f"adb shell 失败:\n{cmd}\n{err}") return r.stdout or "" if check: raise RuntimeError( f"adb shell 失败 (重试后仍失败):\n{cmd}\n{last_err}" ) return "" def push(self, local, remote, timeout=1800): for attempt in range(2): if self._closed: self._check_timeout_fired() raise RuntimeError("Session 已关闭") r = _run(["adb", "-s", self.serial, "push", local, remote], timeout=timeout) if r.returncode == 0: return r.stdout err = (r.stderr or "") + (r.stdout or "") if attempt == 0 and self._is_offline_error(err): try: self._reconnect_once() except Exception as e: raise RuntimeError(f"adb push 失败 (重连失败): {err}\n重连: {e}") continue raise RuntimeError(f"adb push 失败: {err}") def pull(self, remote, local, timeout=300): for attempt in range(2): if self._closed: self._check_timeout_fired() raise RuntimeError("Session 已关闭") r = _run(["adb", "-s", self.serial, "pull", remote, local], timeout=timeout) if r.returncode == 0: return r.stdout err = (r.stderr or "") + (r.stdout or "") if attempt == 0 and self._is_offline_error(err): try: self._reconnect_once() except Exception as e: raise RuntimeError(f"adb pull 失败 (重连失败): {err}\n重连: {e}") continue raise RuntimeError(f"adb pull 失败: {err}") # ============ 进程退出兜底 ============ def _atexit_cleanup(): for s in _registry.snapshot(): try: s.close() except Exception: pass atexit.register(_atexit_cleanup)