|
|
|
|
|
"""
|
|
|
Script to use a trained GRPO model for arithmetic countdown problems.
|
|
|
|
|
|
This script loads a model trained with train_grpo_hydra.py and provides
|
|
|
both interactive and batch evaluation modes for solving arithmetic problems.
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
import torch
|
|
|
from peft import PeftModel
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
|
|
from src.dataset.grpo import map_problem_description_to_conversation_grpo
|
|
|
from src.utils.rewards import _is_valid_arithmetic_expression
|
|
|
from src.utils.string_helper import extract_answer
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger("model_inference")
|
|
|
|
|
|
|
|
|
class GRPOModelInference:
|
|
|
"""Class for loading and running inference with a trained GRPO model."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
sft_model_path: str | None = None,
|
|
|
grpo_model_path: str | None = None,
|
|
|
base_model_id: str = "Qwen/Qwen2.5-Math-1.5B",
|
|
|
device: str = "auto",
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
):
|
|
|
"""
|
|
|
Initialize the model inference class.
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the trained LoRA model directory
|
|
|
base_model_id: Base model identifier from Hugging Face
|
|
|
device: Device to load the model on
|
|
|
dtype: Torch data type for the model
|
|
|
"""
|
|
|
self.sft_model_path = sft_model_path
|
|
|
self.grpo_model_path = grpo_model_path
|
|
|
self.base_model_id = base_model_id
|
|
|
self.device = device
|
|
|
self.dtype = dtype
|
|
|
|
|
|
self.tokenizer = None
|
|
|
self.model = None
|
|
|
|
|
|
self._load_model()
|
|
|
|
|
|
def _load_model(self) -> None:
|
|
|
"""Load the base model, LoRA adapters, and tokenizer."""
|
|
|
logger.info(f"Loading base model: {self.base_model_id}")
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
self.base_model_id,
|
|
|
dtype=self.dtype,
|
|
|
device_map=self.device,
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.sft_model_path:
|
|
|
logger.info(f"Loading SFT LoRA adapters from: {self.sft_model_path}")
|
|
|
self.model = PeftModel.from_pretrained(self.model, self.sft_model_path)
|
|
|
self.model = self.model.merge_and_unload()
|
|
|
|
|
|
|
|
|
if self.grpo_model_path:
|
|
|
logger.info(f"Loading GRPO LoRA adapters from: {self.grpo_model_path}")
|
|
|
self.model = PeftModel.from_pretrained(self.model, self.grpo_model_path)
|
|
|
self.model = self.model.merge_and_unload()
|
|
|
|
|
|
self.model.eval()
|
|
|
logger.info("Model loaded successfully")
|
|
|
|
|
|
def _format_conversation(self, problem_description: str) -> list[dict[str, str]]:
|
|
|
"""
|
|
|
Format the problem description into the expected conversation format.
|
|
|
|
|
|
Args:
|
|
|
problem_description: The arithmetic problem description
|
|
|
|
|
|
Returns:
|
|
|
List of conversation messages
|
|
|
"""
|
|
|
result = map_problem_description_to_conversation_grpo(
|
|
|
{
|
|
|
"problem_description": problem_description,
|
|
|
}
|
|
|
)
|
|
|
return result["prompt"]
|
|
|
|
|
|
def _generate_response(
|
|
|
self,
|
|
|
messages: list[dict[str, str]],
|
|
|
max_new_tokens: int = 512,
|
|
|
temperature: float = 0.7,
|
|
|
do_sample: bool = True,
|
|
|
top_p: float = 0.9,
|
|
|
) -> str:
|
|
|
"""
|
|
|
Generate a response from the model given conversation messages.
|
|
|
|
|
|
Args:
|
|
|
messages: List of conversation messages
|
|
|
max_new_tokens: Maximum number of new tokens to generate
|
|
|
temperature: Sampling temperature
|
|
|
do_sample: Whether to use sampling
|
|
|
top_p: Top-p sampling parameter
|
|
|
|
|
|
Returns:
|
|
|
Generated response text
|
|
|
"""
|
|
|
|
|
|
formatted_prompt = self.tokenizer.apply_chat_template(
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
)
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
|
formatted_prompt,
|
|
|
return_tensors="pt",
|
|
|
padding=True,
|
|
|
truncation=True,
|
|
|
max_length=4096,
|
|
|
)
|
|
|
|
|
|
|
|
|
if hasattr(self.model, "device"):
|
|
|
device = self.model.device
|
|
|
else:
|
|
|
device = next(self.model.parameters()).device
|
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
|
|
|
|
generation_config = GenerationConfig(
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
temperature=temperature,
|
|
|
do_sample=do_sample,
|
|
|
top_p=top_p,
|
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.generate(
|
|
|
**inputs,
|
|
|
generation_config=generation_config,
|
|
|
)
|
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(
|
|
|
outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
|
|
)
|
|
|
|
|
|
return response.strip()
|
|
|
|
|
|
def solve_problem(
|
|
|
self,
|
|
|
problem_description: str,
|
|
|
max_new_tokens: int = 512,
|
|
|
temperature: float = 0.7,
|
|
|
) -> tuple[str, str, bool]:
|
|
|
"""
|
|
|
Solve a single arithmetic countdown problem.
|
|
|
|
|
|
Args:
|
|
|
problem_description: The problem description
|
|
|
max_new_tokens: Maximum tokens to generate
|
|
|
temperature: Sampling temperature
|
|
|
verbose: Whether to print detailed output
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (full_response, extracted_answer, is_valid_format)
|
|
|
"""
|
|
|
|
|
|
messages = self._format_conversation(problem_description)
|
|
|
|
|
|
|
|
|
response = self._generate_response(
|
|
|
messages, max_new_tokens=max_new_tokens, temperature=temperature
|
|
|
)
|
|
|
|
|
|
|
|
|
extracted_answer = extract_answer(response)
|
|
|
is_valid = _is_valid_arithmetic_expression(extracted_answer)
|
|
|
|
|
|
return response, extracted_answer, is_valid
|
|
|
|