ehejin's picture
new sampler model.py
3885d4d
"""Tinker inference client. Supports both base models and fine-tuned checkpoints."""
import re
import streamlit as st
@st.cache_resource
def _get_tinker_clients(model_name: str, sampler_path: str = ""):
"""
Initialise and cache the Tinker sampling client, renderer, and tokenizer.
If sampler_path is provided, loads from that checkpoint (fine-tuned model).
Otherwise, loads the base model_name.
Cache key includes both so different variants get different clients.
"""
import tinker
from tinker import types as tinker_types
from tinker_cookbook import renderers
from tinker_cookbook.model_info import get_recommended_renderer_name
from tinker_cookbook.tokenizer_utils import get_tokenizer
service_client = tinker.ServiceClient()
if sampler_path:
print(f"[MODEL] Loading fine-tuned checkpoint: {sampler_path}")
sampling_client = service_client.create_sampling_client(model_path=sampler_path)
else:
print(f"[MODEL] Loading base model: {model_name}")
sampling_client = service_client.create_sampling_client(base_model=model_name)
tokenizer = get_tokenizer(model_name)
renderer_name = get_recommended_renderer_name(model_name)
renderer = renderers.get_renderer(renderer_name, tokenizer)
return sampling_client, renderer, tinker_types
def call_model(messages: list, cfg: dict) -> str:
"""Send a message list to Tinker and return cleaned response text."""
model_name = cfg["model_name"]
sampler_path = cfg.get("sampler_path", "")
print(f"[MODEL] model_name={model_name} sampler_path={sampler_path or '(base)'}")
print(f"[MODEL] num_messages={len(messages)}")
print(f"[MODEL] roles={[m['role'] for m in messages]}")
if messages:
print(f"[MODEL] system_prompt[:150]={messages[0]['content'][:150]}")
try:
from tinker_cookbook import renderers as tinker_renderers
sampling_client, renderer, tinker_types = _get_tinker_clients(model_name, sampler_path)
prompt = renderer.build_generation_prompt(messages)
params = tinker_types.SamplingParams(
max_tokens=1000,
temperature=0.7,
stop=renderer.get_stop_sequences(),
)
result = sampling_client.sample(
prompt=prompt,
sampling_params=params,
num_samples=1,
).result()
parsed_message, _ = renderer.parse_response(result.sequences[0].tokens)
content = tinker_renderers.format_content_as_string(parsed_message["content"])
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
content = re.sub(r"<\|[^|]*\|>", "", content).strip()
match = re.search(r"(.{40,}?)\1{4,}", content, flags=re.DOTALL)
if match:
content = content[: match.start() + len(match.group(1))].strip()
if not content or len(content.split()) < 3:
raise ValueError("Model output cleanup yielded no usable content.")
return content
except Exception as e:
print(f"[MODEL] Tinker error: {e}")
return f"[Model error: {e}]"