| from __future__ import annotations |
|
|
| import argparse |
| from collections import deque |
| import json |
| import math |
| import sys |
| import threading |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| if str(SCRIPT_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| from data import ( |
| OBS_ROLE_AGENT_FEEDBACK, |
| OBS_ROLE_NONE, |
| OBS_ROLE_USER, |
| build_assistant_feedback_observation, |
| build_user_generation_observation, |
| quantize_message_sequence, |
| should_use_base_chat_template, |
| tokenize_text, |
| ) |
| from model import ThoughtLoopT5Gemma |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Interactive inference for Samantha thought-loop checkpoints, with a direct HF baseline fallback." |
| ) |
| parser.add_argument("--model-path", required=True, help="Local export directory or Hugging Face repo id.") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--history-json", default=None, help="Optional JSON array of {t, speaker, text} turns to replay.") |
| parser.add_argument("--user", default=None, help="Optional single user message to answer.") |
| parser.add_argument("--interactive", action="store_true", help="Run a live terminal session.") |
| parser.add_argument("--tick-seconds", type=float, default=None, help="Runtime tick size. Defaults to the trained value.") |
| parser.add_argument("--max-thought-seconds", type=float, default=300.0) |
| parser.add_argument("--settle-thought-seconds", type=float, default=1.0) |
| parser.add_argument("--max-autonomous-utterances", type=int, default=4, help="One-shot --user safety cap.") |
| parser.add_argument("--gate-threshold", type=float, default=0.5, help="Trained binary gate decision threshold.") |
| parser.add_argument("--max-new-tokens", type=int, default=160) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-p", type=float, default=0.95) |
| parser.add_argument( |
| "--debug-gate", |
| "--show-gate", |
| dest="debug_gate", |
| action="store_true", |
| help="Print the gate head probability on every runtime tick.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def build_generation_kwargs(args: argparse.Namespace) -> dict[str, Any]: |
| kwargs: dict[str, Any] = { |
| "max_new_tokens": int(args.max_new_tokens), |
| "do_sample": float(args.temperature) > 0.0, |
| } |
| if kwargs["do_sample"]: |
| kwargs["temperature"] = float(args.temperature) |
| kwargs["top_p"] = float(args.top_p) |
| return kwargs |
|
|
|
|
| def normalize_history_speaker(speaker: str | None) -> str | None: |
| normalized = str(speaker or "").strip().lower() |
| if normalized in {"user", "human", "prompter"}: |
| return "user" |
| if normalized in {"assistant", "agent", "model", "gpt"}: |
| return "assistant" |
| return None |
|
|
|
|
| def load_base_seq2seq_model( |
| model_path: str, |
| *, |
| device: str | torch.device, |
| ) -> tuple[torch.nn.Module, Any]: |
| resolved_device = torch.device(device) |
| dtype = torch.bfloat16 if resolved_device.type == "cuda" else torch.float32 |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=dtype) |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model.to(resolved_device) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| def load_inference_backend( |
| model_path: str, |
| *, |
| device: str | torch.device, |
| ) -> tuple[str, Any, Any | None]: |
| try: |
| model = ThoughtLoopT5Gemma.from_pretrained(model_path, device=device, map_location=device) |
| model.eval() |
| return "thought_loop", model, None |
| except FileNotFoundError as exc: |
| if Path(str(exc.filename)).name != "sft_config.json": |
| raise |
| print( |
| "[inference] No sft_config.json found at model path; loading as a raw Hugging Face seq2seq baseline.", |
| file=sys.stderr, |
| ) |
| base_model, tokenizer = load_base_seq2seq_model(model_path, device=device) |
| return "base_seq2seq", base_model, tokenizer |
|
|
|
|
| class BaseSeq2SeqSession: |
| def __init__( |
| self, |
| model: torch.nn.Module, |
| tokenizer: Any, |
| device: str | torch.device, |
| generation_kwargs: dict[str, Any], |
| ) -> None: |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = torch.device(device) |
| self.generation_kwargs = generation_kwargs |
| self.history: list[dict[str, str]] = [] |
|
|
| def replay_history(self, messages: list[dict[str, Any]]) -> None: |
| ordered_messages = sorted(messages, key=lambda item: float(item.get("t", 0.0))) |
| for message in ordered_messages: |
| role = normalize_history_speaker(message.get("speaker")) |
| text = str(message.get("text", "")).strip() |
| if role is None or not text: |
| continue |
| self.history.append({"role": role, "content": text}) |
|
|
| def _format_prompt(self, user_text: str) -> str: |
| messages = [*self.history, {"role": "user", "content": user_text}] |
| if getattr(self.tokenizer, "chat_template", None): |
| return self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| lines: list[str] = [] |
| for message in messages: |
| role = "User" if message["role"] == "user" else "Assistant" |
| lines.append(f"{role}: {message['content']}") |
| lines.append("Assistant:") |
| return "\n".join(lines) |
|
|
| def respond( |
| self, |
| user_text: str, |
| *, |
| max_thought_seconds: float, |
| settle_thought_seconds: float, |
| max_autonomous_utterances: int, |
| ) -> tuple[list[str], float]: |
| del max_thought_seconds, settle_thought_seconds, max_autonomous_utterances |
| prompt = self._format_prompt(user_text) |
| encoded = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| generated = self.model.generate(**encoded, **self.generation_kwargs) |
| response = self.tokenizer.decode(generated[0], skip_special_tokens=True).strip() |
| self.history.append({"role": "user", "content": user_text}) |
| self.history.append({"role": "assistant", "content": response}) |
| return [response], 0.0 |
|
|
|
|
| class TerminalInputBuffer: |
| def __init__(self) -> None: |
| self._lines: deque[str] = deque() |
| self._lock = threading.Lock() |
|
|
| def push(self, line: str) -> None: |
| with self._lock: |
| self._lines.append(line) |
|
|
| def drain(self) -> list[str]: |
| with self._lock: |
| lines = list(self._lines) |
| self._lines.clear() |
| return lines |
|
|
| def prepend_many(self, lines: list[str]) -> None: |
| if not lines: |
| return |
| with self._lock: |
| for line in reversed(lines): |
| self._lines.appendleft(line) |
|
|
| def __len__(self) -> int: |
| with self._lock: |
| return len(self._lines) |
|
|
|
|
| def start_terminal_reader(line_buffer: TerminalInputBuffer, stop_event: threading.Event) -> threading.Thread: |
| def read_loop() -> None: |
| while not stop_event.is_set(): |
| line = sys.stdin.readline() |
| if line == "": |
| stop_event.set() |
| break |
| line_buffer.push(line.rstrip("\n")) |
|
|
| thread = threading.Thread(target=read_loop, name="terminal-input-reader", daemon=True) |
| thread.start() |
| return thread |
|
|
|
|
| def print_gate_debug(tick_index: int, event_name: str, gate_probability: float) -> None: |
| print(f"[tick={tick_index} event={event_name} gate={gate_probability:.6f}]", flush=True) |
|
|
|
|
| class ThoughtLoopSession: |
| def __init__( |
| self, |
| model: ThoughtLoopT5Gemma, |
| tick_seconds: float, |
| gate_threshold: float, |
| generation_kwargs: dict[str, Any], |
| ) -> None: |
| self.model = model |
| self.tokenizer = model.tokenizer |
| self.tick_seconds = tick_seconds |
| self.gate_threshold = gate_threshold |
| self.generation_kwargs = generation_kwargs |
| self.z_state = model.initial_state(batch_size=1, device=model.device) |
| self.z_mask = model.initial_state_mask(batch_size=1, device=model.device) |
| self.use_base_chat_template = should_use_base_chat_template(model.config, self.tokenizer) |
| self.elapsed_seconds = 0.0 |
| self.next_tick_index = 0 |
| self.feedback_delay_seconds = float(model.config.get("rollout", {}).get("post_speech_feedback_delay_seconds", tick_seconds)) |
|
|
| def _format_observation(self, observation_role: int, observation_text: str | None) -> str | None: |
| if observation_text is None: |
| return None |
| if observation_role == OBS_ROLE_USER: |
| return build_user_generation_observation( |
| self.tokenizer, |
| observation_text, |
| self.use_base_chat_template, |
| ) |
| if observation_role == OBS_ROLE_AGENT_FEEDBACK: |
| return build_assistant_feedback_observation( |
| self.tokenizer, |
| observation_text, |
| self.use_base_chat_template, |
| ) |
| return observation_text |
|
|
| def _advance_raw(self, delta_seconds: float, observation_role: int, observation_text: str | None) -> float: |
| formatted_observation_text = self._format_observation(observation_role, observation_text) |
| input_ids, attention_mask = tokenize_text( |
| tokenizer=self.tokenizer, |
| text=formatted_observation_text, |
| max_length=int(self.model.config["model"]["max_observation_tokens"]), |
| ) |
| input_ids_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.model.device) |
| attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=self.model.device) |
| role_tensor = torch.tensor([observation_role], dtype=torch.long, device=self.model.device) |
|
|
| self.elapsed_seconds += delta_seconds |
|
|
| with torch.no_grad(): |
| self.z_state, gate_logits, self.z_mask = self.model.rollout_step_with_mask( |
| z_state=self.z_state, |
| state_mask=self.z_mask, |
| observation_input_ids=input_ids_tensor, |
| observation_attention_mask=attention_mask_tensor, |
| observation_role_ids=role_tensor, |
| delta_seconds=torch.tensor([delta_seconds], dtype=torch.float32, device=self.model.device), |
| elapsed_seconds=torch.tensor([self.elapsed_seconds], dtype=torch.float32, device=self.model.device), |
| since_last_user_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), |
| since_last_agent_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), |
| ) |
| return torch.sigmoid(gate_logits).item() |
|
|
| def _advance_tick(self, observation_role: int, observation_text: str | None) -> float: |
| delta_seconds = 0.0 if self.next_tick_index == 0 else self.tick_seconds |
| gate_probability = self._advance_raw(delta_seconds, observation_role, observation_text) |
| self.next_tick_index += 1 |
| return gate_probability |
|
|
| def advance_tick(self, observation_role: int, observation_text: str | None) -> float: |
| return self._advance_tick(observation_role, observation_text) |
|
|
| def replay_history(self, messages: list[dict[str, Any]]) -> None: |
| quantized_events = quantize_message_sequence( |
| messages=messages, |
| tick_seconds=self.tick_seconds, |
| feedback_delay_seconds=self.feedback_delay_seconds, |
| ) |
| events_by_tick = {int(event["tick"]): event for event in quantized_events} |
| if not events_by_tick: |
| return |
|
|
| last_tick = max(events_by_tick) |
| while self.next_tick_index <= last_tick: |
| event = events_by_tick.get(self.next_tick_index) |
| if event is None: |
| self._advance_tick(OBS_ROLE_NONE, None) |
| continue |
|
|
| if event["kind"] == "user": |
| self._advance_tick(OBS_ROLE_USER, str(event["text"])) |
| elif event["kind"] == "agent_feedback": |
| self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, str(event["text"])) |
| else: |
| self._advance_tick(OBS_ROLE_NONE, None) |
|
|
| def _generate_response(self) -> str: |
| generated = self.model.generate_from_state( |
| self.z_state, |
| encoder_attention_mask=self.z_mask, |
| **self.generation_kwargs, |
| ) |
| return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip() |
|
|
| def generate_response(self) -> str: |
| return self._generate_response() |
|
|
| def respond( |
| self, |
| user_text: str, |
| *, |
| max_thought_seconds: float, |
| settle_thought_seconds: float, |
| max_autonomous_utterances: int, |
| debug_gate: bool = False, |
| ) -> tuple[list[str], float]: |
| gate_probability = self._advance_tick(OBS_ROLE_USER, user_text) |
| if debug_gate: |
| print_gate_debug(self.next_tick_index - 1, "user", gate_probability) |
| utterances: list[str] = [] |
| thought_ticks = 0 |
| settle_ticks = max(1, int(math.ceil(settle_thought_seconds / self.tick_seconds))) |
| max_thought_ticks = max(1, int(math.ceil(max_thought_seconds / self.tick_seconds))) |
| idle_ticks_after_last_speech = 0 |
| pending_assistant_feedback: str | None = None |
|
|
| while thought_ticks < max_thought_ticks: |
| if pending_assistant_feedback is not None: |
| event_name = "assistant_feedback" |
| gate_probability = self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, pending_assistant_feedback) |
| pending_assistant_feedback = None |
| else: |
| event_name = "idle" |
| gate_probability = self._advance_tick(OBS_ROLE_NONE, None) |
| if debug_gate: |
| print_gate_debug(self.next_tick_index - 1, event_name, gate_probability) |
| thought_ticks += 1 |
| idle_ticks_after_last_speech += 1 |
|
|
| if gate_probability >= self.gate_threshold: |
| response = self._generate_response() |
| utterances.append(response) |
| idle_ticks_after_last_speech = 0 |
| if len(utterances) >= max_autonomous_utterances: |
| break |
| pending_assistant_feedback = response |
| continue |
|
|
| if utterances and idle_ticks_after_last_speech >= settle_ticks: |
| break |
|
|
| return utterances, thought_ticks * self.tick_seconds |
|
|
|
|
| def consume_next_live_user_line( |
| line_buffer: TerminalInputBuffer, |
| *, |
| stop_event: threading.Event, |
| status_text: str, |
| ) -> str | None: |
| lines = line_buffer.drain() |
| if not lines: |
| return None |
|
|
| selected_line: str | None = None |
| deferred_lines: list[str] = [] |
| for raw_line in lines: |
| line = raw_line.strip() |
| if not line: |
| continue |
| command = line.lower() |
| if command in {"/exit", "/quit", "exit", "quit"}: |
| stop_event.set() |
| continue |
| if command == "/status": |
| print(status_text, flush=True) |
| continue |
| if selected_line is None: |
| selected_line = line |
| else: |
| deferred_lines.append(line) |
|
|
| line_buffer.prepend_many(deferred_lines) |
| return selected_line |
|
|
|
|
| def run_live_thought_loop_terminal( |
| session: ThoughtLoopSession, |
| *, |
| gate_threshold: float, |
| debug_gate: bool, |
| ) -> None: |
| line_buffer = TerminalInputBuffer() |
| stop_event = threading.Event() |
| start_terminal_reader(line_buffer, stop_event) |
| pending_assistant_feedback: str | None = None |
|
|
| print("Live Samantha session.", flush=True) |
| print("Type anytime and press Enter. Each line is one user event on the next 1 Hz tick.", flush=True) |
| print("Commands: /status, /exit", flush=True) |
|
|
| while not stop_event.is_set(): |
| tick_started_at = time.monotonic() |
| status_text = ( |
| f"[status] tick={session.next_tick_index} elapsed={session.elapsed_seconds:.1f}s " |
| f"queued_user_lines={len(line_buffer)} pending_assistant_feedback={pending_assistant_feedback is not None}" |
| ) |
|
|
| if pending_assistant_feedback is not None: |
| observation_role = OBS_ROLE_AGENT_FEEDBACK |
| observation_text = pending_assistant_feedback |
| event_name = "assistant_feedback" |
| pending_assistant_feedback = None |
| else: |
| user_line = consume_next_live_user_line( |
| line_buffer, |
| stop_event=stop_event, |
| status_text=status_text, |
| ) |
| if stop_event.is_set(): |
| break |
|
|
| if user_line is not None: |
| observation_role = OBS_ROLE_USER |
| observation_text = user_line |
| event_name = "user" |
| print(f"user> {user_line}", flush=True) |
| else: |
| observation_role = OBS_ROLE_NONE |
| observation_text = None |
| event_name = "idle" |
|
|
| gate_probability = session.advance_tick(observation_role, observation_text) |
| if debug_gate: |
| print_gate_debug(session.next_tick_index - 1, event_name, gate_probability) |
|
|
| if observation_role != OBS_ROLE_USER and gate_probability >= gate_threshold: |
| response = session.generate_response() |
| print(f"assistant> {response}", flush=True) |
| pending_assistant_feedback = response |
|
|
| sleep_seconds = session.tick_seconds - (time.monotonic() - tick_started_at) |
| if sleep_seconds > 0.0: |
| time.sleep(sleep_seconds) |
|
|
|
|
| def run_blocking_base_terminal(session: BaseSeq2SeqSession) -> None: |
| print("Raw HF baseline terminal mode. This model has no Samantha Z state or speech gate.", flush=True) |
| print("Type a message and press Enter. Commands: /exit", flush=True) |
| while True: |
| user_text = input("user> ").strip() |
| if user_text.lower() in {"/exit", "/quit", "exit", "quit"}: |
| break |
| if not user_text: |
| continue |
| responses, _ = session.respond( |
| user_text, |
| max_thought_seconds=0.0, |
| settle_thought_seconds=0.0, |
| max_autonomous_utterances=1, |
| ) |
| for response in responses: |
| print(f"assistant> {response}", flush=True) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| backend_kind, model, tokenizer = load_inference_backend(args.model_path, device=args.device) |
| generation_kwargs = build_generation_kwargs(args) |
|
|
| if backend_kind == "thought_loop": |
| tick_seconds = ( |
| float(args.tick_seconds) |
| if args.tick_seconds is not None |
| else float(model.config.get("rollout", {}).get("tick_seconds", 1.0)) |
| ) |
| session: Any = ThoughtLoopSession( |
| model=model, |
| tick_seconds=tick_seconds, |
| gate_threshold=args.gate_threshold, |
| generation_kwargs=generation_kwargs, |
| ) |
| else: |
| assert tokenizer is not None |
| session = BaseSeq2SeqSession( |
| model=model, |
| tokenizer=tokenizer, |
| device=args.device, |
| generation_kwargs=generation_kwargs, |
| ) |
|
|
| if args.history_json: |
| history = json.loads(Path(args.history_json).read_text(encoding="utf-8")) |
| if not isinstance(history, list): |
| raise ValueError("History JSON must be a list of turn objects.") |
| session.replay_history(history) |
|
|
| if args.user: |
| responses, thought_time = session.respond( |
| args.user, |
| max_thought_seconds=args.max_thought_seconds, |
| settle_thought_seconds=args.settle_thought_seconds, |
| max_autonomous_utterances=args.max_autonomous_utterances, |
| debug_gate=bool(args.debug_gate), |
| ) |
| print( |
| json.dumps( |
| { |
| "responses": responses, |
| "thought_seconds": thought_time, |
| }, |
| indent=2, |
| ensure_ascii=False, |
| ) |
| ) |
|
|
| if args.interactive or args.user is None: |
| if backend_kind == "thought_loop": |
| run_live_thought_loop_terminal( |
| session, |
| gate_threshold=float(args.gate_threshold), |
| debug_gate=bool(args.debug_gate), |
| ) |
| else: |
| run_blocking_base_terminal(session) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|