| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| if str(SCRIPT_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| from data import OBS_ROLE_AGENT_FEEDBACK, OBS_ROLE_NONE, OBS_ROLE_USER, quantize_message_sequence, tokenize_text |
| from model import ThoughtLoopT5Gemma |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Interactive inference for the recurrent Samantha SFT model.") |
| parser.add_argument("--model-path", required=True, help="Local export directory or Hugging Face repo id.") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--history-json", default=None, help="Optional JSON array of {t, speaker, text} turns to replay.") |
| parser.add_argument("--user", default=None, help="Optional single user message to answer.") |
| parser.add_argument("--interactive", action="store_true") |
| parser.add_argument("--tick-seconds", type=float, default=None) |
| parser.add_argument("--max-thought-seconds", type=float, default=300.0) |
| parser.add_argument("--settle-thought-seconds", type=float, default=1.0) |
| parser.add_argument("--max-autonomous-utterances", type=int, default=4) |
| parser.add_argument("--gate-threshold", type=float, default=0.5) |
| parser.add_argument("--max-new-tokens", type=int, default=160) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-p", type=float, default=0.95) |
| return parser.parse_args() |
|
|
|
|
| class ThoughtLoopSession: |
| def __init__( |
| self, |
| model: ThoughtLoopT5Gemma, |
| tick_seconds: float, |
| gate_threshold: float, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| ) -> None: |
| self.model = model |
| self.tokenizer = model.tokenizer |
| self.tick_seconds = tick_seconds |
| self.gate_threshold = gate_threshold |
| self.max_new_tokens = max_new_tokens |
| self.temperature = temperature |
| self.top_p = top_p |
| self.z_state = model.initial_state(batch_size=1, device=model.device) |
| self.elapsed_seconds = 0.0 |
| self.next_tick_index = 0 |
| self.feedback_delay_seconds = float(model.config.get("rollout", {}).get("post_speech_feedback_delay_seconds", tick_seconds)) |
|
|
| def _advance_raw(self, delta_seconds: float, observation_role: int, observation_text: str | None) -> float: |
| input_ids, attention_mask = tokenize_text( |
| tokenizer=self.tokenizer, |
| text=observation_text, |
| max_length=int(self.model.config["model"]["max_observation_tokens"]), |
| ) |
| input_ids_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.model.device) |
| attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=self.model.device) |
| role_tensor = torch.tensor([observation_role], dtype=torch.long, device=self.model.device) |
|
|
| self.elapsed_seconds += delta_seconds |
|
|
| with torch.no_grad(): |
| self.z_state, gate_logits = self.model.rollout_step( |
| z_state=self.z_state, |
| observation_input_ids=input_ids_tensor, |
| observation_attention_mask=attention_mask_tensor, |
| observation_role_ids=role_tensor, |
| delta_seconds=torch.tensor([delta_seconds], dtype=torch.float32, device=self.model.device), |
| elapsed_seconds=torch.tensor([self.elapsed_seconds], dtype=torch.float32, device=self.model.device), |
| since_last_user_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), |
| since_last_agent_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device), |
| ) |
| return torch.sigmoid(gate_logits).item() |
|
|
| def _advance_tick(self, observation_role: int, observation_text: str | None) -> float: |
| delta_seconds = 0.0 if self.next_tick_index == 0 else self.tick_seconds |
| gate_probability = self._advance_raw(delta_seconds, observation_role, observation_text) |
| self.next_tick_index += 1 |
| return gate_probability |
|
|
| def replay_history(self, messages: list[dict[str, Any]]) -> None: |
| quantized_events = quantize_message_sequence( |
| messages=messages, |
| tick_seconds=self.tick_seconds, |
| feedback_delay_seconds=self.feedback_delay_seconds, |
| ) |
| events_by_tick = {int(event["tick"]): event for event in quantized_events} |
| if not events_by_tick: |
| return |
|
|
| last_tick = max(events_by_tick) |
| while self.next_tick_index <= last_tick: |
| event = events_by_tick.get(self.next_tick_index) |
| if event is None: |
| self._advance_tick(OBS_ROLE_NONE, None) |
| continue |
|
|
| if event["kind"] == "user": |
| self._advance_tick(OBS_ROLE_USER, str(event["text"])) |
| elif event["kind"] == "agent_feedback": |
| self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, str(event["text"])) |
| else: |
| self._advance_tick(OBS_ROLE_NONE, None) |
|
|
| def _generate_response(self) -> str: |
| generated = self.model.generate_from_state( |
| self.z_state, |
| max_new_tokens=self.max_new_tokens, |
| do_sample=self.temperature > 0.0, |
| temperature=self.temperature, |
| top_p=self.top_p, |
| ) |
| return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip() |
|
|
| def respond( |
| self, |
| user_text: str, |
| *, |
| max_thought_seconds: float, |
| settle_thought_seconds: float, |
| max_autonomous_utterances: int, |
| ) -> tuple[list[str], float]: |
| gate_probability = self._advance_tick(OBS_ROLE_USER, user_text) |
| utterances: list[str] = [] |
| thought_ticks = 0 |
| settle_ticks = max(1, int(math.ceil(settle_thought_seconds / self.tick_seconds))) |
| max_thought_ticks = max(1, int(math.ceil(max_thought_seconds / self.tick_seconds))) |
| idle_ticks_after_last_speech = 0 |
|
|
| while thought_ticks < max_thought_ticks: |
| if gate_probability >= self.gate_threshold: |
| response = self._generate_response() |
| utterances.append(response) |
| gate_probability = self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, response) |
| thought_ticks += 1 |
| idle_ticks_after_last_speech = 0 |
| if len(utterances) >= max_autonomous_utterances: |
| break |
| continue |
|
|
| gate_probability = self._advance_tick(OBS_ROLE_NONE, None) |
| thought_ticks += 1 |
| idle_ticks_after_last_speech += 1 |
| if utterances and idle_ticks_after_last_speech >= settle_ticks: |
| break |
|
|
| return utterances, thought_ticks * self.tick_seconds |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| model = ThoughtLoopT5Gemma.from_pretrained(args.model_path, device=args.device, map_location=args.device) |
| model.eval() |
| tick_seconds = ( |
| float(args.tick_seconds) |
| if args.tick_seconds is not None |
| else float(model.config.get("rollout", {}).get("tick_seconds", 0.1)) |
| ) |
|
|
| session = ThoughtLoopSession( |
| model=model, |
| tick_seconds=tick_seconds, |
| gate_threshold=args.gate_threshold, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| ) |
|
|
| if args.history_json: |
| history = json.loads(Path(args.history_json).read_text(encoding="utf-8")) |
| if not isinstance(history, list): |
| raise ValueError("History JSON must be a list of turn objects.") |
| session.replay_history(history) |
|
|
| if args.user: |
| responses, thought_time = session.respond( |
| args.user, |
| max_thought_seconds=args.max_thought_seconds, |
| settle_thought_seconds=args.settle_thought_seconds, |
| max_autonomous_utterances=args.max_autonomous_utterances, |
| ) |
| print( |
| json.dumps( |
| { |
| "responses": responses, |
| "thought_seconds": thought_time, |
| }, |
| indent=2, |
| ensure_ascii=False, |
| ) |
| ) |
|
|
| if args.interactive: |
| print("Interactive mode. Type 'exit' to stop.") |
| while True: |
| user_text = input("user> ").strip() |
| if user_text.lower() in {"exit", "quit"}: |
| break |
| responses, thought_time = session.respond( |
| user_text, |
| max_thought_seconds=args.max_thought_seconds, |
| settle_thought_seconds=args.settle_thought_seconds, |
| max_autonomous_utterances=args.max_autonomous_utterances, |
| ) |
| if not responses: |
| print(f"agent> [kept thinking for {thought_time:.1f}s and did not cross the speak threshold]") |
| continue |
| for response in responses: |
| print(f"agent> {response}") |
| print(f"[thought_seconds={thought_time:.1f}]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|