test-true2 / inference.py
BRlkl's picture
Upload folder using huggingface_hub
bc7437c verified
Raw
History Blame Contribute Delete
20.8 kB
from __future__ import annotations
import argparse
from collections import deque
import json
import math
import sys
import threading
import time
from pathlib import Path
from typing import Any
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
from data import (
OBS_ROLE_AGENT_FEEDBACK,
OBS_ROLE_NONE,
OBS_ROLE_USER,
build_assistant_feedback_observation,
build_user_generation_observation,
quantize_message_sequence,
should_use_base_chat_template,
tokenize_text,
)
from model import ThoughtLoopT5Gemma
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Interactive inference for Samantha thought-loop checkpoints, with a direct HF baseline fallback."
)
parser.add_argument("--model-path", required=True, help="Local export directory or Hugging Face repo id.")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--history-json", default=None, help="Optional JSON array of {t, speaker, text} turns to replay.")
parser.add_argument("--user", default=None, help="Optional single user message to answer.")
parser.add_argument("--interactive", action="store_true", help="Run a live terminal session.")
parser.add_argument("--tick-seconds", type=float, default=None, help="Runtime tick size. Defaults to the trained value.")
parser.add_argument("--max-thought-seconds", type=float, default=300.0)
parser.add_argument("--settle-thought-seconds", type=float, default=1.0)
parser.add_argument("--max-autonomous-utterances", type=int, default=4, help="One-shot --user safety cap.")
parser.add_argument("--gate-threshold", type=float, default=0.5, help="Trained binary gate decision threshold.")
parser.add_argument("--max-new-tokens", type=int, default=160)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument(
"--debug-gate",
"--show-gate",
dest="debug_gate",
action="store_true",
help="Print the gate head probability on every runtime tick.",
)
return parser.parse_args()
def build_generation_kwargs(args: argparse.Namespace) -> dict[str, Any]:
kwargs: dict[str, Any] = {
"max_new_tokens": int(args.max_new_tokens),
"do_sample": float(args.temperature) > 0.0,
}
if kwargs["do_sample"]:
kwargs["temperature"] = float(args.temperature)
kwargs["top_p"] = float(args.top_p)
return kwargs
def normalize_history_speaker(speaker: str | None) -> str | None:
normalized = str(speaker or "").strip().lower()
if normalized in {"user", "human", "prompter"}:
return "user"
if normalized in {"assistant", "agent", "model", "gpt"}:
return "assistant"
return None
def load_base_seq2seq_model(
model_path: str,
*,
device: str | torch.device,
) -> tuple[torch.nn.Module, Any]:
resolved_device = torch.device(device)
dtype = torch.bfloat16 if resolved_device.type == "cuda" else torch.float32
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(resolved_device)
model.eval()
return model, tokenizer
def load_inference_backend(
model_path: str,
*,
device: str | torch.device,
) -> tuple[str, Any, Any | None]:
try:
model = ThoughtLoopT5Gemma.from_pretrained(model_path, device=device, map_location=device)
model.eval()
return "thought_loop", model, None
except FileNotFoundError as exc:
if Path(str(exc.filename)).name != "sft_config.json":
raise
print(
"[inference] No sft_config.json found at model path; loading as a raw Hugging Face seq2seq baseline.",
file=sys.stderr,
)
base_model, tokenizer = load_base_seq2seq_model(model_path, device=device)
return "base_seq2seq", base_model, tokenizer
class BaseSeq2SeqSession:
def __init__(
self,
model: torch.nn.Module,
tokenizer: Any,
device: str | torch.device,
generation_kwargs: dict[str, Any],
) -> None:
self.model = model
self.tokenizer = tokenizer
self.device = torch.device(device)
self.generation_kwargs = generation_kwargs
self.history: list[dict[str, str]] = []
def replay_history(self, messages: list[dict[str, Any]]) -> None:
ordered_messages = sorted(messages, key=lambda item: float(item.get("t", 0.0)))
for message in ordered_messages:
role = normalize_history_speaker(message.get("speaker"))
text = str(message.get("text", "")).strip()
if role is None or not text:
continue
self.history.append({"role": role, "content": text})
def _format_prompt(self, user_text: str) -> str:
messages = [*self.history, {"role": "user", "content": user_text}]
if getattr(self.tokenizer, "chat_template", None):
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
lines: list[str] = []
for message in messages:
role = "User" if message["role"] == "user" else "Assistant"
lines.append(f"{role}: {message['content']}")
lines.append("Assistant:")
return "\n".join(lines)
def respond(
self,
user_text: str,
*,
max_thought_seconds: float,
settle_thought_seconds: float,
max_autonomous_utterances: int,
) -> tuple[list[str], float]:
del max_thought_seconds, settle_thought_seconds, max_autonomous_utterances
prompt = self._format_prompt(user_text)
encoded = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
generated = self.model.generate(**encoded, **self.generation_kwargs)
response = self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
self.history.append({"role": "user", "content": user_text})
self.history.append({"role": "assistant", "content": response})
return [response], 0.0
class TerminalInputBuffer:
def __init__(self) -> None:
self._lines: deque[str] = deque()
self._lock = threading.Lock()
def push(self, line: str) -> None:
with self._lock:
self._lines.append(line)
def drain(self) -> list[str]:
with self._lock:
lines = list(self._lines)
self._lines.clear()
return lines
def prepend_many(self, lines: list[str]) -> None:
if not lines:
return
with self._lock:
for line in reversed(lines):
self._lines.appendleft(line)
def __len__(self) -> int:
with self._lock:
return len(self._lines)
def start_terminal_reader(line_buffer: TerminalInputBuffer, stop_event: threading.Event) -> threading.Thread:
def read_loop() -> None:
while not stop_event.is_set():
line = sys.stdin.readline()
if line == "":
stop_event.set()
break
line_buffer.push(line.rstrip("\n"))
thread = threading.Thread(target=read_loop, name="terminal-input-reader", daemon=True)
thread.start()
return thread
def print_gate_debug(tick_index: int, event_name: str, gate_probability: float) -> None:
print(f"[tick={tick_index} event={event_name} gate={gate_probability:.6f}]", flush=True)
class ThoughtLoopSession:
def __init__(
self,
model: ThoughtLoopT5Gemma,
tick_seconds: float,
gate_threshold: float,
generation_kwargs: dict[str, Any],
) -> None:
self.model = model
self.tokenizer = model.tokenizer
self.tick_seconds = tick_seconds
self.gate_threshold = gate_threshold
self.generation_kwargs = generation_kwargs
self.z_state = model.initial_state(batch_size=1, device=model.device)
self.z_mask = model.initial_state_mask(batch_size=1, device=model.device)
self.use_base_chat_template = should_use_base_chat_template(model.config, self.tokenizer)
self.elapsed_seconds = 0.0
self.next_tick_index = 0
self.feedback_delay_seconds = float(model.config.get("rollout", {}).get("post_speech_feedback_delay_seconds", tick_seconds))
def _format_observation(self, observation_role: int, observation_text: str | None) -> str | None:
if observation_text is None:
return None
if observation_role == OBS_ROLE_USER:
return build_user_generation_observation(
self.tokenizer,
observation_text,
self.use_base_chat_template,
)
if observation_role == OBS_ROLE_AGENT_FEEDBACK:
return build_assistant_feedback_observation(
self.tokenizer,
observation_text,
self.use_base_chat_template,
)
return observation_text
def _advance_raw(self, delta_seconds: float, observation_role: int, observation_text: str | None) -> float:
formatted_observation_text = self._format_observation(observation_role, observation_text)
input_ids, attention_mask = tokenize_text(
tokenizer=self.tokenizer,
text=formatted_observation_text,
max_length=int(self.model.config["model"]["max_observation_tokens"]),
)
input_ids_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.model.device)
attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=self.model.device)
role_tensor = torch.tensor([observation_role], dtype=torch.long, device=self.model.device)
self.elapsed_seconds += delta_seconds
with torch.no_grad():
self.z_state, gate_logits, self.z_mask = self.model.rollout_step_with_mask(
z_state=self.z_state,
state_mask=self.z_mask,
observation_input_ids=input_ids_tensor,
observation_attention_mask=attention_mask_tensor,
observation_role_ids=role_tensor,
delta_seconds=torch.tensor([delta_seconds], dtype=torch.float32, device=self.model.device),
elapsed_seconds=torch.tensor([self.elapsed_seconds], dtype=torch.float32, device=self.model.device),
since_last_user_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device),
since_last_agent_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device),
)
return torch.sigmoid(gate_logits).item()
def _advance_tick(self, observation_role: int, observation_text: str | None) -> float:
delta_seconds = 0.0 if self.next_tick_index == 0 else self.tick_seconds
gate_probability = self._advance_raw(delta_seconds, observation_role, observation_text)
self.next_tick_index += 1
return gate_probability
def advance_tick(self, observation_role: int, observation_text: str | None) -> float:
return self._advance_tick(observation_role, observation_text)
def replay_history(self, messages: list[dict[str, Any]]) -> None:
quantized_events = quantize_message_sequence(
messages=messages,
tick_seconds=self.tick_seconds,
feedback_delay_seconds=self.feedback_delay_seconds,
)
events_by_tick = {int(event["tick"]): event for event in quantized_events}
if not events_by_tick:
return
last_tick = max(events_by_tick)
while self.next_tick_index <= last_tick:
event = events_by_tick.get(self.next_tick_index)
if event is None:
self._advance_tick(OBS_ROLE_NONE, None)
continue
if event["kind"] == "user":
self._advance_tick(OBS_ROLE_USER, str(event["text"]))
elif event["kind"] == "agent_feedback":
self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, str(event["text"]))
else:
self._advance_tick(OBS_ROLE_NONE, None)
def _generate_response(self) -> str:
generated = self.model.generate_from_state(
self.z_state,
encoder_attention_mask=self.z_mask,
**self.generation_kwargs,
)
return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
def generate_response(self) -> str:
return self._generate_response()
def respond(
self,
user_text: str,
*,
max_thought_seconds: float,
settle_thought_seconds: float,
max_autonomous_utterances: int,
debug_gate: bool = False,
) -> tuple[list[str], float]:
gate_probability = self._advance_tick(OBS_ROLE_USER, user_text)
if debug_gate:
print_gate_debug(self.next_tick_index - 1, "user", gate_probability)
utterances: list[str] = []
thought_ticks = 0
settle_ticks = max(1, int(math.ceil(settle_thought_seconds / self.tick_seconds)))
max_thought_ticks = max(1, int(math.ceil(max_thought_seconds / self.tick_seconds)))
idle_ticks_after_last_speech = 0
pending_assistant_feedback: str | None = None
while thought_ticks < max_thought_ticks:
if pending_assistant_feedback is not None:
event_name = "assistant_feedback"
gate_probability = self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, pending_assistant_feedback)
pending_assistant_feedback = None
else:
event_name = "idle"
gate_probability = self._advance_tick(OBS_ROLE_NONE, None)
if debug_gate:
print_gate_debug(self.next_tick_index - 1, event_name, gate_probability)
thought_ticks += 1
idle_ticks_after_last_speech += 1
if gate_probability >= self.gate_threshold:
response = self._generate_response()
utterances.append(response)
idle_ticks_after_last_speech = 0
if len(utterances) >= max_autonomous_utterances:
break
pending_assistant_feedback = response
continue
if utterances and idle_ticks_after_last_speech >= settle_ticks:
break
return utterances, thought_ticks * self.tick_seconds
def consume_next_live_user_line(
line_buffer: TerminalInputBuffer,
*,
stop_event: threading.Event,
status_text: str,
) -> str | None:
lines = line_buffer.drain()
if not lines:
return None
selected_line: str | None = None
deferred_lines: list[str] = []
for raw_line in lines:
line = raw_line.strip()
if not line:
continue
command = line.lower()
if command in {"/exit", "/quit", "exit", "quit"}:
stop_event.set()
continue
if command == "/status":
print(status_text, flush=True)
continue
if selected_line is None:
selected_line = line
else:
deferred_lines.append(line)
line_buffer.prepend_many(deferred_lines)
return selected_line
def run_live_thought_loop_terminal(
session: ThoughtLoopSession,
*,
gate_threshold: float,
debug_gate: bool,
) -> None:
line_buffer = TerminalInputBuffer()
stop_event = threading.Event()
start_terminal_reader(line_buffer, stop_event)
pending_assistant_feedback: str | None = None
print("Live Samantha session.", flush=True)
print("Type anytime and press Enter. Each line is one user event on the next 1 Hz tick.", flush=True)
print("Commands: /status, /exit", flush=True)
while not stop_event.is_set():
tick_started_at = time.monotonic()
status_text = (
f"[status] tick={session.next_tick_index} elapsed={session.elapsed_seconds:.1f}s "
f"queued_user_lines={len(line_buffer)} pending_assistant_feedback={pending_assistant_feedback is not None}"
)
if pending_assistant_feedback is not None:
observation_role = OBS_ROLE_AGENT_FEEDBACK
observation_text = pending_assistant_feedback
event_name = "assistant_feedback"
pending_assistant_feedback = None
else:
user_line = consume_next_live_user_line(
line_buffer,
stop_event=stop_event,
status_text=status_text,
)
if stop_event.is_set():
break
if user_line is not None:
observation_role = OBS_ROLE_USER
observation_text = user_line
event_name = "user"
print(f"user> {user_line}", flush=True)
else:
observation_role = OBS_ROLE_NONE
observation_text = None
event_name = "idle"
gate_probability = session.advance_tick(observation_role, observation_text)
if debug_gate:
print_gate_debug(session.next_tick_index - 1, event_name, gate_probability)
if observation_role != OBS_ROLE_USER and gate_probability >= gate_threshold:
response = session.generate_response()
print(f"assistant> {response}", flush=True)
pending_assistant_feedback = response
sleep_seconds = session.tick_seconds - (time.monotonic() - tick_started_at)
if sleep_seconds > 0.0:
time.sleep(sleep_seconds)
def run_blocking_base_terminal(session: BaseSeq2SeqSession) -> None:
print("Raw HF baseline terminal mode. This model has no Samantha Z state or speech gate.", flush=True)
print("Type a message and press Enter. Commands: /exit", flush=True)
while True:
user_text = input("user> ").strip()
if user_text.lower() in {"/exit", "/quit", "exit", "quit"}:
break
if not user_text:
continue
responses, _ = session.respond(
user_text,
max_thought_seconds=0.0,
settle_thought_seconds=0.0,
max_autonomous_utterances=1,
)
for response in responses:
print(f"assistant> {response}", flush=True)
def main() -> None:
args = parse_args()
backend_kind, model, tokenizer = load_inference_backend(args.model_path, device=args.device)
generation_kwargs = build_generation_kwargs(args)
if backend_kind == "thought_loop":
tick_seconds = (
float(args.tick_seconds)
if args.tick_seconds is not None
else float(model.config.get("rollout", {}).get("tick_seconds", 1.0))
)
session: Any = ThoughtLoopSession(
model=model,
tick_seconds=tick_seconds,
gate_threshold=args.gate_threshold,
generation_kwargs=generation_kwargs,
)
else:
assert tokenizer is not None
session = BaseSeq2SeqSession(
model=model,
tokenizer=tokenizer,
device=args.device,
generation_kwargs=generation_kwargs,
)
if args.history_json:
history = json.loads(Path(args.history_json).read_text(encoding="utf-8"))
if not isinstance(history, list):
raise ValueError("History JSON must be a list of turn objects.")
session.replay_history(history)
if args.user:
responses, thought_time = session.respond(
args.user,
max_thought_seconds=args.max_thought_seconds,
settle_thought_seconds=args.settle_thought_seconds,
max_autonomous_utterances=args.max_autonomous_utterances,
debug_gate=bool(args.debug_gate),
)
print(
json.dumps(
{
"responses": responses,
"thought_seconds": thought_time,
},
indent=2,
ensure_ascii=False,
)
)
if args.interactive or args.user is None:
if backend_kind == "thought_loop":
run_live_thought_loop_terminal(
session,
gate_threshold=float(args.gate_threshold),
debug_gate=bool(args.debug_gate),
)
else:
run_blocking_base_terminal(session)
if __name__ == "__main__":
main()