File size: 8,469 Bytes
ef18673
 
 
 
a21a30f
1e799aa
 
 
ef18673
 
 
 
 
 
 
a21a30f
1e799aa
ef18673
 
 
 
 
 
1e799aa
 
 
 
 
 
 
 
a21a30f
 
 
 
 
 
 
 
 
 
 
 
ef18673
 
 
 
 
 
 
 
 
1e799aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c64fd6
 
 
 
 
 
 
 
 
 
ef18673
 
1e799aa
ef18673
1e799aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef18673
1e799aa
 
 
 
ef18673
 
 
1e799aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2114e
 
1e799aa
 
3d2114e
 
 
1e799aa
 
 
 
 
 
 
 
3d2114e
1e799aa
3d2114e
 
 
 
 
 
 
1e799aa
 
 
 
 
 
 
 
 
3d2114e
1e799aa
3d2114e
1e799aa
 
 
 
ef18673
 
 
 
1e799aa
ef18673
 
 
 
 
1e799aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f432f
 
 
 
 
 
 
4c64fd6
 
 
 
b4f432f
4c64fd6
 
b4f432f
 
 
 
 
a21a30f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""GPU-oriented FastAPI server for SAGE."""

from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import Any, Optional

import torch
from fastapi import FastAPI
from pydantic import BaseModel

from model.config import ModelConfig
from model.model import SageTransformer
from serve.control_plane import build_control_router, get_runtime_access_info
from train.checkpoint import load_latest_checkpoint
from train.hardware import HardwareConfig


app = FastAPI(title="SAGE Server")
_MODEL: SageTransformer | None = None
_TOKENIZER = None
_MODEL_DEVICE: torch.device | None = None
_MODEL_STATE: dict[str, Any] = {
    "model_config": None,
    "checkpoint_dir": None,
    "checkpoint_loaded": False,
    "checkpoint_step": 0,
    "tokenizer_path": None,
}
_LOGGER = logging.getLogger("uvicorn.error")


def _print_startup_banner() -> None:
    """Print the login details for the browser control UI."""
    access = get_runtime_access_info()
    local_url = (access["local_url"] or "http://127.0.0.1:8000").rstrip("/")
    public_url = access["public_url"]
    _LOGGER.info("SAGE local URL: %s/", local_url)
    if public_url:
        _LOGGER.info("SAGE public URL: %s/", public_url.rstrip("/"))
    _LOGGER.info("SAGE login password: %s", access["password"])


class GenerationRequest(BaseModel):
    """Request schema for text generation."""

    input_ids: list[int]
    max_new_tokens: int = 32


class ChatRequest(BaseModel):
    """Request schema for text generation through the tokenizer."""

    prompt: str
    max_new_tokens: int = 64


def get_generation_device() -> torch.device:
    """Pick the active inference device."""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _resolve_model_config_path() -> Path:
    configured = Path(os.environ.get("SAGE_MODEL_CONFIG", "configs/model/1b.yaml"))
    return configured if configured.exists() else Path("configs/model/1b.yaml")


def _resolve_checkpoint_dir() -> Path:
    return Path(os.environ.get("SAGE_CHECKPOINT_DIR", "runs/default"))


def _resolve_tokenizer_path() -> Path:
    return Path(os.environ.get("SAGE_TOKENIZER_MODEL", "tokenizer/tokenizer.model"))


def configure_runtime_paths(model_config: str | None, checkpoint_dir: str | None, tokenizer_model: str | None) -> None:
    """Configure runtime paths via environment variables."""
    if model_config:
        os.environ["SAGE_MODEL_CONFIG"] = model_config
    if checkpoint_dir:
        os.environ["SAGE_CHECKPOINT_DIR"] = checkpoint_dir
    if tokenizer_model:
        os.environ["SAGE_TOKENIZER_MODEL"] = tokenizer_model


def get_model() -> SageTransformer:
    """Lazily create the model for server startup."""
    global _MODEL, _MODEL_DEVICE
    if _MODEL is None:
        config_path = _resolve_model_config_path()
        config = ModelConfig.from_yaml(config_path) if config_path.exists() else ModelConfig()
        _MODEL = SageTransformer(config)
        checkpoint_dir = _resolve_checkpoint_dir()
        checkpoint_step = 0
        if checkpoint_dir.exists():
            checkpoint_step = load_latest_checkpoint(_MODEL, None, None, None, str(checkpoint_dir), device="cpu")
        _MODEL_STATE.update(
            {
                "model_config": str(config_path),
                "checkpoint_dir": str(checkpoint_dir),
                "checkpoint_loaded": checkpoint_step > 0,
                "checkpoint_step": checkpoint_step,
            }
        )
        _MODEL.eval()
    device = get_generation_device()
    if _MODEL_DEVICE != device:
        _MODEL = _MODEL.to(device)
        _MODEL_DEVICE = device
    return _MODEL


def get_tokenizer():
    """Lazily load the SentencePiece tokenizer if present."""
    global _TOKENIZER
    if _TOKENIZER is None:
        tokenizer_path = _resolve_tokenizer_path()
        _MODEL_STATE["tokenizer_path"] = str(tokenizer_path)
        if not tokenizer_path.exists():
            return None
        from tokenizer.validate_tokenizer import load_processor

        _TOKENIZER = load_processor(str(tokenizer_path))
    return _TOKENIZER


def _generate_token_ids(input_ids: list[int], max_new_tokens: int) -> list[int]:
    """Run greedy decoding from input token ids."""
    model = get_model()
    device = get_generation_device()
    context_length = model.config.context_length
    generated = list(input_ids[-context_length:])
    with torch.inference_mode():
        for _ in range(max(0, int(max_new_tokens))):
            window = generated[-context_length:]
            tensor_ids = torch.tensor([window], dtype=torch.long, device=device)
            logits, _ = model(tensor_ids)
            next_token = int(torch.argmax(logits[:, -1, :], dim=-1).item())
            generated.append(next_token)
    return generated


def chat_status() -> dict[str, object]:
    """Return whether text chat is configured for the current server."""
    tokenizer = get_tokenizer()
    checkpoint_dir = _MODEL_STATE["checkpoint_dir"] or str(_resolve_checkpoint_dir())
    checkpoint_loaded = bool(_MODEL_STATE["checkpoint_loaded"])
    if not checkpoint_loaded:
        checkpoint_loaded = any(Path(checkpoint_dir).glob("ckpt_step_*.pt"))
    checkpoint_step = int(_MODEL_STATE["checkpoint_step"] or 0)
    if checkpoint_step == 0 and checkpoint_loaded:
        latest = sorted(Path(checkpoint_dir).glob("ckpt_step_*.pt"))
        if latest:
            checkpoint_step = int(latest[-1].stem.split("_")[-1])
    available = tokenizer is not None
    warning = None
    if tokenizer is None:
        warning = "Tokenizer model not found. Train or place tokenizer/tokenizer.model before using browser chat."
    elif not checkpoint_loaded:
        warning = "No checkpoint loaded. Chat will run with randomly initialized model weights until you train or configure SAGE_CHECKPOINT_DIR."
    return {
        "available": available,
        "tokenizer_path": _MODEL_STATE["tokenizer_path"],
        "checkpoint_dir": checkpoint_dir,
        "checkpoint_loaded": checkpoint_loaded,
        "checkpoint_step": checkpoint_step,
        "warning": warning,
    }


@app.get("/health")
def health() -> dict[str, object]:
    """Return basic health and hardware information."""
    hw = HardwareConfig(model_size_b=1.0, context_length=4096)
    return {"status": "ok", "hardware": hw.summary(), "chat": chat_status()}


@app.post("/generate")
def generate(request: GenerationRequest) -> dict[str, object]:
    """Generate continuation token ids from an input token list."""
    return {"tokens": _generate_token_ids(request.input_ids, request.max_new_tokens)}


@app.get("/chat/status")
def get_chat_status() -> dict[str, object]:
    """Expose browser-chat readiness."""
    return chat_status()


@app.post("/chat")
def chat(request: ChatRequest) -> dict[str, object]:
    """Generate text from a prompt using the local tokenizer."""
    tokenizer = get_tokenizer()
    if tokenizer is None:
        return {
            "success": False,
            "detail": "Tokenizer model not found. Train the tokenizer first or set SAGE_TOKENIZER_MODEL.",
            **chat_status(),
        }
    prompt = request.prompt.strip()
    if not prompt:
        return {"success": False, "detail": "Prompt cannot be empty.", **chat_status()}
    prompt_ids = list(tokenizer.encode(prompt, out_type=int))
    generated = _generate_token_ids(prompt_ids, request.max_new_tokens)
    completion_ids = generated[len(prompt_ids) :]
    return {
        "success": True,
        "prompt": prompt,
        "response": tokenizer.decode(completion_ids),
        "input_ids": prompt_ids,
        "output_ids": generated,
        "new_token_ids": completion_ids,
        **chat_status(),
    }


def _health_action(_: dict[str, object]) -> dict[str, object]:
    return health()


def _generate_action(args: dict[str, object]) -> dict[str, object]:
    input_ids_raw = args.get("input_ids", [])
    if not isinstance(input_ids_raw, list):
        input_ids_raw = [input_ids_raw] if input_ids_raw is not None else []
    input_ids = [int(item) for item in input_ids_raw]  # type: ignore
    request = GenerationRequest(
        input_ids=input_ids,
        max_new_tokens=int(args.get("max_new_tokens", 32)),  # type: ignore
    )
    return generate(request)


app.include_router(build_control_router({"health_check": _health_action, "generate": _generate_action}))


@app.on_event("startup")
def _startup_banner() -> None:
    _print_startup_banner()