|
|
|
|
|
"""
|
|
|
Run inference using SFT local model + GRPO model from HuggingFace.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
|
sys.path.append(str(PROJECT_ROOT))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger("model_inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
from src.utils.inference import GRPOModelInference
|
|
|
except Exception as e:
|
|
|
print("β Import src.utils.inference FAILED!")
|
|
|
print("PROJECT_ROOT =", PROJECT_ROOT)
|
|
|
print("Error =", e)
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
def main():
|
|
|
print("π Starting inference...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
model_inference = GRPOModelInference(
|
|
|
sft_model_path="models/sft/checkpoint-100",
|
|
|
grpo_model_path="Dat1710/countdown-grpo-qwen2",
|
|
|
base_model_id="Qwen/Qwen2.5-Math-1.5B",
|
|
|
device="auto",
|
|
|
dtype=torch.float16,
|
|
|
)
|
|
|
print("β
Models loaded successfully!")
|
|
|
except Exception as e:
|
|
|
print("β Failed to load models!")
|
|
|
print(e)
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
problem = (
|
|
|
"Your task: Use 53, 3, 47, and 36 exactly once each with "
|
|
|
"only +, -, *, and / operators to create an expression equal to 133."
|
|
|
)
|
|
|
|
|
|
print("\nπ Problem:")
|
|
|
print(problem)
|
|
|
|
|
|
response, extracted_answer, is_valid = model_inference.solve_problem(
|
|
|
problem_description=problem,
|
|
|
max_new_tokens=512,
|
|
|
temperature=0.8,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n================= MODEL OUTPUT =================")
|
|
|
print(response)
|
|
|
print("================================================\n")
|
|
|
|
|
|
print("π Extracted Answer:", extracted_answer)
|
|
|
print("π Valid format:", is_valid)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|