Dat1710's picture
Upload folder using huggingface_hub
00db46c verified
#!/usr/bin/env python3
"""
Run inference using SFT local model + GRPO model from HuggingFace.
"""
import os
import sys
import logging
from pathlib import Path
import torch
# ============================
# Fix PYTHONPATH
# ============================
PROJECT_ROOT = Path(__file__).resolve().parents[2]
sys.path.append(str(PROJECT_ROOT))
# ============================
# Logging setup
# ============================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("model_inference")
# ============================
# Import AFTER fixing path
# ============================
try:
from src.utils.inference import GRPOModelInference
except Exception as e:
print("❌ Import src.utils.inference FAILED!")
print("PROJECT_ROOT =", PROJECT_ROOT)
print("Error =", e)
sys.exit(1)
def main():
print("πŸš€ Starting inference...")
# ============================
# Load model
# ============================
try:
model_inference = GRPOModelInference(
sft_model_path="models/sft/checkpoint-100",
grpo_model_path="Dat1710/countdown-grpo-qwen2",
base_model_id="Qwen/Qwen2.5-Math-1.5B",
device="auto",
dtype=torch.float16,
)
print("βœ… Models loaded successfully!")
except Exception as e:
print("❌ Failed to load models!")
print(e)
return
# ============================
# Run inference
# ============================
problem = (
"Your task: Use 53, 3, 47, and 36 exactly once each with "
"only +, -, *, and / operators to create an expression equal to 133."
)
print("\nπŸ“Œ Problem:")
print(problem)
response, extracted_answer, is_valid = model_inference.solve_problem(
problem_description=problem,
max_new_tokens=512,
temperature=0.8,
)
# ============================
# Print output clearly
# ============================
print("\n================= MODEL OUTPUT =================")
print(response)
print("================================================\n")
print("πŸ“˜ Extracted Answer:", extracted_answer)
print("πŸ” Valid format:", is_valid)
if __name__ == "__main__":
main()