#!/usr/bin/env python3 """ Run inference using SFT local model + GRPO model from HuggingFace. """ import os import sys import logging from pathlib import Path import torch # ============================ # Fix PYTHONPATH # ============================ PROJECT_ROOT = Path(__file__).resolve().parents[2] sys.path.append(str(PROJECT_ROOT)) # ============================ # Logging setup # ============================ logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("model_inference") # ============================ # Import AFTER fixing path # ============================ try: from src.utils.inference import GRPOModelInference except Exception as e: print("āŒ Import src.utils.inference FAILED!") print("PROJECT_ROOT =", PROJECT_ROOT) print("Error =", e) sys.exit(1) def main(): print("šŸš€ Starting inference...") # ============================ # Load model # ============================ try: model_inference = GRPOModelInference( sft_model_path="models/sft/checkpoint-100", grpo_model_path="Dat1710/countdown-grpo-qwen2", base_model_id="Qwen/Qwen2.5-Math-1.5B", device="auto", dtype=torch.float16, ) print("āœ… Models loaded successfully!") except Exception as e: print("āŒ Failed to load models!") print(e) return # ============================ # Run inference # ============================ problem = ( "Your task: Use 53, 3, 47, and 36 exactly once each with " "only +, -, *, and / operators to create an expression equal to 133." ) print("\nšŸ“Œ Problem:") print(problem) response, extracted_answer, is_valid = model_inference.solve_problem( problem_description=problem, max_new_tokens=512, temperature=0.8, ) # ============================ # Print output clearly # ============================ print("\n================= MODEL OUTPUT =================") print(response) print("================================================\n") print("šŸ“˜ Extracted Answer:", extracted_answer) print("šŸ” Valid format:", is_valid) if __name__ == "__main__": main()