meta-r2 / scripts /inference.py
github-actions[bot]
Deploy Space snapshot
ddbc1ba
"""
LifeStack Model β€” Inference Script
===================================
Usage:
python scripts/inference.py --model ./lifestack_model
python scripts/inference.py --model ./lifestack_model --scenario "My car broke down and I have a meeting in 2 hours"
python scripts/inference.py --model ./lifestack_model --interactive
"""
import argparse
import json
import re
import torch
from transformers import AutoTokenizer
# ── Load model ────────────────────────────────────────────────────────────────
def load_model(model_dir: str):
"""Load the LoRA adapter on top of base Qwen2.5-1.5B.
Tries Unsloth first (2x faster), falls back to standard PEFT.
"""
try:
from unsloth import FastLanguageModel
print("Loading with Unsloth (fast)...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_dir,
max_seq_length=2048,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
print("βœ… Loaded via Unsloth")
return model, tokenizer
except ImportError:
print("Unsloth not installed β€” using standard PEFT (slower)...")
from transformers import AutoModelForCausalLM
from peft import PeftModel
base_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
base = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
model = PeftModel.from_pretrained(base, model_dir)
model.eval()
print("βœ… Loaded via PEFT")
return model, tokenizer
# ── Build prompt ───────────────────────────────────────────────────────────────
SYSTEM = (
"You are LifeStack, an AI life-management agent. "
"Given a real-life crisis, respond with a single optimal action as valid JSON.\n\n"
"Required JSON format:\n"
'{"action_type": "negotiate|communicate|delegate|spend|reschedule|rest|deprioritize|execute", '
'"target_domain": "career|finances|relationships|physical_health|mental_wellbeing|time|transport_crisis|flight_crisis|code_merge_crisis", '
'"metric_changes": {"domain.submetric": delta_value}, '
'"resource_cost": {"time": hours, "money": dollars, "energy": units}, '
'"reasoning": "brief explanation of why this is the best action"}'
)
def build_prompt(scenario: str) -> str:
return (
f"<|im_start|>system\n{SYSTEM}<|im_end|>\n"
f"<|im_start|>user\n"
f"CRISIS: {scenario}\n\n"
f"Respond with ONLY valid JSON β€” no markdown, no explanation outside the JSON.<|im_end|>\n"
f"<|im_start|>assistant\n"
)
class JsonCompleteStopping(torch.nn.Module):
"""Stop once the first complete JSON object closes."""
def __init__(self, tokenizer, prompt_len: int):
super().__init__()
self.tokenizer = tokenizer
self.prompt_len = prompt_len
def __call__(self, input_ids: torch.LongTensor, scores, **kwargs) -> bool:
text = self.tokenizer.decode(input_ids[0][self.prompt_len:], skip_special_tokens=True)
depth = 0
entered = False
for ch in text:
if ch == "{":
depth += 1
entered = True
elif ch == "}":
depth -= 1
if entered and depth == 0:
return True
return False
def extract_json_payload(completion: str) -> dict:
"""Parse the first complete JSON object, ignoring trailing text."""
decoder = json.JSONDecoder()
for match in re.finditer(r"\{", completion.strip()):
try:
obj, _ = decoder.raw_decode(completion[match.start():].strip())
except json.JSONDecodeError:
continue
if isinstance(obj, dict):
return obj
return {"raw_output": completion, "parse_error": "no valid JSON object found"}
# ── Inference ──────────────────────────────────────────────────────────────────
def resolve(model, tokenizer, scenario: str, temperature: float = 0.3) -> dict:
prompt = build_prompt(scenario)
device = next(model.parameters()).device
inputs = tokenizer(prompt, return_tensors="pt").to(device)
prompt_len = inputs["input_ids"].shape[1]
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=160,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=pad_token_id,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=[JsonCompleteStopping(tokenizer, prompt_len)],
)
# Decode only the new tokens (not the prompt)
completion = tokenizer.decode(
output_ids[0][prompt_len:],
skip_special_tokens=True
).strip()
return extract_json_payload(completion)
# ── Built-in scenarios for quick demo ─────────────────────────────────────────
DEMO_SCENARIOS = [
"My flight got cancelled and my card got declined at the rebooking desk. I have a client presentation tomorrow morning.",
"My car broke down on the highway. The repair will take 3 days and I have no other transport.",
"I haven't slept properly in 2 weeks. My productivity is shot and my partner says I'm distant.",
"A surprise tax audit letter arrived. I owe $4,000 I don't have liquid.",
"My boss just dropped a 12-hour task on me at 5PM Friday and said it's due Monday morning.",
"The morning train is delayed 90 minutes and I have a 9AM client meeting I can't miss.",
"I've been double-booked every weekend for a month and I can't say no to anyone.",
"A critical git merge broke the staging environment 2 hours before a demo.",
]
# ── Main ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="LifeStack Inference")
parser.add_argument("--model", type=str, default="./lifestack_model",
help="Path to the downloaded/unzipped model directory")
parser.add_argument("--scenario", type=str, default=None,
help="Describe your crisis (quoted string)")
parser.add_argument("--interactive", action="store_true",
help="Interactive REPL β€” type your own crises")
parser.add_argument("--demo", action="store_true",
help="Run all 8 built-in demo scenarios")
parser.add_argument("--temperature", type=float, default=0.3,
help="Generation temperature (default 0.3 = focused)")
args = parser.parse_args()
print(f"\n{'='*60}")
print(" LifeStack β€” AI Life Management Agent")
print(f"{'='*60}\n")
model, tokenizer = load_model(args.model)
print()
if args.demo:
for i, scenario in enumerate(DEMO_SCENARIOS, 1):
print(f"[{i}/8] {scenario[:80]}...")
action = resolve(model, tokenizer, scenario, args.temperature)
print(json.dumps(action, indent=2))
print()
elif args.interactive:
print("Interactive mode β€” type your crisis and press Enter.")
print("Type 'quit' to exit.\n")
while True:
try:
scenario = input("Crisis > ").strip()
except (EOFError, KeyboardInterrupt):
break
if scenario.lower() in ("quit", "exit", "q"):
break
if not scenario:
continue
action = resolve(model, tokenizer, scenario, args.temperature)
print("\nLifeStack Action:")
print(json.dumps(action, indent=2))
print()
elif args.scenario:
action = resolve(model, tokenizer, args.scenario, args.temperature)
print("LifeStack Action:")
print(json.dumps(action, indent=2))
else:
# Default: run 1 demo scenario
scenario = DEMO_SCENARIOS[0]
print(f"Demo scenario: {scenario}\n")
action = resolve(model, tokenizer, scenario, args.temperature)
print("LifeStack Action:")
print(json.dumps(action, indent=2))
if __name__ == "__main__":
main()