"""Shared generation helpers for Horizon 2 (causal LMs, optional RAG context).""" from __future__ import annotations import json import time from dataclasses import dataclass, asdict from typing import Any # Small model for fast verification / CPU smoke (poor text quality; use --model for real runs). SMOKE_MODEL_ID = "sshleifer/tiny-gpt2" # Sensible default for local quality (still small; override with HORIZON2_DEFAULT_MODEL or --model). DEFAULT_INSTRUCTION_MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct" @dataclass class OneSample: id: int input: str output: str seconds: float n_prompt_tokens: int n_new_tokens: int def pick_device(explicit: str) -> str: import torch if explicit == "auto": if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available() if hasattr(torch.backends, "mps") else False: # type: ignore[union-attr] return "mps" return "cpu" return explicit def set_seed(seed: int) -> None: import random import numpy as np import torch random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def build_user_prompt( task: str, text: str, *, context: str | None = None, ) -> str: c = (context or "").strip() if c: ctx_block = ( "You must use ONLY the following CONTEXT; do not invent facts.\n\n" f"CONTEXT:\n{c}\n\n" ) else: ctx_block = "" t = text.strip() if task == "summarize": return ( f"{ctx_block}Summarize the user text in 2-4 short sentences. Be concise.\n\n" f"USER_TEXT:\n{t}" ) if task == "reformulate": return ( f"{ctx_block}Rewrite USER_TEXT as a clear, professional support reply. " f"Keep the same meaning. Under 120 words if possible.\n\n" f"USER_TEXT:\n{t}" ) if task == "grounded": if not c: raise ValueError("task 'grounded' requires --context or --context-file") return ( f"{ctx_block}Answer the user using ONLY the context above. If the context does not " f"contain the answer, say you do not have enough information.\n\nUSER_QUESTION:\n{t}" ) raise ValueError(f"unknown task: {task!r} (use summarize, reformulate, or grounded)") DEFAULT_CHAT_SYSTEM = ( "You are Universal Brain, a concise and accurate assistant. " "Answer the user clearly. If you lack information, say so. " "Keep replies focused unless the user asks for depth." ) ChatMessage = dict[str, str] # role, content def format_multiturn_for_model( tokenizer: Any, messages: list[ChatMessage], ) -> str: """Build a single prompt string from chat history (OpenAI-style role dicts).""" clean: list[dict[str, str]] = [] for m in messages: role = (m.get("role") or "").strip().lower() content = (m.get("content") or "").strip() if not content or role not in ("system", "user", "assistant"): continue clean.append({"role": role, "content": content}) if not clean: raise ValueError("no valid chat messages") if getattr(tokenizer, "chat_template", None): try: return tokenizer.apply_chat_template( clean, tokenize=False, add_generation_prompt=True, ) except Exception: pass chunks: list[str] = [] for m in clean: label = m["role"].upper() chunks.append(f"{label}: {m['content']}") chunks.append("ASSISTANT:") return "\n\n".join(chunks) def generate_chat_reply( lm: LoadedLM, messages: list[ChatMessage], *, max_new_tokens: int, seed: int, do_sample: bool = True, ) -> tuple[str, int, int, float]: """Complete the next assistant turn given full message list (incl. system/user/assistant).""" prompt = format_multiturn_for_model(lm.tokenizer, messages) return generate_completion( lm, prompt, max_new_tokens=max_new_tokens, seed=seed, do_sample=do_sample, ) def format_for_model( tokenizer: Any, user_prompt: str, ) -> str: if getattr(tokenizer, "chat_template", None): try: messages = [{"role": "user", "content": user_prompt}] return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except Exception: pass return f"{user_prompt}\n\n### Assistant\n" @dataclass class LoadedLM: model: Any tokenizer: Any device: str def load_causal_lm( model_id: str, device: str, ) -> LoadedLM: import os import sys # Must run before `import torch` on first use (e.g. horizon2_server on Windows). if sys.platform == "win32": os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import torch from transformers import AutoModelForCausalLM, AutoTokenizer if sys.platform == "win32": torch.set_num_threads(1) try: torch.set_num_interop_threads(1) except RuntimeError: pass d = device if device in ("cpu", "cuda", "mps") else "cpu" tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) if tok.pad_token is None and tok.eos_token is not None: tok.pad_token = tok.eos_token if d == "cuda": dt = ( torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 ) else: dt = torch.float32 def _from_pretrained(extra: dict[str, Any]) -> Any: # Prefer `dtype` (newer Transformers); fall back to `torch_dtype` (older). try: return AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, dtype=dt, **extra ) except TypeError: return AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=dt, **extra ) # Retry with progressively fewer options (compat + stability on Windows CPU). if d == "cpu": extras: tuple[dict[str, Any], ...] = ( {"low_cpu_mem_usage": True, "attn_implementation": "eager"}, {"low_cpu_mem_usage": True}, {}, ) else: extras = ({"low_cpu_mem_usage": True}, {}) model = None last_err: BaseException | None = None for extra in extras: try: model = _from_pretrained(extra) break except (TypeError, ValueError, OSError) as e: last_err = e continue if model is None: raise RuntimeError( f"Failed to load causal LM {model_id!r}; last error: {last_err!r}" ) from last_err model.eval() model = model.to(d) return LoadedLM(model=model, tokenizer=tok, device=d) def generate_completion( lm: LoadedLM, prompt: str, *, max_new_tokens: int, seed: int, do_sample: bool = True, ) -> tuple[str, int, int, float]: import torch from transformers import set_seed as hf_set_seed set_seed(seed) hf_set_seed(seed) tok = lm.tokenizer t0 = time.perf_counter() enc = tok( prompt, return_tensors="pt", truncation=True, max_length=2048, padding="longest", ) input_ids = enc["input_ids"] attention_mask = enc.get("attention_mask") if lm.device == "cuda": input_ids = input_ids.to("cuda") if attention_mask is not None: attention_mask = attention_mask.to("cuda") elif lm.device == "mps": input_ids = input_ids.to("mps") if attention_mask is not None: attention_mask = attention_mask.to("mps") n_prompt = int(input_ids.shape[1]) gen_kw: dict[str, Any] = { "max_new_tokens": max_new_tokens, "pad_token_id": tok.eos_token_id, } if attention_mask is not None: gen_kw["attention_mask"] = attention_mask if do_sample: gen_kw["do_sample"] = True gen_kw["temperature"] = 0.7 gen_kw["top_p"] = 0.9 else: gen_kw["do_sample"] = False with torch.inference_mode(): out = lm.model.generate(input_ids, **gen_kw) full = out[0] new_tokens = full[n_prompt:] text = tok.decode(new_tokens, skip_special_tokens=True) text = (text or "").strip() dt = time.perf_counter() - t0 n_new = int(new_tokens.shape[0]) return text, n_prompt, n_new, dt def run_json_artifact( *, model_id: str, device: str, task: str, max_new_tokens: int, seed: int, samples_in: list[tuple[str, str | None]], do_sample: bool = True, ) -> dict[str, Any]: import transformers lm = load_causal_lm(model_id, device) out_samples: list[OneSample] = [] for i, (raw_text, ctx) in enumerate(samples_in): up = build_user_prompt(task, raw_text, context=ctx) prompt = format_for_model(lm.tokenizer, up) out, np_, nn_, sec = generate_completion( lm, prompt, max_new_tokens=max_new_tokens, seed=seed + i, do_sample=do_sample, ) out_samples.append( OneSample( id=i, input=raw_text, output=out, seconds=round(sec, 4), n_prompt_tokens=np_, n_new_tokens=nn_, ) ) return { "horizon": 2, "schema": "horizon2_generative_run/1.0", "model_id": model_id, "device": lm.device, "transformers_version": transformers.__version__, "task": task, "max_new_tokens": max_new_tokens, "seed": seed, "samples": [asdict(s) for s in out_samples], } def dump_json(d: dict[str, Any], path: str) -> None: p = __import__("pathlib").Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(d, indent=2) + "\n", encoding="utf-8")