| """Tinker inference client. Supports both base models and fine-tuned checkpoints.""" |
| import re |
|
|
| import streamlit as st |
|
|
|
|
| @st.cache_resource |
| def _get_tinker_clients(model_name: str, sampler_path: str = ""): |
| """ |
| Initialise and cache the Tinker sampling client, renderer, and tokenizer. |
| If sampler_path is provided, loads from that checkpoint (fine-tuned model). |
| Otherwise, loads the base model_name. |
| Cache key includes both so different variants get different clients. |
| """ |
| import tinker |
| from tinker import types as tinker_types |
| from tinker_cookbook import renderers |
| from tinker_cookbook.model_info import get_recommended_renderer_name |
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
| service_client = tinker.ServiceClient() |
| if sampler_path: |
| print(f"[MODEL] Loading fine-tuned checkpoint: {sampler_path}") |
| sampling_client = service_client.create_sampling_client(model_path=sampler_path) |
| else: |
| print(f"[MODEL] Loading base model: {model_name}") |
| sampling_client = service_client.create_sampling_client(base_model=model_name) |
|
|
| tokenizer = get_tokenizer(model_name) |
| renderer_name = get_recommended_renderer_name(model_name) |
| renderer = renderers.get_renderer(renderer_name, tokenizer) |
| return sampling_client, renderer, tinker_types |
|
|
|
|
| def call_model(messages: list, cfg: dict) -> str: |
| """Send a message list to Tinker and return cleaned response text.""" |
| model_name = cfg["model_name"] |
| sampler_path = cfg.get("sampler_path", "") |
| print(f"[MODEL] model_name={model_name} sampler_path={sampler_path or '(base)'}") |
| print(f"[MODEL] num_messages={len(messages)}") |
| print(f"[MODEL] roles={[m['role'] for m in messages]}") |
| if messages: |
| print(f"[MODEL] system_prompt[:150]={messages[0]['content'][:150]}") |
|
|
| try: |
| from tinker_cookbook import renderers as tinker_renderers |
|
|
| sampling_client, renderer, tinker_types = _get_tinker_clients(model_name, sampler_path) |
|
|
| prompt = renderer.build_generation_prompt(messages) |
| params = tinker_types.SamplingParams( |
| max_tokens=1000, |
| temperature=0.7, |
| stop=renderer.get_stop_sequences(), |
| ) |
| result = sampling_client.sample( |
| prompt=prompt, |
| sampling_params=params, |
| num_samples=1, |
| ).result() |
|
|
| parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) |
| content = tinker_renderers.format_content_as_string(parsed_message["content"]) |
|
|
| content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() |
| content = re.sub(r"<\|[^|]*\|>", "", content).strip() |
| match = re.search(r"(.{40,}?)\1{4,}", content, flags=re.DOTALL) |
| if match: |
| content = content[: match.start() + len(match.group(1))].strip() |
| if not content or len(content.split()) < 3: |
| raise ValueError("Model output cleanup yielded no usable content.") |
|
|
| return content |
|
|
| except Exception as e: |
| print(f"[MODEL] Tinker error: {e}") |
| return f"[Model error: {e}]" |