File size: 2,409 Bytes
00db46c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
#!/usr/bin/env python3
"""
Run inference using SFT local model + GRPO model from HuggingFace.
"""
import os
import sys
import logging
from pathlib import Path
import torch
# ============================
# Fix PYTHONPATH
# ============================
PROJECT_ROOT = Path(__file__).resolve().parents[2]
sys.path.append(str(PROJECT_ROOT))
# ============================
# Logging setup
# ============================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("model_inference")
# ============================
# Import AFTER fixing path
# ============================
try:
from src.utils.inference import GRPOModelInference
except Exception as e:
print("β Import src.utils.inference FAILED!")
print("PROJECT_ROOT =", PROJECT_ROOT)
print("Error =", e)
sys.exit(1)
def main():
print("π Starting inference...")
# ============================
# Load model
# ============================
try:
model_inference = GRPOModelInference(
sft_model_path="models/sft/checkpoint-100",
grpo_model_path="Dat1710/countdown-grpo-qwen2",
base_model_id="Qwen/Qwen2.5-Math-1.5B",
device="auto",
dtype=torch.float16,
)
print("β
Models loaded successfully!")
except Exception as e:
print("β Failed to load models!")
print(e)
return
# ============================
# Run inference
# ============================
problem = (
"Your task: Use 53, 3, 47, and 36 exactly once each with "
"only +, -, *, and / operators to create an expression equal to 133."
)
print("\nπ Problem:")
print(problem)
response, extracted_answer, is_valid = model_inference.solve_problem(
problem_description=problem,
max_new_tokens=512,
temperature=0.8,
)
# ============================
# Print output clearly
# ============================
print("\n================= MODEL OUTPUT =================")
print(response)
print("================================================\n")
print("π Extracted Answer:", extracted_answer)
print("π Valid format:", is_valid)
if __name__ == "__main__":
main()
|