from __future__ import annotations import argparse from collections import deque import json import math import sys import threading import time from pathlib import Path from typing import Any import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) from data import ( OBS_ROLE_AGENT_FEEDBACK, OBS_ROLE_NONE, OBS_ROLE_USER, build_assistant_feedback_observation, build_user_generation_observation, quantize_message_sequence, should_use_base_chat_template, tokenize_text, ) from model import ThoughtLoopT5Gemma def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Interactive inference for Samantha thought-loop checkpoints, with a direct HF baseline fallback." ) parser.add_argument("--model-path", required=True, help="Local export directory or Hugging Face repo id.") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--history-json", default=None, help="Optional JSON array of {t, speaker, text} turns to replay.") parser.add_argument("--user", default=None, help="Optional single user message to answer.") parser.add_argument("--interactive", action="store_true", help="Run a live terminal session.") parser.add_argument("--tick-seconds", type=float, default=None, help="Runtime tick size. Defaults to the trained value.") parser.add_argument("--max-thought-seconds", type=float, default=300.0) parser.add_argument("--settle-thought-seconds", type=float, default=1.0) parser.add_argument("--max-autonomous-utterances", type=int, default=4, help="One-shot --user safety cap.") parser.add_argument("--gate-threshold", type=float, default=0.5, help="Trained binary gate decision threshold.") parser.add_argument("--max-new-tokens", type=int, default=160) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-p", type=float, default=0.95) parser.add_argument( "--debug-gate", "--show-gate", dest="debug_gate", action="store_true", help="Print the gate head probability on every runtime tick.", ) return parser.parse_args() def build_generation_kwargs(args: argparse.Namespace) -> dict[str, Any]: kwargs: dict[str, Any] = { "max_new_tokens": int(args.max_new_tokens), "do_sample": float(args.temperature) > 0.0, } if kwargs["do_sample"]: kwargs["temperature"] = float(args.temperature) kwargs["top_p"] = float(args.top_p) return kwargs def normalize_history_speaker(speaker: str | None) -> str | None: normalized = str(speaker or "").strip().lower() if normalized in {"user", "human", "prompter"}: return "user" if normalized in {"assistant", "agent", "model", "gpt"}: return "assistant" return None def load_base_seq2seq_model( model_path: str, *, device: str | torch.device, ) -> tuple[torch.nn.Module, Any]: resolved_device = torch.device(device) dtype = torch.bfloat16 if resolved_device.type == "cuda" else torch.float32 model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=dtype) tokenizer = AutoTokenizer.from_pretrained(model_path) model.to(resolved_device) model.eval() return model, tokenizer def load_inference_backend( model_path: str, *, device: str | torch.device, ) -> tuple[str, Any, Any | None]: try: model = ThoughtLoopT5Gemma.from_pretrained(model_path, device=device, map_location=device) model.eval() return "thought_loop", model, None except FileNotFoundError as exc: if Path(str(exc.filename)).name != "sft_config.json": raise print( "[inference] No sft_config.json found at model path; loading as a raw Hugging Face seq2seq baseline.", file=sys.stderr, ) base_model, tokenizer = load_base_seq2seq_model(model_path, device=device) return "base_seq2seq", base_model, tokenizer class BaseSeq2SeqSession: def __init__( self, model: torch.nn.Module, tokenizer: Any, device: str | torch.device, generation_kwargs: dict[str, Any], ) -> None: self.model = model self.tokenizer = tokenizer self.device = torch.device(device) self.generation_kwargs = generation_kwargs self.history: list[dict[str, str]] = [] def replay_history(self, messages: list[dict[str, Any]]) -> None: ordered_messages = sorted(messages, key=lambda item: float(item.get("t", 0.0))) for message in ordered_messages: role = normalize_history_speaker(message.get("speaker")) text = str(message.get("text", "")).strip() if role is None or not text: continue self.history.append({"role": role, "content": text}) def _format_prompt(self, user_text: str) -> str: messages = [*self.history, {"role": "user", "content": user_text}] if getattr(self.tokenizer, "chat_template", None): return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) lines: list[str] = [] for message in messages: role = "User" if message["role"] == "user" else "Assistant" lines.append(f"{role}: {message['content']}") lines.append("Assistant:") return "\n".join(lines) def respond( self, user_text: str, *, max_thought_seconds: float, settle_thought_seconds: float, max_autonomous_utterances: int, ) -> tuple[list[str], float]: del max_thought_seconds, settle_thought_seconds, max_autonomous_utterances prompt = self._format_prompt(user_text) encoded = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): generated = self.model.generate(**encoded, **self.generation_kwargs) response = self.tokenizer.decode(generated[0], skip_special_tokens=True).strip() self.history.append({"role": "user", "content": user_text}) self.history.append({"role": "assistant", "content": response}) return [response], 0.0 class TerminalInputBuffer: def __init__(self) -> None: self._lines: deque[str] = deque() self._lock = threading.Lock() def push(self, line: str) -> None: with self._lock: self._lines.append(line) def drain(self) -> list[str]: with self._lock: lines = list(self._lines) self._lines.clear() return lines def prepend_many(self, lines: list[str]) -> None: if not lines: return with self._lock: for line in reversed(lines): self._lines.appendleft(line) def __len__(self) -> int: with self._lock: return len(self._lines) def start_terminal_reader(line_buffer: TerminalInputBuffer, stop_event: threading.Event) -> threading.Thread: def read_loop() -> None: while not stop_event.is_set(): line = sys.stdin.readline() if line == "": stop_event.set() break line_buffer.push(line.rstrip("\n")) thread = threading.Thread(target=read_loop, name="terminal-input-reader", daemon=True) thread.start() return thread def print_gate_debug(tick_index: int, event_name: str, gate_probability: float) -> None: print(f"[tick={tick_index} event={event_name} gate={gate_probability:.6f}]", flush=True) class ThoughtLoopSession: def __init__( self, model: ThoughtLoopT5Gemma, tick_seconds: float, gate_threshold: float, generation_kwargs: dict[str, Any], ) -> None: self.model = model self.tokenizer = model.tokenizer self.tick_seconds = tick_seconds self.gate_threshold = gate_threshold self.generation_kwargs = generation_kwargs self.z_state = model.initial_state(batch_size=1, device=model.device) self.z_mask = model.initial_state_mask(batch_size=1, device=model.device) self.use_base_chat_template = should_use_base_chat_template(model.config, self.tokenizer) self.elapsed_seconds = 0.0 self.next_tick_index = 0 self.feedback_delay_seconds = float(model.config.get("rollout", {}).get("post_speech_feedback_delay_seconds", tick_seconds)) def _format_observation(self, observation_role: int, observation_text: str | None) -> str | None: if observation_text is None: return None if observation_role == OBS_ROLE_USER: return build_user_generation_observation( self.tokenizer, observation_text, self.use_base_chat_template, ) if observation_role == OBS_ROLE_AGENT_FEEDBACK: return build_assistant_feedback_observation( self.tokenizer, observation_text, self.use_base_chat_template, ) return observation_text def _advance_raw(self, delta_seconds: float, observation_role: int, observation_text: str | None) -> float: formatted_observation_text = self._format_observation(observation_role, observation_text) input_ids, attention_mask = tokenize_text( tokenizer=self.tokenizer, text=formatted_observation_text, max_length=int(self.model.config["model"]["max_observation_tokens"]), ) input_ids_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.model.device) attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=self.model.device) role_tensor = torch.tensor([observation_role], dtype=torch.long, device=self.model.device) self.elapsed_seconds += delta_seconds with torch.no_grad(): self.z_state, gate_logits, self.z_mask = self.model.rollout_step_with_mask( z_state=self.z_state, state_mask=self.z_mask, observation_input_ids=input_ids_tensor, observation_attention_mask=attention_mask_tensor, observation_role_ids=role_tensor, delta_seconds=torch.tensor([delta_seconds], dtype=torch.float32, device=self.model.device), elapsed_seconds=torch.tensor([self.elapsed_seconds], dtype=torch.float32, device=self.model.device), since_last_user_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), since_last_agent_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), ) return torch.sigmoid(gate_logits).item() def _advance_tick(self, observation_role: int, observation_text: str | None) -> float: delta_seconds = 0.0 if self.next_tick_index == 0 else self.tick_seconds gate_probability = self._advance_raw(delta_seconds, observation_role, observation_text) self.next_tick_index += 1 return gate_probability def advance_tick(self, observation_role: int, observation_text: str | None) -> float: return self._advance_tick(observation_role, observation_text) def replay_history(self, messages: list[dict[str, Any]]) -> None: quantized_events = quantize_message_sequence( messages=messages, tick_seconds=self.tick_seconds, feedback_delay_seconds=self.feedback_delay_seconds, ) events_by_tick = {int(event["tick"]): event for event in quantized_events} if not events_by_tick: return last_tick = max(events_by_tick) while self.next_tick_index <= last_tick: event = events_by_tick.get(self.next_tick_index) if event is None: self._advance_tick(OBS_ROLE_NONE, None) continue if event["kind"] == "user": self._advance_tick(OBS_ROLE_USER, str(event["text"])) elif event["kind"] == "agent_feedback": self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, str(event["text"])) else: self._advance_tick(OBS_ROLE_NONE, None) def _generate_response(self) -> str: generated = self.model.generate_from_state( self.z_state, encoder_attention_mask=self.z_mask, **self.generation_kwargs, ) return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip() def generate_response(self) -> str: return self._generate_response() def respond( self, user_text: str, *, max_thought_seconds: float, settle_thought_seconds: float, max_autonomous_utterances: int, debug_gate: bool = False, ) -> tuple[list[str], float]: gate_probability = self._advance_tick(OBS_ROLE_USER, user_text) if debug_gate: print_gate_debug(self.next_tick_index - 1, "user", gate_probability) utterances: list[str] = [] thought_ticks = 0 settle_ticks = max(1, int(math.ceil(settle_thought_seconds / self.tick_seconds))) max_thought_ticks = max(1, int(math.ceil(max_thought_seconds / self.tick_seconds))) idle_ticks_after_last_speech = 0 pending_assistant_feedback: str | None = None while thought_ticks < max_thought_ticks: if pending_assistant_feedback is not None: event_name = "assistant_feedback" gate_probability = self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, pending_assistant_feedback) pending_assistant_feedback = None else: event_name = "idle" gate_probability = self._advance_tick(OBS_ROLE_NONE, None) if debug_gate: print_gate_debug(self.next_tick_index - 1, event_name, gate_probability) thought_ticks += 1 idle_ticks_after_last_speech += 1 if gate_probability >= self.gate_threshold: response = self._generate_response() utterances.append(response) idle_ticks_after_last_speech = 0 if len(utterances) >= max_autonomous_utterances: break pending_assistant_feedback = response continue if utterances and idle_ticks_after_last_speech >= settle_ticks: break return utterances, thought_ticks * self.tick_seconds def consume_next_live_user_line( line_buffer: TerminalInputBuffer, *, stop_event: threading.Event, status_text: str, ) -> str | None: lines = line_buffer.drain() if not lines: return None selected_line: str | None = None deferred_lines: list[str] = [] for raw_line in lines: line = raw_line.strip() if not line: continue command = line.lower() if command in {"/exit", "/quit", "exit", "quit"}: stop_event.set() continue if command == "/status": print(status_text, flush=True) continue if selected_line is None: selected_line = line else: deferred_lines.append(line) line_buffer.prepend_many(deferred_lines) return selected_line def run_live_thought_loop_terminal( session: ThoughtLoopSession, *, gate_threshold: float, debug_gate: bool, ) -> None: line_buffer = TerminalInputBuffer() stop_event = threading.Event() start_terminal_reader(line_buffer, stop_event) pending_assistant_feedback: str | None = None print("Live Samantha session.", flush=True) print("Type anytime and press Enter. Each line is one user event on the next 1 Hz tick.", flush=True) print("Commands: /status, /exit", flush=True) while not stop_event.is_set(): tick_started_at = time.monotonic() status_text = ( f"[status] tick={session.next_tick_index} elapsed={session.elapsed_seconds:.1f}s " f"queued_user_lines={len(line_buffer)} pending_assistant_feedback={pending_assistant_feedback is not None}" ) if pending_assistant_feedback is not None: observation_role = OBS_ROLE_AGENT_FEEDBACK observation_text = pending_assistant_feedback event_name = "assistant_feedback" pending_assistant_feedback = None else: user_line = consume_next_live_user_line( line_buffer, stop_event=stop_event, status_text=status_text, ) if stop_event.is_set(): break if user_line is not None: observation_role = OBS_ROLE_USER observation_text = user_line event_name = "user" print(f"user> {user_line}", flush=True) else: observation_role = OBS_ROLE_NONE observation_text = None event_name = "idle" gate_probability = session.advance_tick(observation_role, observation_text) if debug_gate: print_gate_debug(session.next_tick_index - 1, event_name, gate_probability) if observation_role != OBS_ROLE_USER and gate_probability >= gate_threshold: response = session.generate_response() print(f"assistant> {response}", flush=True) pending_assistant_feedback = response sleep_seconds = session.tick_seconds - (time.monotonic() - tick_started_at) if sleep_seconds > 0.0: time.sleep(sleep_seconds) def run_blocking_base_terminal(session: BaseSeq2SeqSession) -> None: print("Raw HF baseline terminal mode. This model has no Samantha Z state or speech gate.", flush=True) print("Type a message and press Enter. Commands: /exit", flush=True) while True: user_text = input("user> ").strip() if user_text.lower() in {"/exit", "/quit", "exit", "quit"}: break if not user_text: continue responses, _ = session.respond( user_text, max_thought_seconds=0.0, settle_thought_seconds=0.0, max_autonomous_utterances=1, ) for response in responses: print(f"assistant> {response}", flush=True) def main() -> None: args = parse_args() backend_kind, model, tokenizer = load_inference_backend(args.model_path, device=args.device) generation_kwargs = build_generation_kwargs(args) if backend_kind == "thought_loop": tick_seconds = ( float(args.tick_seconds) if args.tick_seconds is not None else float(model.config.get("rollout", {}).get("tick_seconds", 1.0)) ) session: Any = ThoughtLoopSession( model=model, tick_seconds=tick_seconds, gate_threshold=args.gate_threshold, generation_kwargs=generation_kwargs, ) else: assert tokenizer is not None session = BaseSeq2SeqSession( model=model, tokenizer=tokenizer, device=args.device, generation_kwargs=generation_kwargs, ) if args.history_json: history = json.loads(Path(args.history_json).read_text(encoding="utf-8")) if not isinstance(history, list): raise ValueError("History JSON must be a list of turn objects.") session.replay_history(history) if args.user: responses, thought_time = session.respond( args.user, max_thought_seconds=args.max_thought_seconds, settle_thought_seconds=args.settle_thought_seconds, max_autonomous_utterances=args.max_autonomous_utterances, debug_gate=bool(args.debug_gate), ) print( json.dumps( { "responses": responses, "thought_seconds": thought_time, }, indent=2, ensure_ascii=False, ) ) if args.interactive or args.user is None: if backend_kind == "thought_loop": run_live_thought_loop_terminal( session, gate_threshold=float(args.gate_threshold), debug_gate=bool(args.debug_gate), ) else: run_blocking_base_terminal(session) if __name__ == "__main__": main()