Spaces:
Running
Running
| """Tinker inference client. Supports both base models and fine-tuned checkpoints.""" | |
| import re | |
| import streamlit as st | |
| def _get_tinker_clients(model_name: str, sampler_path: str = ""): | |
| """ | |
| Initialise and cache the Tinker sampling client, renderer, and tokenizer. | |
| If sampler_path is provided, loads from that checkpoint (fine-tuned model). | |
| Otherwise, loads the base model_name. | |
| Cache key includes both so different variants get different clients. | |
| """ | |
| import tinker | |
| from tinker import types as tinker_types | |
| from tinker_cookbook import renderers | |
| from tinker_cookbook.model_info import get_recommended_renderer_name | |
| from tinker_cookbook.tokenizer_utils import get_tokenizer | |
| service_client = tinker.ServiceClient() | |
| if sampler_path: | |
| print(f"[MODEL] Loading fine-tuned checkpoint: {sampler_path}") | |
| sampling_client = service_client.create_sampling_client(model_path=sampler_path) | |
| else: | |
| print(f"[MODEL] Loading base model: {model_name}") | |
| sampling_client = service_client.create_sampling_client(base_model=model_name) | |
| tokenizer = get_tokenizer(model_name) | |
| renderer_name = get_recommended_renderer_name(model_name) | |
| renderer = renderers.get_renderer(renderer_name, tokenizer) | |
| return sampling_client, renderer, tinker_types | |
| def call_model(messages: list, cfg: dict) -> str: | |
| """Send a message list to Tinker and return cleaned response text.""" | |
| model_name = cfg["model_name"] | |
| sampler_path = cfg.get("sampler_path", "") | |
| print(f"[MODEL] model_name={model_name} sampler_path={sampler_path or '(base)'}") | |
| print(f"[MODEL] num_messages={len(messages)}") | |
| print(f"[MODEL] roles={[m['role'] for m in messages]}") | |
| if messages: | |
| print(f"[MODEL] system_prompt[:150]={messages[0]['content'][:150]}") | |
| try: | |
| from tinker_cookbook import renderers as tinker_renderers | |
| sampling_client, renderer, tinker_types = _get_tinker_clients(model_name, sampler_path) | |
| prompt = renderer.build_generation_prompt(messages) | |
| params = tinker_types.SamplingParams( | |
| max_tokens=1000, | |
| temperature=0.7, | |
| stop=renderer.get_stop_sequences(), | |
| ) | |
| result = sampling_client.sample( | |
| prompt=prompt, | |
| sampling_params=params, | |
| num_samples=1, | |
| ).result() | |
| parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) | |
| content = tinker_renderers.format_content_as_string(parsed_message["content"]) | |
| content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() | |
| content = re.sub(r"<\|[^|]*\|>", "", content).strip() | |
| match = re.search(r"(.{40,}?)\1{4,}", content, flags=re.DOTALL) | |
| if match: | |
| content = content[: match.start() + len(match.group(1))].strip() | |
| if not content or len(content.split()) < 3: | |
| raise ValueError("Model output cleanup yielded no usable content.") | |
| return content | |
| except Exception as e: | |
| print(f"[MODEL] Tinker error: {e}") | |
| return f"[Model error: {e}]" |