| """Tinker inference client. Supports both base models and fine-tuned checkpoints.""" |
| import json |
| import re |
|
|
| import streamlit as st |
|
|
|
|
| def _should_print_model_input(cfg: dict) -> bool: |
| """True if user enabled PRINT_MODEL_INPUT env or print_model_input in study_config.yaml.""" |
| return bool(cfg.get("print_model_input", False)) |
|
|
|
|
| def _print_model_input(messages: list, prompt: object) -> None: |
| """Dump logical messages + exact string passed to sampling_client.sample(prompt=...).""" |
| slim = [{"role": m.get("role"), "content": m.get("content")} for m in messages] |
| sep = "=" * 72 |
| print(f"\n{sep}\nMODEL INPUT — logical messages (role/content)\n{sep}", flush=True) |
| print(json.dumps(slim, indent=2, ensure_ascii=False), flush=True) |
| prompt_str = prompt if isinstance(prompt, str) else str(prompt) |
| print(f"\n{sep}\nMODEL INPUT — rendered prompt (build_generation_prompt)\n{sep}", flush=True) |
| print(prompt_str, flush=True) |
| print(f"{sep}\nEND MODEL INPUT\n", flush=True) |
|
|
|
|
| @st.cache_resource |
| def _get_tinker_clients(model_name: str, sampler_path: str = ""): |
| """ |
| Initialise and cache the Tinker sampling client, renderer, and tokenizer. |
| If sampler_path is provided, loads from that checkpoint (fine-tuned model). |
| Otherwise, loads the base model_name. |
| Cache key includes both so different variants get different clients. |
| """ |
| import tinker |
| from tinker import types as tinker_types |
| from tinker_cookbook import renderers |
| from tinker_cookbook.model_info import get_recommended_renderer_name |
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
| service_client = tinker.ServiceClient() |
| if sampler_path: |
| print(f"[MODEL] Loading fine-tuned checkpoint: {sampler_path}") |
| sampling_client = service_client.create_sampling_client(model_path=sampler_path) |
| else: |
| print(f"[MODEL] Loading base model: {model_name}") |
| sampling_client = service_client.create_sampling_client(base_model=model_name) |
|
|
| tokenizer = get_tokenizer(model_name) |
| renderer_name = get_recommended_renderer_name(model_name) |
| renderer = renderers.get_renderer(renderer_name, tokenizer) |
| return sampling_client, renderer, tinker_types |
|
|
|
|
| def call_model(messages: list, cfg: dict) -> str: |
| """Send a message list to Tinker and return cleaned response text.""" |
| model_name = cfg["model_name"] |
| sampler_path = cfg.get("sampler_path", "") |
| temperature = float(cfg.get("sampling_temperature", 1.0)) |
| print( |
| f"[MODEL] model_name={model_name} sampler_path={sampler_path or '(base)'} " |
| f"temperature={temperature}" |
| ) |
| print(f"[MODEL] num_messages={len(messages)}") |
| print(f"[MODEL] roles={[m['role'] for m in messages]}") |
| if messages: |
| print(f"[MODEL] system_prompt[:150]={messages[0]['content'][:150]}") |
|
|
| try: |
| from tinker_cookbook import renderers as tinker_renderers |
|
|
| sampling_client, renderer, tinker_types = _get_tinker_clients(model_name, sampler_path) |
|
|
| prompt = renderer.build_generation_prompt(messages) |
| if _should_print_model_input(cfg): |
| _print_model_input(messages, prompt) |
|
|
| params = tinker_types.SamplingParams( |
| max_tokens=1000, |
| temperature=temperature, |
| stop=renderer.get_stop_sequences(), |
| ) |
| result = sampling_client.sample( |
| prompt=prompt, |
| sampling_params=params, |
| num_samples=1, |
| ).result() |
|
|
| parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) |
| content = tinker_renderers.format_content_as_string(parsed_message["content"]) |
|
|
| content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() |
| content = re.sub(r"<\|[^|]*\|>", "", content).strip() |
| match = re.search(r"(.{40,}?)\1{4,}", content, flags=re.DOTALL) |
| if match: |
| content = content[: match.start() + len(match.group(1))].strip() |
| if not content or len(content.split()) < 3: |
| raise ValueError("Model output cleanup yielded no usable content.") |
|
|
| return content |
|
|
| except Exception as e: |
| print(f"[MODEL] Tinker error: {e}") |
| return f"[Model error: {e}]" |