File size: 3,015 Bytes
219ee1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Thin client for the encoder Hugging Face Space.

Uses the official `gradio_client` so the URL contract isn't fragile across
Gradio versions. Reconstructs torch tensors from the base64 float16 payload.

Env vars
--------
ENCODER_SPACE_URL  Public URL of the Space, e.g.
                   https://USER-pseudoscorex-encoder.hf.space  OR  USER/SPACE-NAME
                   Required.
HF_TOKEN           Required only if the Space is private.
"""
import base64
import logging
import os
import threading
import time

import numpy as np
import torch
from gradio_client import Client

logger = logging.getLogger("encoder_client")

ENCODER_SPACE_URL = os.getenv("ENCODER_SPACE_URL", "").strip()
HF_TOKEN = os.getenv("HF_TOKEN")

_client = None
_client_lock = threading.Lock()


def _get_client() -> Client:
    """Lazily build the Gradio client, thread-safe."""
    global _client
    if _client is not None:
        return _client
    with _client_lock:
        if _client is not None:
            return _client
        if not ENCODER_SPACE_URL:
            raise RuntimeError(
                "ENCODER_SPACE_URL is not set. Point it at your encoder Space, "
                "e.g. https://USER-pseudoscorex-encoder.hf.space"
            )
        logger.info("Connecting to encoder Space: %s", ENCODER_SPACE_URL)
        # gradio_client renamed/removed the auth kwarg between versions.
        # Public Spaces don't need a token, so fall back to no-auth on TypeError.
        if HF_TOKEN:
            for kw in ("hf_token", "token"):
                try:
                    _client = Client(ENCODER_SPACE_URL, **{kw: HF_TOKEN})
                    break
                except TypeError:
                    continue
            else:
                logger.warning(
                    "gradio_client did not accept hf_token/token; "
                    "connecting without auth (Space must be public)."
                )
                _client = Client(ENCODER_SPACE_URL)
        else:
            _client = Client(ENCODER_SPACE_URL)
        return _client


def encode_text(text: str, device: torch.device):
    """
    Returns (hidden, attention_mask, clean_tokens):
        hidden          torch.float32 tensor of shape (1, seq_len, 1024)
        attention_mask  torch.long    tensor of shape (1, seq_len)
        clean_tokens    list[str]
    """
    client = _get_client()

    t0 = time.perf_counter()
    out = client.predict(text, api_name="/encode")
    logger.debug("Space round-trip: %.2fs", time.perf_counter() - t0)

    if not isinstance(out, dict) or "hidden_b64" not in out:
        raise RuntimeError(f"Unexpected encoder Space response: {out!r}")

    arr = np.frombuffer(base64.b64decode(out["hidden_b64"]), dtype=np.float16)
    arr = arr.reshape(out["shape"])  # (seq_len, 1024)

    hidden = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0).to(device)
    mask = torch.tensor(out["attention_mask"], dtype=torch.long).unsqueeze(0).to(device)
    return hidden, mask, out["clean_tokens"]