AND / conn.py
ziren28's picture
v2.6: Hubble sync + D1 database
126cf9c verified
"""
SSH 隧道 + ADB 设备统一管理 (v3)
==============================
给 restore.py / sync.py 共享, 解决以下并发/健壮性问题:
1) 动态本地端口 - 忽略 SSH 命令里的 -L 写死端口, 用 OS 分配
彻底避免不同云真机端口撞车
2) SSH keep-alive - paramiko Transport.set_keepalive, 默认 30s 一次
云真机平台会 kill 掉空闲 SSH, 心跳解决
3) 自动重连 - 命令检测到 'device offline' 等错误后重建隧道+adb,
重试 1 次. 不会无限循环
4) 硬超时兜底 - 全局看门狗线程, 超过 HKDY_HARD_TIMEOUT 秒的 session
强制清理 (防任务卡死持有资源)
5) 并发安全清理 - 只做 adb disconnect localhost:<自己的端口>,
绝不 adb kill-server (会踩死其他并发用户)
6) 进程退出兜底 - atexit 注册, 防脚本崩溃泄漏 SSH 连接
环境变量 (.env 或 os.environ):
HKDY_HARD_TIMEOUT 任务硬超时秒数 默认 180
HKDY_SSH_KEEPALIVE SSH keep-alive 秒数 默认 30
HKDY_ADB_READY_TIMEOUT adb ready 等待秒数 默认 30
HKDY_TUNNEL_BIND_RETRIES 绑定本地端口重试次数 默认 3
HKDY_WATCHDOG_INTERVAL 看门狗扫描间隔秒数 默认 5
典型用法:
with Session(raw_input_text, log=log_fn, hard_timeout=180) as sess:
sess.sh("am force-stop ...")
sess.push("local.xml", "/data/local/tmp/x.xml")
sess.pull("/sdcard/verify.png", "./verify.png")
# 退出 with 块时自动 close (即使异常也会清理)
"""
import atexit
import os
import re
import select
import socket
import subprocess
import sys
import threading
import time
_HERE = os.path.dirname(os.path.abspath(__file__))
_LIB = os.path.join(_HERE, "lib")
if os.path.isdir(_LIB) and _LIB not in sys.path:
sys.path.insert(0, _LIB)
# 尝试自动加载 .env (有 python-dotenv 就用, 没有就依赖 os.environ)
try:
from dotenv import load_dotenv
load_dotenv(os.path.join(_HERE, ".env"))
except Exception:
pass
import paramiko
# ============ 配置 ============
def _envint(name, default):
try:
return int(os.environ.get(name, str(default)))
except (ValueError, TypeError):
return default
HARD_TIMEOUT = _envint("HKDY_HARD_TIMEOUT", 180)
SSH_KEEPALIVE = _envint("HKDY_SSH_KEEPALIVE", 30)
ADB_READY_TIMEOUT = _envint("HKDY_ADB_READY_TIMEOUT", 30)
TUNNEL_BIND_RETRIES = _envint("HKDY_TUNNEL_BIND_RETRIES", 3)
WATCHDOG_INTERVAL = _envint("HKDY_WATCHDOG_INTERVAL", 5)
_CREATE_NO_WINDOW = 0x08000000 if os.name == "nt" else 0
# ============ 全局 Session 注册表 + 看门狗 ============
class _Registry:
def __init__(self):
self._sessions = set()
self._lock = threading.Lock()
self._watchdog_started = False
def register(self, s):
with self._lock:
self._sessions.add(s)
if not self._watchdog_started:
self._watchdog_started = True
t = threading.Thread(
target=self._watchdog,
daemon=True, name="hkdy-watchdog",
)
t.start()
def unregister(self, s):
with self._lock:
self._sessions.discard(s)
def snapshot(self):
with self._lock:
return list(self._sessions)
def _watchdog(self):
"""5 秒扫一次所有 session, 超时的强制清理."""
while True:
time.sleep(WATCHDOG_INTERVAL)
now = time.time()
for s in self.snapshot():
try:
if s._started_at is None:
continue
elapsed = now - s._started_at
if elapsed > s.hard_timeout and not s._timeout_fired:
s._on_timeout(elapsed)
except Exception:
pass
_registry = _Registry()
# ============ SSH 命令解析 ============
def parse_ssh_command(cmd):
"""从 'ssh -o... USER@HOST -p PORT -L LOCAL:rhost:REMOTE -Nf' 提取 6 元组.
注: LOCAL 字段会被 Session 忽略 (改用动态分配), 仅作为偏好值."""
m_uh = re.search(r"\s([A-Za-z0-9_.\-]+@[A-Za-z0-9_.\-]+)", cmd)
m_p = re.search(r"-p\s+(\d+)", cmd)
m_l = re.search(r"-L\s+(\d+):([^\s:]+):(\d+)", cmd)
if not (m_uh and m_p and m_l):
raise ValueError("SSH 命令缺少 user@host / -p / -L 字段")
u, h = m_uh.group(1).split("@")
return (u, h, int(m_p.group(1)),
int(m_l.group(1)), m_l.group(2), int(m_l.group(3)))
# ============ 内部: 隧道 + 双向管道 ============
def _pipe(chan, sock):
try:
while True:
r, _, _ = select.select([sock, chan], [], [])
if sock in r:
d = sock.recv(65536)
if not d:
break
chan.send(d)
if chan in r:
d = chan.recv(65536)
if not d:
break
sock.send(d)
except Exception:
pass
finally:
try: chan.close()
except Exception: pass
try: sock.close()
except Exception: pass
class _Tunnel:
def __init__(self, host, port, user, password, rhost, rport,
preferred_local=0):
self.host, self.port = host, port
self.user, self.password = user, password
self.rhost, self.rport = rhost, rport
self.preferred_local = preferred_local # 0 = 让 OS 分配
self.local_port = None
self.client = None
self.listener = None
self._stop = False
def start(self, timeout=15):
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.client.connect(
self.host, port=self.port,
username=self.user, password=self.password,
allow_agent=False, look_for_keys=False, timeout=timeout,
)
transport = self.client.get_transport()
if transport and SSH_KEEPALIVE > 0:
transport.set_keepalive(SSH_KEEPALIVE)
# 绑端口: 优先指定端口, 失败则回退 OS 分配 (并发多任务更安全)
bound = False
last_err = None
for attempt in range(TUNNEL_BIND_RETRIES):
try_port = self.preferred_local if attempt == 0 else 0
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind(("127.0.0.1", try_port))
self.listener = sock
bound = True
break
except OSError as e:
last_err = e
try: sock.close()
except Exception: pass
if not bound:
raise RuntimeError(f"无法绑定本地端口: {last_err}")
self.local_port = self.listener.getsockname()[1]
self.listener.listen(100)
threading.Thread(
target=self._accept, args=(transport,),
daemon=True, name=f"hkdy-tunnel-{self.local_port}",
).start()
def _accept(self, transport):
while not self._stop:
try:
sock, addr = self.listener.accept()
except OSError:
return
try:
chan = transport.open_channel(
"direct-tcpip", (self.rhost, self.rport), addr,
)
except Exception:
sock.close()
continue
threading.Thread(
target=_pipe, args=(chan, sock), daemon=True,
).start()
def close(self):
self._stop = True
for x in (self.listener, self.client):
try: x and x.close()
except Exception: pass
# ============ ADB 原生 ============
def _run(argv, timeout=120):
return subprocess.run(
argv, capture_output=True,
encoding="utf-8", errors="replace",
timeout=timeout, creationflags=_CREATE_NO_WINDOW,
)
# ============ 对外主类: Session ============
class Session:
"""托管一次任务的 SSH 隧道 + adb 设备.
线程安全, 支持 with, 带硬超时+自动重连, 并发安全清理."""
def __init__(self, raw_input_text, *, log=None, hard_timeout=None,
preferred_local=None):
ssh_cmd, password = self._parse_input(raw_input_text)
u, h, p, lp_pref, rh, rp = parse_ssh_command(ssh_cmd)
self._cred = (h, p, u, password, rh, rp) # 重连要用
# preferred_local=None → 直接用 OS 分配 (并发最安全)
# preferred_local=0 → 同上
# preferred_local=N → 先试 N, 撞车再 OS 分配
if preferred_local is None:
preferred_local = 0
self._preferred_local = preferred_local
self.log = log or (lambda _m: None)
self.hard_timeout = hard_timeout if hard_timeout is not None else HARD_TIMEOUT
self._tunnel = None
self._lock = threading.Lock()
self._closed = False
self._started_at = None
self._timeout_fired = False
# ---------- 解析输入 ----------
@staticmethod
def _parse_input(raw):
lines = [l.strip() for l in (raw or "").splitlines() if l.strip()]
ssh = next((l for l in lines if l.lower().startswith("ssh ")), None)
pwd = next(
(l for l in lines if not l.lower().startswith(("ssh ", "adb "))),
None,
)
if not ssh:
raise ValueError("未找到 SSH 命令 (应以 'ssh' 开头)")
if not pwd:
raise ValueError("未找到密码 (不以 ssh/adb 开头的那一行)")
return ssh, pwd
# ---------- with ----------
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc, tb):
self.close()
return False # 异常继续往上抛
# ---------- 生命周期 ----------
def start(self, timeout=15):
with self._lock:
if self._started_at is not None:
return
host, port, user, pwd, rhost, rport = self._cred
self._tunnel = _Tunnel(
host, port, user, pwd, rhost, rport,
preferred_local=self._preferred_local,
)
try:
self._tunnel.start(timeout=timeout)
except paramiko.AuthenticationException:
self._tunnel = None
raise RuntimeError("SSH 认证失败: 用户名或密码错误")
except Exception as e:
self._tunnel = None
raise RuntimeError(f"SSH 连接失败: {e}")
time.sleep(0.5)
self._started_at = time.time()
_registry.register(self) # 注册到看门狗 (在锁外)
try:
self._adb_ready()
except Exception:
self.close()
raise
def _adb_ready(self):
"""把 localhost:local_port 挂进 adb server, 并完成 adb root.
绝不 kill-server (共享, 会影响其他并发用户)."""
port = self._tunnel.local_port
serial = f"localhost:{port}"
deadline = time.time() + ADB_READY_TIMEOUT
ok = False
while time.time() < deadline:
_run(["adb", "connect", serial], timeout=10)
time.sleep(0.8)
r = _run(["adb", "-s", serial, "get-state"], timeout=5)
if (r.stdout or "").strip() == "device":
ok = True
break
_run(["adb", "disconnect", serial], timeout=5)
time.sleep(0.8)
if not ok:
raise RuntimeError(f"adb 无法进入 device 状态: {serial}")
# 切 root (adbd 会重启一下)
_run(["adb", "-s", serial, "root"], timeout=20)
time.sleep(1.5)
# root 后再等一会儿 + 重连
deadline2 = time.time() + 10
while time.time() < deadline2:
_run(["adb", "connect", serial], timeout=10)
time.sleep(0.5)
r = _run(["adb", "-s", serial, "get-state"], timeout=5)
if (r.stdout or "").strip() == "device":
return
raise RuntimeError(f"adb root 后重连失败: {serial}")
def _on_timeout(self, elapsed):
with self._lock:
if self._timeout_fired or self._closed:
return
self._timeout_fired = True
try:
self.log(f"⚠ 硬超时 {elapsed:.1f}s (上限 {self.hard_timeout}s), 强制清理")
except Exception:
pass
self._hard_close()
def _hard_close(self):
"""并发安全. 只清自己的 serial + 自己的 tunnel. 不 kill-server."""
with self._lock:
if self._closed:
return
self._closed = True
port = self._tunnel.local_port if self._tunnel else None
_registry.unregister(self)
if port is not None:
try:
_run(["adb", "disconnect", f"localhost:{port}"], timeout=5)
except Exception:
pass
if self._tunnel:
try:
self._tunnel.close()
except Exception:
pass
def close(self):
self._hard_close()
# ---------- 属性 ----------
@property
def serial(self):
if self._closed:
raise RuntimeError("Session 已关闭")
if not self._tunnel:
raise RuntimeError("Session 未启动")
return f"localhost:{self._tunnel.local_port}"
@property
def local_port(self):
return self._tunnel.local_port if self._tunnel else None
@property
def timed_out(self):
return self._timeout_fired
# ---------- 自动重连 ----------
@staticmethod
def _is_offline_error(msg):
m = (msg or "").lower()
keys = (
"offline", "not found", "closed",
"connection refused", "cannot connect", "failed to connect",
"device not found", "error: no devices",
)
return any(k in m for k in keys)
def _reconnect_once(self):
self.log(" ⚠ 连接异常, 重建 SSH+adb...")
with self._lock:
if self._closed:
raise RuntimeError("Session 已关闭, 无法重连")
old_port = self._tunnel.local_port if self._tunnel else None
if old_port:
try:
_run(["adb", "disconnect", f"localhost:{old_port}"], timeout=5)
except Exception:
pass
if self._tunnel:
try: self._tunnel.close()
except Exception: pass
host, port, user, pwd, rhost, rport = self._cred
self._tunnel = _Tunnel(
host, port, user, pwd, rhost, rport,
preferred_local=0, # 重连时直接用 OS 分配
)
self._tunnel.start()
time.sleep(0.5)
self._adb_ready()
self.log(f" 已重连 serial={self.serial}")
# ---------- ADB 方法 (带超时+重连) ----------
def _check_timeout_fired(self):
if self._timeout_fired:
raise RuntimeError(
f"任务硬超时 (>{self.hard_timeout}s), 已被看门狗清理"
)
def sh(self, cmd, check=True, timeout=120, auto_reconnect=True):
"""adb shell 包装. offline 时自动重连 1 次."""
last_err = None
for attempt in range(2 if auto_reconnect else 1):
if self._closed:
self._check_timeout_fired()
raise RuntimeError("Session 已关闭")
r = _run(["adb", "-s", self.serial, "shell", cmd], timeout=timeout)
if r.returncode == 0:
return r.stdout or ""
err = (r.stderr or "") + (r.stdout or "")
last_err = err
if auto_reconnect and attempt == 0 and self._is_offline_error(err):
try:
self._reconnect_once()
except Exception as e:
raise RuntimeError(
f"adb shell 失败 (重连失败): {cmd}\n{err}\n重连错误: {e}"
)
continue
if check:
raise RuntimeError(f"adb shell 失败:\n{cmd}\n{err}")
return r.stdout or ""
if check:
raise RuntimeError(
f"adb shell 失败 (重试后仍失败):\n{cmd}\n{last_err}"
)
return ""
def push(self, local, remote, timeout=1800):
for attempt in range(2):
if self._closed:
self._check_timeout_fired()
raise RuntimeError("Session 已关闭")
r = _run(["adb", "-s", self.serial, "push", local, remote],
timeout=timeout)
if r.returncode == 0:
return r.stdout
err = (r.stderr or "") + (r.stdout or "")
if attempt == 0 and self._is_offline_error(err):
try:
self._reconnect_once()
except Exception as e:
raise RuntimeError(f"adb push 失败 (重连失败): {err}\n重连: {e}")
continue
raise RuntimeError(f"adb push 失败: {err}")
def pull(self, remote, local, timeout=300):
for attempt in range(2):
if self._closed:
self._check_timeout_fired()
raise RuntimeError("Session 已关闭")
r = _run(["adb", "-s", self.serial, "pull", remote, local],
timeout=timeout)
if r.returncode == 0:
return r.stdout
err = (r.stderr or "") + (r.stdout or "")
if attempt == 0 and self._is_offline_error(err):
try:
self._reconnect_once()
except Exception as e:
raise RuntimeError(f"adb pull 失败 (重连失败): {err}\n重连: {e}")
continue
raise RuntimeError(f"adb pull 失败: {err}")
# ============ 进程退出兜底 ============
def _atexit_cleanup():
for s in _registry.snapshot():
try: s.close()
except Exception: pass
atexit.register(_atexit_cleanup)