|
|
""" |
|
|
Interactive REPL for testing trained physics problem-solving model. |
|
|
""" |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
from qwen2_model import Transformer |
|
|
from tokenizer import Tokenizer |
|
|
from generation_utils import generate |
|
|
from tokenizer_wrapper import decode_token_ids |
|
|
|
|
|
|
|
|
SYSTEM_MESSAGE = ( |
|
|
"You are a helpful physics tutor. You first think about the reasoning process " |
|
|
"in your mind and then provide the user with the answer." |
|
|
) |
|
|
USER_TEMPLATE = ( |
|
|
"{question}\n" |
|
|
"Show your reasoning in <think> </think> tags. " |
|
|
"Then provide your final answer in <answer> </answer> tags." |
|
|
) |
|
|
RESPONSE_PROMPT = "Let me solve this step by step.\n<think>" |
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(config_path, checkpoint_path=None): |
|
|
"""Load model and tokenizer from config and checkpoint.""" |
|
|
with open(config_path, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
pretrained_model_path = Path(config["model"]["pretrained_model_path"]) |
|
|
device = torch.device(config["model"]["device"]) |
|
|
|
|
|
dtype_map = { |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float16": torch.float16, |
|
|
"float32": torch.float32, |
|
|
} |
|
|
dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16) |
|
|
|
|
|
|
|
|
tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json")) |
|
|
|
|
|
|
|
|
model = Transformer.from_pretrained(pretrained_model_path, device=device) |
|
|
|
|
|
|
|
|
if checkpoint_path: |
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if "model_state_dict" in checkpoint: |
|
|
|
|
|
state_dict = checkpoint["model_state_dict"] |
|
|
print(f"Loaded checkpoint from step {checkpoint.get('step', 'unknown')}") |
|
|
else: |
|
|
|
|
|
state_dict = checkpoint |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
print("Checkpoint loaded successfully!") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return model, tokenizer, device, dtype, config |
|
|
|
|
|
def generate_response(model, tokenizer, question, device, dtype, max_gen_len=512, temperature=0.7, top_p=0.9): |
|
|
"""Generate a response for a given physics question.""" |
|
|
|
|
|
user_message = USER_TEMPLATE.format(question=question) |
|
|
prefix = tokenizer.encode_chat_with_response_prompt( |
|
|
[ |
|
|
{"role": "system", "content": SYSTEM_MESSAGE}, |
|
|
{"role": "user", "content": user_message}, |
|
|
], |
|
|
RESPONSE_PROMPT, |
|
|
) |
|
|
|
|
|
|
|
|
tokens = tokenizer.tokenize(prefix) |
|
|
prefix_token_ids = tokens.ids |
|
|
|
|
|
|
|
|
print("\nGenerating response...") |
|
|
with torch.inference_mode(): |
|
|
generated_token_ids, is_finished = generate( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt_token_ids=prefix_token_ids, |
|
|
max_gen_len=max_gen_len, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = decode_token_ids(tokenizer, generated_token_ids) |
|
|
|
|
|
return prefix + generated_text, is_finished |
|
|
|
|
|
|
|
|
def extract_answer(text): |
|
|
"""Extract the answer from <answer> tags.""" |
|
|
import re |
|
|
answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL) |
|
|
if answer_match: |
|
|
return answer_match.group(1).strip() |
|
|
return None |
|
|
|
|
|
|
|
|
def print_response(full_text): |
|
|
"""Pretty print the model's response.""" |
|
|
import re |
|
|
|
|
|
|
|
|
think_match = re.search(r"<think>(.*?)</think>", full_text, re.DOTALL) |
|
|
answer_match = re.search(r"<answer>(.*?)</answer>", full_text, re.DOTALL) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
|
|
|
if think_match: |
|
|
print("\n🤔 REASONING:") |
|
|
print("-" * 80) |
|
|
print(think_match.group(1).strip()) |
|
|
|
|
|
if answer_match: |
|
|
print("\n✅ ANSWER:") |
|
|
print("-" * 80) |
|
|
print(answer_match.group(1).strip()) |
|
|
else: |
|
|
print("\n⚠️ WARNING: No answer tags found in response") |
|
|
print("\nFull response:") |
|
|
print("-" * 80) |
|
|
print(full_text) |
|
|
|
|
|
print("="*80 + "\n") |
|
|
|
|
|
|
|
|
def interactive_mode(model, tokenizer, device, dtype, config): |
|
|
"""Run interactive REPL mode.""" |
|
|
print("\n" + "="*80) |
|
|
print("Physics Problem Solver - Interactive Mode") |
|
|
print("="*80) |
|
|
print("\nCommands:") |
|
|
print(" - Type your physics question and press Enter") |
|
|
print(" - Type 'quit' or 'exit' to exit") |
|
|
print(" - Type 'config' to change generation parameters") |
|
|
print(" - Type 'example' to see example questions") |
|
|
print("="*80 + "\n") |
|
|
|
|
|
|
|
|
max_gen_len = config["training"].get("max_gen_len", 512) |
|
|
temperature = 0.7 |
|
|
top_p = 0.9 |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("\n📝 Enter physics question (or command): ").strip() |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']: |
|
|
print("\nGoodbye! 👋") |
|
|
break |
|
|
|
|
|
if user_input.lower() == 'example': |
|
|
print("\nExample questions:") |
|
|
print(" 1. A ball is thrown upward with velocity 20 m/s. What is its maximum height?") |
|
|
print(" 2. Calculate the force needed to accelerate a 5kg object at 3 m/s²") |
|
|
print(" 3. What is the wavelength of light with frequency 5×10¹⁴ Hz?") |
|
|
print(" 4. A 2kg block slides down a 30° incline. What is its acceleration?") |
|
|
continue |
|
|
|
|
|
if user_input.lower() == 'config': |
|
|
print(f"\nCurrent settings:") |
|
|
print(f" max_gen_len: {max_gen_len}") |
|
|
print(f" temperature: {temperature}") |
|
|
print(f" top_p: {top_p}") |
|
|
|
|
|
try: |
|
|
new_max_len = input(f"\nNew max_gen_len [{max_gen_len}]: ").strip() |
|
|
if new_max_len: |
|
|
max_gen_len = int(new_max_len) |
|
|
|
|
|
new_temp = input(f"New temperature [{temperature}]: ").strip() |
|
|
if new_temp: |
|
|
temperature = float(new_temp) |
|
|
|
|
|
new_top_p = input(f"New top_p [{top_p}]: ").strip() |
|
|
if new_top_p: |
|
|
top_p = float(new_top_p) |
|
|
|
|
|
print("\n✓ Configuration updated!") |
|
|
except ValueError: |
|
|
print("\n✗ Invalid input. Configuration unchanged.") |
|
|
continue |
|
|
|
|
|
|
|
|
full_text, is_finished = generate_response( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
question=user_input, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
max_gen_len=max_gen_len, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
|
|
|
|
|
|
print_response(full_text) |
|
|
|
|
|
if not is_finished: |
|
|
print("⚠️ Note: Response was truncated (reached max_gen_len)") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nInterrupted. Type 'quit' to exit.\n") |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"\n✗ Error: {e}\n") |
|
|
continue |
|
|
|
|
|
|
|
|
def batch_inference_mode(model, tokenizer, device, dtype, config, questions_file, output_file): |
|
|
"""Run batch inference on a file of questions.""" |
|
|
print(f"\nRunning batch inference on {questions_file}...") |
|
|
|
|
|
max_gen_len = config["training"].get("max_gen_len", 512) |
|
|
|
|
|
|
|
|
with open(questions_file, 'r') as f: |
|
|
questions = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
print(f"Found {len(questions)} questions") |
|
|
|
|
|
results = [] |
|
|
for i, question in enumerate(questions, 1): |
|
|
print(f"\n[{i}/{len(questions)}] Processing: {question[:60]}...") |
|
|
|
|
|
full_text, is_finished = generate_response( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
question=question, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
max_gen_len=max_gen_len, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
answer = extract_answer(full_text) |
|
|
|
|
|
results.append({ |
|
|
'question': question, |
|
|
'full_response': full_text, |
|
|
'answer': answer, |
|
|
'is_finished': is_finished, |
|
|
}) |
|
|
|
|
|
|
|
|
import json |
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
print(f"\n✓ Results saved to {output_file}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Interactive inference for physics problem solver") |
|
|
parser.add_argument("--config", type=str, required=True, help="Path to config YAML file") |
|
|
parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint (optional)") |
|
|
parser.add_argument("--batch", action="store_true", help="Run batch inference mode") |
|
|
parser.add_argument("--questions", type=str, help="Path to questions file (for batch mode)") |
|
|
parser.add_argument("--output", type=str, default="results.json", help="Output file (for batch mode)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
print("Loading model and tokenizer...") |
|
|
model, tokenizer, device, dtype, config = load_model_and_tokenizer( |
|
|
args.config, |
|
|
args.checkpoint |
|
|
) |
|
|
print("✓ Model loaded successfully!\n") |
|
|
|
|
|
if args.batch: |
|
|
if not args.questions: |
|
|
print("Error: --questions file required for batch mode") |
|
|
return |
|
|
batch_inference_mode(model, tokenizer, device, dtype, config, args.questions, args.output) |
|
|
else: |
|
|
interactive_mode(model, tokenizer, device, dtype, config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|