File size: 2,445 Bytes
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# =============================================================================
# smollm.py
# SmolLM2 Inference Engine
# SmolLM2 Service Space
# Copyright 2026 - Volkan Kücükbudak
# Apache License V2 + ESOL 1.1
# =============================================================================

import logging
import torch
from typing import Optional
import model as model_module

logger = logging.getLogger("smollm")

_tokenizer = None
_model     = None
_device    = None


def load():
    """Lazy model loader — called on first request."""
    global _tokenizer, _model, _device

    if _model is not None:
        return

    from transformers import AutoModelForCausalLM, AutoTokenizer

    model_id = model_module.get_model_id()
    kwargs   = model_module.get_model_kwargs()
    _device  = "cuda" if torch.cuda.is_available() else "cpu"

    logger.info(f"Loading {model_id} on {_device}...")
    _tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
    _model     = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).to(_device)
    logger.info(f"Model ready [{_device}]")

    # Update model card on startup
    model_module.push_model_card({
        "model_id": model_id,
        "device": _device,
    })


async def complete(
    prompt: str,
    system_prompt: str = "",
    max_tokens: int = 150,
    temperature: float = 0.2,
) -> str:
    """
    Run SmolLM2 inference.

    Returns:
        Generated text string.
    Raises:
        RuntimeError on inference failure.
    """
    load()

    messages = []
    if system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})

    text   = _tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = _tokenizer.encode(text, return_tensors="pt").to(_device)

    with torch.no_grad():
        outputs = _model.generate(
            inputs,
            max_new_tokens=max_tokens,
            temperature=temperature if temperature > 0 else None,
            do_sample=temperature > 0,
            top_p=0.9 if temperature > 0 else None,
            pad_token_id=_tokenizer.eos_token_id,
        )

    new_tokens = outputs[0][inputs.shape[-1]:]
    return _tokenizer.decode(new_tokens, skip_special_tokens=True).strip()


def is_ready() -> bool:
    return _model is not None

def device_info() -> str:
    return _device or "not loaded"