test-true / inference.py
BRlkl's picture
Upload folder using huggingface_hub
d97bf05 verified
Raw
History Blame Contribute Delete
9.29 kB
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
from typing import Any
import torch
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
from data import OBS_ROLE_AGENT_FEEDBACK, OBS_ROLE_NONE, OBS_ROLE_USER, quantize_message_sequence, tokenize_text
from model import ThoughtLoopT5Gemma
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Interactive inference for the recurrent Samantha SFT model.")
parser.add_argument("--model-path", required=True, help="Local export directory or Hugging Face repo id.")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--history-json", default=None, help="Optional JSON array of {t, speaker, text} turns to replay.")
parser.add_argument("--user", default=None, help="Optional single user message to answer.")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--tick-seconds", type=float, default=None)
parser.add_argument("--max-thought-seconds", type=float, default=300.0)
parser.add_argument("--settle-thought-seconds", type=float, default=1.0)
parser.add_argument("--max-autonomous-utterances", type=int, default=4)
parser.add_argument("--gate-threshold", type=float, default=0.5)
parser.add_argument("--max-new-tokens", type=int, default=160)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-p", type=float, default=0.95)
return parser.parse_args()
class ThoughtLoopSession:
def __init__(
self,
model: ThoughtLoopT5Gemma,
tick_seconds: float,
gate_threshold: float,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> None:
self.model = model
self.tokenizer = model.tokenizer
self.tick_seconds = tick_seconds
self.gate_threshold = gate_threshold
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.z_state = model.initial_state(batch_size=1, device=model.device)
self.elapsed_seconds = 0.0
self.next_tick_index = 0
self.feedback_delay_seconds = float(model.config.get("rollout", {}).get("post_speech_feedback_delay_seconds", tick_seconds))
def _advance_raw(self, delta_seconds: float, observation_role: int, observation_text: str | None) -> float:
input_ids, attention_mask = tokenize_text(
tokenizer=self.tokenizer,
text=observation_text,
max_length=int(self.model.config["model"]["max_observation_tokens"]),
)
input_ids_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.model.device)
attention_mask_tensor = torch.tensor([attention_mask], dtype=torch.long, device=self.model.device)
role_tensor = torch.tensor([observation_role], dtype=torch.long, device=self.model.device)
self.elapsed_seconds += delta_seconds
with torch.no_grad():
self.z_state, gate_logits = self.model.rollout_step(
z_state=self.z_state,
observation_input_ids=input_ids_tensor,
observation_attention_mask=attention_mask_tensor,
observation_role_ids=role_tensor,
delta_seconds=torch.tensor([delta_seconds], dtype=torch.float32, device=self.model.device),
elapsed_seconds=torch.tensor([self.elapsed_seconds], dtype=torch.float32, device=self.model.device),
since_last_user_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device),
since_last_agent_seconds=torch.zeros(1, dtype=torch.float32, device=self.model.device),
)
return torch.sigmoid(gate_logits).item()
def _advance_tick(self, observation_role: int, observation_text: str | None) -> float:
delta_seconds = 0.0 if self.next_tick_index == 0 else self.tick_seconds
gate_probability = self._advance_raw(delta_seconds, observation_role, observation_text)
self.next_tick_index += 1
return gate_probability
def replay_history(self, messages: list[dict[str, Any]]) -> None:
quantized_events = quantize_message_sequence(
messages=messages,
tick_seconds=self.tick_seconds,
feedback_delay_seconds=self.feedback_delay_seconds,
)
events_by_tick = {int(event["tick"]): event for event in quantized_events}
if not events_by_tick:
return
last_tick = max(events_by_tick)
while self.next_tick_index <= last_tick:
event = events_by_tick.get(self.next_tick_index)
if event is None:
self._advance_tick(OBS_ROLE_NONE, None)
continue
if event["kind"] == "user":
self._advance_tick(OBS_ROLE_USER, str(event["text"]))
elif event["kind"] == "agent_feedback":
self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, str(event["text"]))
else:
self._advance_tick(OBS_ROLE_NONE, None)
def _generate_response(self) -> str:
generated = self.model.generate_from_state(
self.z_state,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0.0,
temperature=self.temperature,
top_p=self.top_p,
)
return self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
def respond(
self,
user_text: str,
*,
max_thought_seconds: float,
settle_thought_seconds: float,
max_autonomous_utterances: int,
) -> tuple[list[str], float]:
gate_probability = self._advance_tick(OBS_ROLE_USER, user_text)
utterances: list[str] = []
thought_ticks = 0
settle_ticks = max(1, int(math.ceil(settle_thought_seconds / self.tick_seconds)))
max_thought_ticks = max(1, int(math.ceil(max_thought_seconds / self.tick_seconds)))
idle_ticks_after_last_speech = 0
while thought_ticks < max_thought_ticks:
if gate_probability >= self.gate_threshold:
response = self._generate_response()
utterances.append(response)
gate_probability = self._advance_tick(OBS_ROLE_AGENT_FEEDBACK, response)
thought_ticks += 1
idle_ticks_after_last_speech = 0
if len(utterances) >= max_autonomous_utterances:
break
continue
gate_probability = self._advance_tick(OBS_ROLE_NONE, None)
thought_ticks += 1
idle_ticks_after_last_speech += 1
if utterances and idle_ticks_after_last_speech >= settle_ticks:
break
return utterances, thought_ticks * self.tick_seconds
def main() -> None:
args = parse_args()
model = ThoughtLoopT5Gemma.from_pretrained(args.model_path, device=args.device, map_location=args.device)
model.eval()
tick_seconds = (
float(args.tick_seconds)
if args.tick_seconds is not None
else float(model.config.get("rollout", {}).get("tick_seconds", 0.1))
)
session = ThoughtLoopSession(
model=model,
tick_seconds=tick_seconds,
gate_threshold=args.gate_threshold,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
if args.history_json:
history = json.loads(Path(args.history_json).read_text(encoding="utf-8"))
if not isinstance(history, list):
raise ValueError("History JSON must be a list of turn objects.")
session.replay_history(history)
if args.user:
responses, thought_time = session.respond(
args.user,
max_thought_seconds=args.max_thought_seconds,
settle_thought_seconds=args.settle_thought_seconds,
max_autonomous_utterances=args.max_autonomous_utterances,
)
print(
json.dumps(
{
"responses": responses,
"thought_seconds": thought_time,
},
indent=2,
ensure_ascii=False,
)
)
if args.interactive:
print("Interactive mode. Type 'exit' to stop.")
while True:
user_text = input("user> ").strip()
if user_text.lower() in {"exit", "quit"}:
break
responses, thought_time = session.respond(
user_text,
max_thought_seconds=args.max_thought_seconds,
settle_thought_seconds=args.settle_thought_seconds,
max_autonomous_utterances=args.max_autonomous_utterances,
)
if not responses:
print(f"agent> [kept thinking for {thought_time:.1f}s and did not cross the speak threshold]")
continue
for response in responses:
print(f"agent> {response}")
print(f"[thought_seconds={thought_time:.1f}]")
if __name__ == "__main__":
main()