from __future__ import annotations import argparse import json import math import sys from pathlib import Path from typing import Any import torch SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) from data import OBS_ROLE_AGENT_FEEDBACK, OBS_ROLE_NONE, OBS_ROLE_USER, quantize_message_sequence, tokenize_text from model import ThoughtLoopT5Gemma def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Interactive inference for the recurrent Samantha SFT model.") parser.add_argument("--model-path", required=True, help="Local export directory or Hugging Face repo id.") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--history-json", default=None, help="Optional JSON array of {t, speaker, text} turns to replay.") parser.add_argument("--user", default=None, help="Optional single user message to answer.") parser.add_argument("--interactive", action="store_true") parser.add_argument("--tick-seconds", type=float, default=None) parser.add_argument("--max-thought-seconds", type=float, default=300.0) parser.add_argument("--settle-thought-seconds", type=float, default=1.0) parser.add_argument("--max-autonomous-utterances", type=int, default=4) parser.add_argument("--gate-threshold", type=float, default=0.5) parser.add_argument("--max-new-tokens", type=int, default=160) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-p", type=float, default=0.95) return parser.parse_args() class ThoughtLoopSession: def __init__( self, model: ThoughtLoopT5Gemma, tick_seconds: float, gate_threshold: float, max_new_tokens: int, temperature: float, top_p: float, ) -> None: self.model = model self.tokenizer = model.tokenizer self.tick_seconds = tick_seconds self.gate_threshold = gate_threshold self.max_new_tokens = max_new_tokens self.temperature = temperature self.top_p = top_p self.z_state = model.initial_state(batch_size=1, device=model.device) self.elapsed_seconds = 0.0 self.next_tick_index = 0 self.feedback_delay_seconds = float(model.config.get("rollout", {}).get("post_speech_feedback_delay_seconds", tick_seconds)) def _advance_raw(self, delta_seconds: float, observation_role: int, observation_text: str | None) -> float: input_ids, attention_mask = tokenize_text( tokenizer=self.tokenizer, text=observation_text, max_length=int(self.model.config["model"]["max_observation_tokens"]), ) input_ids_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.model.device) attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=self.model.device) role_tensor = torch.tensor([observation_role], dtype=torch.long, device=self.model.device) self.elapsed_seconds += delta_seconds with torch.no_grad(): self.z_state, gate_logits = self.model.rollout_step( z_state=self.z_state, observation_input_ids=input_ids_tensor, observation_attention_mask=attention_mask_tensor, observation_role_ids=role_tensor, delta_seconds=torch.tensor([delta_seconds], dtype=torch.float32, device=self.model.device), elapsed_seconds=torch.tensor([self.elapsed_seconds], dtype=torch.float32, device=self.model.device), since_last_user_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), since_last_agent_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), ) return torch.sigmoid(gate_logits).item() def _advance_tick(self, observation_role: int, observation_text: str | None) -> float: delta_seconds = 0.0 if self.next_tick_index == 0 else self.tick_seconds gate_probability = self._advance_raw(delta_seconds, observation_role, observation_text) self.next_tick_index += 1 return gate_probability def replay_history(self, messages: list[dict[str, Any]]) -> None: quantized_events = quantize_message_sequence( messages=messages, tick_seconds=self.tick_seconds, feedback_delay_seconds=self.feedback_delay_seconds, ) events_by_tick = {int(event["tick"]): event for event in quantized_events} if not events_by_tick: return last_tick = max(events_by_tick) while self.next_tick_index <= last_tick: event = events_by_tick.get(self.next_tick_index) if event is None: self._advance_tick(OBS_ROLE_NONE, None) continue if event["kind"] == "user": self._advance_tick(OBS_ROLE_USER, str(event["text"])) elif event["kind"] == "agent_feedback": self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, str(event["text"])) else: self._advance_tick(OBS_ROLE_NONE, None) def _generate_response(self) -> str: generated = self.model.generate_from_state( self.z_state, max_new_tokens=self.max_new_tokens, do_sample=self.temperature > 0.0, temperature=self.temperature, top_p=self.top_p, ) return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip() def respond( self, user_text: str, *, max_thought_seconds: float, settle_thought_seconds: float, max_autonomous_utterances: int, ) -> tuple[list[str], float]: gate_probability = self._advance_tick(OBS_ROLE_USER, user_text) utterances: list[str] = [] thought_ticks = 0 settle_ticks = max(1, int(math.ceil(settle_thought_seconds / self.tick_seconds))) max_thought_ticks = max(1, int(math.ceil(max_thought_seconds / self.tick_seconds))) idle_ticks_after_last_speech = 0 while thought_ticks < max_thought_ticks: if gate_probability >= self.gate_threshold: response = self._generate_response() utterances.append(response) gate_probability = self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, response) thought_ticks += 1 idle_ticks_after_last_speech = 0 if len(utterances) >= max_autonomous_utterances: break continue gate_probability = self._advance_tick(OBS_ROLE_NONE, None) thought_ticks += 1 idle_ticks_after_last_speech += 1 if utterances and idle_ticks_after_last_speech >= settle_ticks: break return utterances, thought_ticks * self.tick_seconds def main() -> None: args = parse_args() model = ThoughtLoopT5Gemma.from_pretrained(args.model_path, device=args.device, map_location=args.device) model.eval() tick_seconds = ( float(args.tick_seconds) if args.tick_seconds is not None else float(model.config.get("rollout", {}).get("tick_seconds", 0.1)) ) session = ThoughtLoopSession( model=model, tick_seconds=tick_seconds, gate_threshold=args.gate_threshold, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, ) if args.history_json: history = json.loads(Path(args.history_json).read_text(encoding="utf-8")) if not isinstance(history, list): raise ValueError("History JSON must be a list of turn objects.") session.replay_history(history) if args.user: responses, thought_time = session.respond( args.user, max_thought_seconds=args.max_thought_seconds, settle_thought_seconds=args.settle_thought_seconds, max_autonomous_utterances=args.max_autonomous_utterances, ) print( json.dumps( { "responses": responses, "thought_seconds": thought_time, }, indent=2, ensure_ascii=False, ) ) if args.interactive: print("Interactive mode. Type 'exit' to stop.") while True: user_text = input("user> ").strip() if user_text.lower() in {"exit", "quit"}: break responses, thought_time = session.respond( user_text, max_thought_seconds=args.max_thought_seconds, settle_thought_seconds=args.settle_thought_seconds, max_autonomous_utterances=args.max_autonomous_utterances, ) if not responses: print(f"agent> [kept thinking for {thought_time:.1f}s and did not cross the speak threshold]") continue for response in responses: print(f"agent> {response}") print(f"[thought_seconds={thought_time:.1f}]") if __name__ == "__main__": main()