"""
Interactive REPL for testing trained physics problem-solving model.
"""
import argparse
from pathlib import Path
import torch
import yaml
from qwen2_model import Transformer
from tokenizer import Tokenizer
from generation_utils import generate
from tokenizer_wrapper import decode_token_ids
SYSTEM_MESSAGE = (
"You are a helpful physics tutor. You first think about the reasoning process "
"in your mind and then provide the user with the answer."
)
USER_TEMPLATE = (
"{question}\n"
"Show your reasoning in tags. "
"Then provide your final answer in tags."
)
RESPONSE_PROMPT = "Let me solve this step by step.\n"
def load_model_and_tokenizer(config_path, checkpoint_path=None):
"""Load model and tokenizer from config and checkpoint."""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
pretrained_model_path = Path(config["model"]["pretrained_model_path"])
device = torch.device(config["model"]["device"])
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16)
# Load tokenizer
tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json"))
# Load model
model = Transformer.from_pretrained(pretrained_model_path, device=device)
# Load checkpoint if provided
if checkpoint_path:
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if "model_state_dict" in checkpoint:
# Checkpoint contains model_state_dict, optimizer_state_dict, etc.
state_dict = checkpoint["model_state_dict"]
print(f"Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
else:
# Checkpoint is already a state dict
state_dict = checkpoint
else:
state_dict = checkpoint
model.load_state_dict(state_dict)
print("Checkpoint loaded successfully!")
model.eval()
return model, tokenizer, device, dtype, config
def generate_response(model, tokenizer, question, device, dtype, max_gen_len=512, temperature=0.7, top_p=0.9):
"""Generate a response for a given physics question."""
# Format the prompt
user_message = USER_TEMPLATE.format(question=question)
prefix = tokenizer.encode_chat_with_response_prompt(
[
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": user_message},
],
RESPONSE_PROMPT,
)
# Tokenize
tokens = tokenizer.tokenize(prefix)
prefix_token_ids = tokens.ids
# Generate
print("\nGenerating response...")
with torch.inference_mode():
generated_token_ids, is_finished = generate(
model=model,
tokenizer=tokenizer,
prompt_token_ids=prefix_token_ids,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
device=device,
dtype=dtype,
)
# Decode
generated_text = decode_token_ids(tokenizer, generated_token_ids)
return prefix + generated_text, is_finished
def extract_answer(text):
"""Extract the answer from tags."""
import re
answer_match = re.search(r"(.*?)", text, re.DOTALL)
if answer_match:
return answer_match.group(1).strip()
return None
def print_response(full_text):
"""Pretty print the model's response."""
import re
# Try to extract think and answer sections
think_match = re.search(r"(.*?)", full_text, re.DOTALL)
answer_match = re.search(r"(.*?)", full_text, re.DOTALL)
print("\n" + "="*80)
if think_match:
print("\n🤔 REASONING:")
print("-" * 80)
print(think_match.group(1).strip())
if answer_match:
print("\n✅ ANSWER:")
print("-" * 80)
print(answer_match.group(1).strip())
else:
print("\n⚠️ WARNING: No answer tags found in response")
print("\nFull response:")
print("-" * 80)
print(full_text)
print("="*80 + "\n")
def interactive_mode(model, tokenizer, device, dtype, config):
"""Run interactive REPL mode."""
print("\n" + "="*80)
print("Physics Problem Solver - Interactive Mode")
print("="*80)
print("\nCommands:")
print(" - Type your physics question and press Enter")
print(" - Type 'quit' or 'exit' to exit")
print(" - Type 'config' to change generation parameters")
print(" - Type 'example' to see example questions")
print("="*80 + "\n")
# Default generation parameters
max_gen_len = config["training"].get("max_gen_len", 512)
temperature = 0.7
top_p = 0.9
while True:
try:
user_input = input("\n📝 Enter physics question (or command): ").strip()
if not user_input:
continue
if user_input.lower() in ['quit', 'exit', 'q']:
print("\nGoodbye! 👋")
break
if user_input.lower() == 'example':
print("\nExample questions:")
print(" 1. A ball is thrown upward with velocity 20 m/s. What is its maximum height?")
print(" 2. Calculate the force needed to accelerate a 5kg object at 3 m/s²")
print(" 3. What is the wavelength of light with frequency 5×10¹⁴ Hz?")
print(" 4. A 2kg block slides down a 30° incline. What is its acceleration?")
continue
if user_input.lower() == 'config':
print(f"\nCurrent settings:")
print(f" max_gen_len: {max_gen_len}")
print(f" temperature: {temperature}")
print(f" top_p: {top_p}")
try:
new_max_len = input(f"\nNew max_gen_len [{max_gen_len}]: ").strip()
if new_max_len:
max_gen_len = int(new_max_len)
new_temp = input(f"New temperature [{temperature}]: ").strip()
if new_temp:
temperature = float(new_temp)
new_top_p = input(f"New top_p [{top_p}]: ").strip()
if new_top_p:
top_p = float(new_top_p)
print("\n✓ Configuration updated!")
except ValueError:
print("\n✗ Invalid input. Configuration unchanged.")
continue
# Generate response
full_text, is_finished = generate_response(
model=model,
tokenizer=tokenizer,
question=user_input,
device=device,
dtype=dtype,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
# Print response
print_response(full_text)
if not is_finished:
print("⚠️ Note: Response was truncated (reached max_gen_len)")
except KeyboardInterrupt:
print("\n\nInterrupted. Type 'quit' to exit.\n")
continue
except Exception as e:
print(f"\n✗ Error: {e}\n")
continue
def batch_inference_mode(model, tokenizer, device, dtype, config, questions_file, output_file):
"""Run batch inference on a file of questions."""
print(f"\nRunning batch inference on {questions_file}...")
max_gen_len = config["training"].get("max_gen_len", 512)
# Read questions
with open(questions_file, 'r') as f:
questions = [line.strip() for line in f if line.strip()]
print(f"Found {len(questions)} questions")
results = []
for i, question in enumerate(questions, 1):
print(f"\n[{i}/{len(questions)}] Processing: {question[:60]}...")
full_text, is_finished = generate_response(
model=model,
tokenizer=tokenizer,
question=question,
device=device,
dtype=dtype,
max_gen_len=max_gen_len,
temperature=0.7,
top_p=0.9,
)
answer = extract_answer(full_text)
results.append({
'question': question,
'full_response': full_text,
'answer': answer,
'is_finished': is_finished,
})
# Save results
import json
with open(output_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n✓ Results saved to {output_file}")
def main():
parser = argparse.ArgumentParser(description="Interactive inference for physics problem solver")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML file")
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint (optional)")
parser.add_argument("--batch", action="store_true", help="Run batch inference mode")
parser.add_argument("--questions", type=str, help="Path to questions file (for batch mode)")
parser.add_argument("--output", type=str, default="results.json", help="Output file (for batch mode)")
args = parser.parse_args()
# Load model and tokenizer
print("Loading model and tokenizer...")
model, tokenizer, device, dtype, config = load_model_and_tokenizer(
args.config,
args.checkpoint
)
print("✓ Model loaded successfully!\n")
if args.batch:
if not args.questions:
print("Error: --questions file required for batch mode")
return
batch_inference_mode(model, tokenizer, device, dtype, config, args.questions, args.output)
else:
interactive_mode(model, tokenizer, device, dtype, config)
if __name__ == "__main__":
main()