pdf-trainer-worker / backend /worker /sftp_store.py
Avinashnalla7's picture
Remove dead _ensure_dir; finalize SFTP upload path logic
76fb4bd
from __future__ import annotations
def _rjoin(*parts):
out = []
for p in parts:
if not p:
continue
p = str(p).strip("/")
if p:
out.append(p)
return "/".join(out)
import base64
import os
from pathlib import Path
from typing import Optional, Tuple
import paramiko
def _pick_base_dir(sftp, preferred: str) -> str:
"""
Try preferred base dir; if not writable/usable, fallback to /pdfs.
"""
candidates = []
if preferred and str(preferred).strip():
candidates.append(str(preferred).strip())
candidates.append("/pdfs") # known-good on your server
# last resort: root
candidates.append("/")
for base in candidates:
base = "/" + base.strip("/")
test = base + "/.sftp_write_test"
try:
# write+delete tiny file to validate permissions
with sftp.file(test, "wb") as f:
f.write(b"ok")
sftp.remove(test)
return base
except Exception:
continue
raise RuntimeError("No writable SFTP base dir found (tried preferred, /pdfs, /)")
def _sftp_client() -> Tuple[paramiko.SFTPClient, paramiko.Transport]:
host = (os.environ.get("SFTP_HOST") or "").strip()
user = (os.environ.get("SFTP_USER") or "").strip()
if not host or not user:
raise RuntimeError("Missing SFTP_HOST or SFTP_USER")
port = int((os.environ.get("SFTP_PORT") or "22").strip())
password = os.environ.get("SFTP_PASS")
key_b64 = os.environ.get("SFTP_KEY_B64")
key_pass = os.environ.get("SFTP_KEY_PASSPHRASE")
t = paramiko.Transport((host, port))
if key_b64:
key_bytes = base64.b64decode(key_b64)
pkey = paramiko.RSAKey.from_private_key_file(_write_tmp_key(key_bytes), password=key_pass)
t.connect(username=user, pkey=pkey)
else:
if not password:
raise RuntimeError("Missing SFTP_PASS (or provide SFTP_KEY_B64)")
t.connect(username=user, password=password)
return paramiko.SFTPClient.from_transport(t), t
def _write_tmp_key(key_bytes: bytes) -> str:
# HF container FS is writable for tmp; keep it simple.
p = Path("/tmp/sftp_key.pem")
p.write_bytes(key_bytes)
os.chmod(p, 0o600)
return str(p)
def _mkdir_p(sftp: paramiko.SFTPClient, remote_dir: str) -> None:
parts = remote_dir.strip("/").split("/")
cur = ""
for part in parts:
cur += "/" + part
try:
sftp.stat(cur)
except FileNotFoundError:
sftp.mkdir(cur)
def store_to_sftp(pdf_id: str, template_id: str, cfg_json_bytes: bytes, pdf_bytes: bytes, pdf_name: str) -> str:
"""
Upload trainer outputs to SFTP.
Remote layout:
<base>/<template_id>/<pdf_id>/
trainer_config_<pdf_id>__<template_id>.json
<pdf_name or <pdf_id>.pdf>
"""
base = (os.environ.get("SFTP_BASE_DIR") or "/").strip()
if not base.startswith("/"):
base = "/" + base
base = base.rstrip("/") or "/"
# connect first
sftp, transport = _sftp_client()
try:
remote_dir = f"{base}/{template_id}/{pdf_id}".replace("//", "/")
_mkdir_p(sftp, remote_dir)
remote_cfg = f"{remote_dir}/trainer_config_{pdf_id}__{template_id}.json"
remote_pdf_name = (pdf_name or f"{pdf_id}.pdf").lstrip("/")
remote_pdf = f"{remote_dir}/{remote_pdf_name}"
with sftp.open(remote_cfg, "wb") as f:
f.write(cfg_json_bytes)
with sftp.open(remote_pdf, "wb") as f:
f.write(pdf_bytes)
return remote_dir
finally:
try:
sftp.close()
finally:
transport.close()