File size: 6,684 Bytes
00db46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/env python3
"""

Script to use a trained GRPO model for arithmetic countdown problems.



This script loads a model trained with train_grpo_hydra.py and provides

both interactive and batch evaluation modes for solving arithmetic problems.

"""

import logging
import sys
from pathlib import Path

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# Add src to path for imports
sys.path.append(str(Path(__file__).parent.parent))

from src.dataset.grpo import map_problem_description_to_conversation_grpo
from src.utils.rewards import _is_valid_arithmetic_expression
from src.utils.string_helper import extract_answer

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("model_inference")


class GRPOModelInference:
    """Class for loading and running inference with a trained GRPO model."""

    def __init__(

        self,

        sft_model_path: str | None = None,

        grpo_model_path: str | None = None,

        base_model_id: str = "Qwen/Qwen2.5-Math-1.5B",

        device: str = "auto",

        dtype: torch.dtype = torch.float16,

    ):
        """

        Initialize the model inference class.



        Args:

            model_path: Path to the trained LoRA model directory

            base_model_id: Base model identifier from Hugging Face

            device: Device to load the model on

            dtype: Torch data type for the model

        """
        self.sft_model_path = sft_model_path
        self.grpo_model_path = grpo_model_path
        self.base_model_id = base_model_id
        self.device = device
        self.dtype = dtype

        self.tokenizer = None
        self.model = None

        self._load_model()

    def _load_model(self) -> None:
        """Load the base model, LoRA adapters, and tokenizer."""
        logger.info(f"Loading base model: {self.base_model_id}")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)

        # Load base model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.base_model_id,
            dtype=self.dtype,
            device_map=self.device,
        )

        # Load SFT model
        if self.sft_model_path:
            logger.info(f"Loading SFT LoRA adapters from: {self.sft_model_path}")
            self.model = PeftModel.from_pretrained(self.model, self.sft_model_path)
            self.model = self.model.merge_and_unload()

        # Check if LoRA adapters exist
        if self.grpo_model_path:
            logger.info(f"Loading GRPO LoRA adapters from: {self.grpo_model_path}")
            self.model = PeftModel.from_pretrained(self.model, self.grpo_model_path)
            self.model = self.model.merge_and_unload()

        self.model.eval()
        logger.info("Model loaded successfully")

    def _format_conversation(self, problem_description: str) -> list[dict[str, str]]:
        """

        Format the problem description into the expected conversation format.



        Args:

            problem_description: The arithmetic problem description



        Returns:

            List of conversation messages

        """
        result = map_problem_description_to_conversation_grpo(
            {
                "problem_description": problem_description,
            }
        )
        return result["prompt"]

    def _generate_response(

        self,

        messages: list[dict[str, str]],

        max_new_tokens: int = 512,

        temperature: float = 0.7,

        do_sample: bool = True,

        top_p: float = 0.9,

    ) -> str:
        """

        Generate a response from the model given conversation messages.



        Args:

            messages: List of conversation messages

            max_new_tokens: Maximum number of new tokens to generate

            temperature: Sampling temperature

            do_sample: Whether to use sampling

            top_p: Top-p sampling parameter



        Returns:

            Generated response text

        """
        # Format messages using the tokenizer's chat template
        formatted_prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        # Tokenize the input
        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=4096,
        )

        # Move to device
        if hasattr(self.model, "device"):
            device = self.model.device
        else:
            device = next(self.model.parameters()).device

        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Generate response
        generation_config = GenerationConfig(
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            top_p=top_p,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                generation_config=generation_config,
            )

        # Decode only the new tokens
        response = self.tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
        )

        return response.strip()

    def solve_problem(

        self,

        problem_description: str,

        max_new_tokens: int = 512,

        temperature: float = 0.7,

    ) -> tuple[str, str, bool]:
        """

        Solve a single arithmetic countdown problem.



        Args:

            problem_description: The problem description

            max_new_tokens: Maximum tokens to generate

            temperature: Sampling temperature

            verbose: Whether to print detailed output



        Returns:

            Tuple of (full_response, extracted_answer, is_valid_format)

        """
        # Format conversation
        messages = self._format_conversation(problem_description)

        # Generate response
        response = self._generate_response(
            messages, max_new_tokens=max_new_tokens, temperature=temperature
        )

        # Extract answer
        extracted_answer = extract_answer(response)
        is_valid = _is_valid_arithmetic_expression(extracted_answer)

        return response, extracted_answer, is_valid