""" Interactive REPL for testing trained physics problem-solving model. """ import argparse from pathlib import Path import torch import yaml from qwen2_model import Transformer from tokenizer import Tokenizer from generation_utils import generate from tokenizer_wrapper import decode_token_ids SYSTEM_MESSAGE = ( "You are a helpful physics tutor. You first think about the reasoning process " "in your mind and then provide the user with the answer." ) USER_TEMPLATE = ( "{question}\n" "Show your reasoning in tags. " "Then provide your final answer in tags." ) RESPONSE_PROMPT = "Let me solve this step by step.\n" def load_model_and_tokenizer(config_path, checkpoint_path=None): """Load model and tokenizer from config and checkpoint.""" with open(config_path, "r") as f: config = yaml.safe_load(f) pretrained_model_path = Path(config["model"]["pretrained_model_path"]) device = torch.device(config["model"]["device"]) dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } dtype = dtype_map.get(config["model"]["dtype"], torch.bfloat16) # Load tokenizer tokenizer = Tokenizer(str(pretrained_model_path / "tokenizer.json")) # Load model model = Transformer.from_pretrained(pretrained_model_path, device=device) # Load checkpoint if provided if checkpoint_path: print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location=device) # Handle different checkpoint formats if isinstance(checkpoint, dict): if "model_state_dict" in checkpoint: # Checkpoint contains model_state_dict, optimizer_state_dict, etc. state_dict = checkpoint["model_state_dict"] print(f"Loaded checkpoint from step {checkpoint.get('step', 'unknown')}") else: # Checkpoint is already a state dict state_dict = checkpoint else: state_dict = checkpoint model.load_state_dict(state_dict) print("Checkpoint loaded successfully!") model.eval() return model, tokenizer, device, dtype, config def generate_response(model, tokenizer, question, device, dtype, max_gen_len=512, temperature=0.7, top_p=0.9): """Generate a response for a given physics question.""" # Format the prompt user_message = USER_TEMPLATE.format(question=question) prefix = tokenizer.encode_chat_with_response_prompt( [ {"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": user_message}, ], RESPONSE_PROMPT, ) # Tokenize tokens = tokenizer.tokenize(prefix) prefix_token_ids = tokens.ids # Generate print("\nGenerating response...") with torch.inference_mode(): generated_token_ids, is_finished = generate( model=model, tokenizer=tokenizer, prompt_token_ids=prefix_token_ids, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, device=device, dtype=dtype, ) # Decode generated_text = decode_token_ids(tokenizer, generated_token_ids) return prefix + generated_text, is_finished def extract_answer(text): """Extract the answer from tags.""" import re answer_match = re.search(r"(.*?)", text, re.DOTALL) if answer_match: return answer_match.group(1).strip() return None def print_response(full_text): """Pretty print the model's response.""" import re # Try to extract think and answer sections think_match = re.search(r"(.*?)", full_text, re.DOTALL) answer_match = re.search(r"(.*?)", full_text, re.DOTALL) print("\n" + "="*80) if think_match: print("\n🤔 REASONING:") print("-" * 80) print(think_match.group(1).strip()) if answer_match: print("\n✅ ANSWER:") print("-" * 80) print(answer_match.group(1).strip()) else: print("\n⚠️ WARNING: No answer tags found in response") print("\nFull response:") print("-" * 80) print(full_text) print("="*80 + "\n") def interactive_mode(model, tokenizer, device, dtype, config): """Run interactive REPL mode.""" print("\n" + "="*80) print("Physics Problem Solver - Interactive Mode") print("="*80) print("\nCommands:") print(" - Type your physics question and press Enter") print(" - Type 'quit' or 'exit' to exit") print(" - Type 'config' to change generation parameters") print(" - Type 'example' to see example questions") print("="*80 + "\n") # Default generation parameters max_gen_len = config["training"].get("max_gen_len", 512) temperature = 0.7 top_p = 0.9 while True: try: user_input = input("\n📝 Enter physics question (or command): ").strip() if not user_input: continue if user_input.lower() in ['quit', 'exit', 'q']: print("\nGoodbye! 👋") break if user_input.lower() == 'example': print("\nExample questions:") print(" 1. A ball is thrown upward with velocity 20 m/s. What is its maximum height?") print(" 2. Calculate the force needed to accelerate a 5kg object at 3 m/s²") print(" 3. What is the wavelength of light with frequency 5×10¹⁴ Hz?") print(" 4. A 2kg block slides down a 30° incline. What is its acceleration?") continue if user_input.lower() == 'config': print(f"\nCurrent settings:") print(f" max_gen_len: {max_gen_len}") print(f" temperature: {temperature}") print(f" top_p: {top_p}") try: new_max_len = input(f"\nNew max_gen_len [{max_gen_len}]: ").strip() if new_max_len: max_gen_len = int(new_max_len) new_temp = input(f"New temperature [{temperature}]: ").strip() if new_temp: temperature = float(new_temp) new_top_p = input(f"New top_p [{top_p}]: ").strip() if new_top_p: top_p = float(new_top_p) print("\n✓ Configuration updated!") except ValueError: print("\n✗ Invalid input. Configuration unchanged.") continue # Generate response full_text, is_finished = generate_response( model=model, tokenizer=tokenizer, question=user_input, device=device, dtype=dtype, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, ) # Print response print_response(full_text) if not is_finished: print("⚠️ Note: Response was truncated (reached max_gen_len)") except KeyboardInterrupt: print("\n\nInterrupted. Type 'quit' to exit.\n") continue except Exception as e: print(f"\n✗ Error: {e}\n") continue def batch_inference_mode(model, tokenizer, device, dtype, config, questions_file, output_file): """Run batch inference on a file of questions.""" print(f"\nRunning batch inference on {questions_file}...") max_gen_len = config["training"].get("max_gen_len", 512) # Read questions with open(questions_file, 'r') as f: questions = [line.strip() for line in f if line.strip()] print(f"Found {len(questions)} questions") results = [] for i, question in enumerate(questions, 1): print(f"\n[{i}/{len(questions)}] Processing: {question[:60]}...") full_text, is_finished = generate_response( model=model, tokenizer=tokenizer, question=question, device=device, dtype=dtype, max_gen_len=max_gen_len, temperature=0.7, top_p=0.9, ) answer = extract_answer(full_text) results.append({ 'question': question, 'full_response': full_text, 'answer': answer, 'is_finished': is_finished, }) # Save results import json with open(output_file, 'w') as f: json.dump(results, f, indent=2) print(f"\n✓ Results saved to {output_file}") def main(): parser = argparse.ArgumentParser(description="Interactive inference for physics problem solver") parser.add_argument("--config", type=str, required=True, help="Path to config YAML file") parser.add_argument("--checkpoint", type=str, help="Path to model checkpoint (optional)") parser.add_argument("--batch", action="store_true", help="Run batch inference mode") parser.add_argument("--questions", type=str, help="Path to questions file (for batch mode)") parser.add_argument("--output", type=str, default="results.json", help="Output file (for batch mode)") args = parser.parse_args() # Load model and tokenizer print("Loading model and tokenizer...") model, tokenizer, device, dtype, config = load_model_and_tokenizer( args.config, args.checkpoint ) print("✓ Model loaded successfully!\n") if args.batch: if not args.questions: print("Error: --questions file required for batch mode") return batch_inference_mode(model, tokenizer, device, dtype, config, args.questions, args.output) else: interactive_mode(model, tokenizer, device, dtype, config) if __name__ == "__main__": main()