File size: 5,972 Bytes
88d4864
19d2650
 
 
 
 
 
 
 
 
 
 
41e0c9e
 
 
 
 
19d2650
 
 
 
 
 
 
 
 
 
41e0c9e
19d2650
25223be
19d2650
41e0c9e
 
 
 
19d2650
 
41e0c9e
 
 
 
 
 
 
25223be
19d2650
88d4864
25223be
 
19d2650
 
 
 
 
 
 
 
 
 
25223be
19d2650
 
25223be
 
 
 
41e0c9e
19d2650
 
 
 
41e0c9e
19d2650
 
41e0c9e
25223be
19d2650
25223be
19d2650
 
 
 
 
 
25223be
19d2650
 
 
 
 
41e0c9e
 
19d2650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41e0c9e
 
 
25223be
41e0c9e
 
 
 
19d2650
41e0c9e
19d2650
 
41e0c9e
 
 
 
 
19d2650
41e0c9e
 
 
19d2650
 
41e0c9e
19d2650
41e0c9e
19d2650
 
41e0c9e
 
 
 
 
 
19d2650
41e0c9e
 
 
 
19d2650
41e0c9e
 
 
 
 
 
 
19d2650
25223be
41e0c9e
 
19d2650
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""openbmb/MiniCPM5-1B (safetensors) wrapper via transformers, ZeroGPU-ready.

On HuggingFace ZeroGPU Spaces:
  * `spaces` is imported before torch, and the actual generation runs inside a
    `@spaces.GPU` function (the GPU is attached only for that call).
  * the model is placed on `cuda` at module level (ZeroGPU's CUDA emulation makes
    this work outside the decorated function), as the docs recommend.

Off ZeroGPU (local CPU/GPU, or `spaces` not installed) everything still works:
  * `@spaces.GPU` becomes a no-op, and the device falls back to a real CUDA GPU
    if present, otherwise CPU.
If anything is unavailable the app keeps running with a deterministic fallback.
"""
from __future__ import annotations

import os
import threading
import time

# IMPORTANT: import `spaces` BEFORE torch so it can patch CUDA for ZeroGPU.
try:
    import spaces  # noqa: F401
    _HAS_SPACES = True
except Exception:
    _HAS_SPACES = False

_ON_ZEROGPU = bool(os.environ.get("SPACES_ZERO_GPU"))

_MODEL = None
_TOKENIZER = None
_DEVICE = "cpu"
_LOAD_LOCK = threading.Lock()
_LOAD_ERROR = None

SYSTEM_PROMPT = (
    "You are FLIGHTDECK, a terse air-traffic analyst. Answer only from the live "
    "flight data you are given. Be concise and use callsigns. Never invent flights."
)


def llm_disabled() -> bool:
    return os.environ.get("DISABLE_LLM", "0").strip() in {"1", "true", "yes"}


def _model_id() -> str:
    # Safetensors repo (transformers), overridable via LLM_REPO.
    return os.environ.get("LLM_REPO", "openbmb/MiniCPM5-1B")


def _gpu(fn):
    """Wrap a function with @spaces.GPU on ZeroGPU; no-op everywhere else."""
    if _HAS_SPACES:
        duration = int(os.environ.get("ZEROGPU_DURATION", "60"))
        return spaces.GPU(duration=duration)(fn)
    return fn


def _apply_chat_template(messages, tokenizer) -> str:
    if getattr(tokenizer, "chat_template", None):
        return tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True)
    parts = [f"[{m.get('role', 'user').upper()}]\n{m.get('content', '')}" for m in messages]
    parts.append("[ASSISTANT]\n")
    return "\n".join(parts)


def _load():
    """Load model + tokenizer once (in the main process; ZeroGPU-safe)."""
    global _MODEL, _TOKENIZER, _DEVICE, _LOAD_ERROR
    if _MODEL is not None or _LOAD_ERROR is not None:
        return _MODEL
    with _LOAD_LOCK:
        if _MODEL is not None or _LOAD_ERROR is not None:
            return _MODEL
        try:
            import torch
            from transformers import AutoModelForCausalLM, AutoTokenizer

            mid = _model_id()
            want_cuda = _ON_ZEROGPU or torch.cuda.is_available()
            _DEVICE = "cuda" if want_cuda else "cpu"
            dtype = torch.float16 if want_cuda else torch.float32

            _TOKENIZER = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                mid, dtype=dtype, trust_remote_code=True)
            # Module-level cuda placement (works via ZeroGPU CUDA emulation).
            model.to(_DEVICE)
            model.eval()
            _MODEL = model
        except Exception as e:  # noqa: BLE001
            _LOAD_ERROR = e
            _MODEL = None
    return _MODEL


@_gpu
def _generate(prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
    """The only GPU-touching function — runs on the ZeroGPU device when attached."""
    import torch
    inputs = _TOKENIZER(prompt, return_tensors="pt").to(_DEVICE)
    gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0,
        top_p=top_p,
        pad_token_id=_TOKENIZER.eos_token_id,
    )
    if temperature > 0:
        gen_kwargs["temperature"] = temperature
    with torch.no_grad():
        out = _MODEL.generate(**inputs, **gen_kwargs)
    new_tokens = out[0][inputs["input_ids"].shape[1]:]
    return _TOKENIZER.decode(new_tokens, skip_special_tokens=True)


def status() -> str:
    label = _model_id().split("/")[-1]
    if llm_disabled():
        return "LLM disabled (DISABLE_LLM=1)."
    if _LOAD_ERROR is not None:
        return f"{label} unavailable: {type(_LOAD_ERROR).__name__}: {_LOAD_ERROR}"
    if _MODEL is None:
        return f"{label} not loaded yet (loads on first query)."
    mode = "ZeroGPU" if (_HAS_SPACES and _ON_ZEROGPU) else _DEVICE.upper()
    return f"{label} online ({mode})."


def available() -> bool:
    if llm_disabled():
        return False
    return _load() is not None


def complete(messages, *, max_tokens=512, temperature=0.2, top_p=0.9):
    """Chat completion used by the agents. Returns (text, latency_ms)."""
    if _load() is None:
        raise RuntimeError(status())
    prompt = _apply_chat_template(messages, _TOKENIZER)
    t0 = time.time()
    text = _generate(prompt, int(max_tokens), float(temperature), float(top_p))
    return str(text).strip(), int((time.time() - t0) * 1000)


def _fallback(question: str, context: str) -> str:
    return (
        "[AI offline — raw readout]\n"
        f"Q: {question}\n\n{context}\n\n"
        "(Enable the model — transformers + torch — for natural-language briefings.)"
    )


def briefing(question: str, context: str, max_tokens: int = 512) -> str:
    if llm_disabled() or _load() is None:
        return _fallback(question, context)
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",
         "content": f"LIVE FLIGHT DATA:\n{context}\n\nQUESTION: {question}"},
    ]
    try:
        text, _ = complete(messages, max_tokens=max_tokens, temperature=0.4)
        return text
    except Exception as e:  # noqa: BLE001
        return _fallback(question, f"{context}\n\n(LLM error: {e})")


# ZeroGPU recommends placing the model at startup (not lazily). On ZeroGPU we
# eager-load; locally we stay lazy so imports/tests remain fast.
if _ON_ZEROGPU and not llm_disabled():
    _load()