"""Tinker inference client. Supports both base models and fine-tuned checkpoints.""" import re import streamlit as st @st.cache_resource def _get_tinker_clients(model_name: str, sampler_path: str = ""): """ Initialise and cache the Tinker sampling client, renderer, and tokenizer. If sampler_path is provided, loads from that checkpoint (fine-tuned model). Otherwise, loads the base model_name. Cache key includes both so different variants get different clients. """ import tinker from tinker import types as tinker_types from tinker_cookbook import renderers from tinker_cookbook.model_info import get_recommended_renderer_name from tinker_cookbook.tokenizer_utils import get_tokenizer service_client = tinker.ServiceClient() if sampler_path: print(f"[MODEL] Loading fine-tuned checkpoint: {sampler_path}") sampling_client = service_client.create_sampling_client(model_path=sampler_path) else: print(f"[MODEL] Loading base model: {model_name}") sampling_client = service_client.create_sampling_client(base_model=model_name) tokenizer = get_tokenizer(model_name) renderer_name = get_recommended_renderer_name(model_name) renderer = renderers.get_renderer(renderer_name, tokenizer) return sampling_client, renderer, tinker_types def call_model(messages: list, cfg: dict) -> str: """Send a message list to Tinker and return cleaned response text.""" model_name = cfg["model_name"] sampler_path = cfg.get("sampler_path", "") print(f"[MODEL] model_name={model_name} sampler_path={sampler_path or '(base)'}") print(f"[MODEL] num_messages={len(messages)}") print(f"[MODEL] roles={[m['role'] for m in messages]}") if messages: print(f"[MODEL] system_prompt[:150]={messages[0]['content'][:150]}") try: from tinker_cookbook import renderers as tinker_renderers sampling_client, renderer, tinker_types = _get_tinker_clients(model_name, sampler_path) prompt = renderer.build_generation_prompt(messages) params = tinker_types.SamplingParams( max_tokens=1000, temperature=0.7, stop=renderer.get_stop_sequences(), ) result = sampling_client.sample( prompt=prompt, sampling_params=params, num_samples=1, ).result() parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) content = tinker_renderers.format_content_as_string(parsed_message["content"]) content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() content = re.sub(r"<\|[^|]*\|>", "", content).strip() match = re.search(r"(.{40,}?)\1{4,}", content, flags=re.DOTALL) if match: content = content[: match.start() + len(match.group(1))].strip() if not content or len(content.split()) < 3: raise ValueError("Model output cleanup yielded no usable content.") return content except Exception as e: print(f"[MODEL] Tinker error: {e}") return f"[Model error: {e}]"