File size: 10,607 Bytes
fd5d391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""

vLLM Model Wrapper for HuggingFace Models

This file creates an AI agent that uses vLLM to run large language models from HuggingFace.

vLLM provides fast GPU-accelerated inference for LLMs.

"""

# Import Python's operating system module to interact with environment variables
import os
# Import random module for fallback random move selection
import random
# Import Optional type hint for parameters that can be None
from typing import Optional
# Import vLLM's main classes: LLM for model loading, SamplingParams for generation settings
from vllm import LLM, SamplingParams

# Import the base protocol that defines what an agent should look like
from .base import AgentLike
# Import utility functions for parsing game observations and moves
from ..utils.parsing import (
    extract_legal_moves,      # Function: gets list of valid moves from observation text
    extract_forbidden,         # Function: gets list of forbidden moves
    slice_board_and_moves,     # Function: creates compact version of board state
    strip_think,               # Function: removes thinking text from model output
    MOVE_RE                    # Regular expression: pattern to find moves like "[A0 B0]"
)
# Import prompt management classes
from ..prompts import PromptPack, get_prompt_pack


class VLLMAgent(AgentLike):
    """

    Agent class powered by vLLM for fast GPU inference.

    This agent can load any HuggingFace model and use it to play Stratego.

    

    Inherits from: AgentLike (protocol defining agent interface)

    """
    
    def __init__(

        self,

        model_name: str,                          # String: HuggingFace model ID (e.g., "google/gemma-2-2b-it")

        system_prompt: Optional[str] = None,      # String or None: custom system prompt override

        prompt_pack: Optional[PromptPack | str] = None,  # PromptPack object or string: prompt configuration

        temperature: float = 0.2,                 # Float: controls randomness (0.0=deterministic, 1.0=creative)

        top_p: float = 0.9,                       # Float: nucleus sampling threshold

        max_tokens: int = 64,                     # Integer: maximum tokens to generate per response

        gpu_memory_utilization: float = 0.3,      # Check Check checking Float: fraction of GPU memory to use (0.0-1.0)

        tensor_parallel_size: int = 1,            # Integer: number of GPUs to use for this model

        download_dir: str = "/scratch/hm24/.cache/huggingface",  # String: where to cache model files

        **kwargs,                                 # Dictionary: additional vLLM arguments

    ):
        """

        Initialize the vLLM agent by loading a model from HuggingFace.

        

        This constructor:

        1. Sets up prompt configuration

        2. Configures cache directories

        3. Loads the model into GPU memory using vLLM

        4. Configures generation parameters

        """
        
        # Store the model name as an instance variable (self.model_name)
        # This is used later for displaying which model made a move
        self.model_name = model_name
        
        # Handle prompt_pack parameter which can be a string name or PromptPack object
        if isinstance(prompt_pack, str) or prompt_pack is None:
            # If it's a string or None, load the prompt pack by name
            # get_prompt_pack() returns a PromptPack object with system prompts and guidance
            self.prompt_pack: PromptPack = get_prompt_pack(prompt_pack)
        else:
            # If it's already a PromptPack object, use it directly
            self.prompt_pack = prompt_pack
        
        # Set system prompt: use custom if provided, otherwise use from prompt pack
        # The system prompt tells the model how to behave (e.g., "You are a Stratego player")
        self.system_prompt = system_prompt if system_prompt is not None else self.prompt_pack.system
        
        # Force HuggingFace to cache models in /scratch instead of home directory
        # Environment variables control where transformers library saves downloaded models
        os.environ["HF_HOME"] = download_dir
        os.environ["TRANSFORMERS_CACHE"] = download_dir
        
        # Print status messages to show progress
        print(f"🤖 Loading {model_name} with vLLM...")
        print(f"📁 Cache directory: {download_dir}")
        
        # Create vLLM engine instance
        # This loads the model from HuggingFace and prepares it for inference
        self.llm = LLM(
            model=model_name,                      # Which model to load
            download_dir=download_dir,             # Where to save/load model files
            gpu_memory_utilization=gpu_memory_utilization,  # How much GPU memory to use
            tensor_parallel_size=tensor_parallel_size,      # How many GPUs to split model across
            trust_remote_code=True,                # Allow custom model code from HuggingFace
            **kwargs                               # Pass any additional vLLM parameters
        )
        
        # Create sampling parameters object
        # This controls how the model generates text (temperature, length, etc.)
        self.sampling_params = SamplingParams(
            temperature=temperature,               # Randomness in generation
            top_p=top_p,                          # Nucleus sampling parameter
            max_tokens=max_tokens,                # Maximum length of generated response
            stop=["\n\n", "Player", "Legal moves:"],  # List of strings that stop generation
        )
        
        # Print success message
        print(f"✅ Model loaded successfully!")

    def _llm_once(self, prompt: str) -> str:
        """

        Generate a single response from the model.

        

        This is a private method (starts with _) used internally.

        

        Args:

            prompt (str): The input text to send to the model

            

        Returns:

            str: The model's response text, cleaned of thinking markers

        """
        # Combine system prompt and user prompt into full prompt
        # Format: "System: <system_prompt>\n\nUser: <prompt>"
        full_prompt = f"{self.system_prompt}\n\n{prompt}"
        
        # Call vLLM to generate response
        # Returns a list of output objects (we only generate 1, so index [0])
        outputs = self.llm.generate([full_prompt], self.sampling_params)
        
        # Extract text from first output's first completion
        # Structure: outputs[request_index].outputs[completion_index].text
        response = outputs[0].outputs[0].text.strip()
        
        # Remove any thinking markers (like <think>...</think>) from response
        # strip_think() is a utility function that cleans the text
        return strip_think(response)

    def __call__(self, observation: str) -> str:
        """

        Main method called when agent needs to make a move.

        This makes the agent callable like a function: agent(observation)

        

        Args:

            observation (str): The current game state as text from TextArena

            

        Returns:

            str: A move in format "[A0 B0]" representing from-square to-square

        """
        # Step 1: Extract list of legal moves from observation text
        # Returns list like ["[A0 B0]", "[A1 B1]", ...]
        legal = extract_legal_moves(observation)
        
        # If no legal moves exist, return empty string (game might be over)
        if not legal:
            return ""
        
        # Step 2: Get forbidden moves (moves that were already tried and failed)
        # Returns set of move strings to avoid
        forbidden = set(extract_forbidden(observation))
        
        # Filter legal moves to remove forbidden ones
        # List comprehension: keep only moves NOT in forbidden set
        # If all moves forbidden, fall back to full legal list
        legal_filtered = [m for m in legal if m not in forbidden] or legal[:]
        
        # Step 3: Create compact version of observation for model
        # slice_board_and_moves() removes unnecessary text to save tokens
        slim = slice_board_and_moves(observation)
        
        # Get guidance prompt from prompt pack
        # This wraps the slim observation with instructions
        guidance = self.prompt_pack.guidance(slim)
        
        # Step 4: Try to get valid move with retry loop (max 3 attempts)
        for attempt in range(3):
            # First strategy: Use guidance prompt
            # Call model with full game context and instructions
            raw = self._llm_once(guidance)
            
            # Search for move pattern in response using regex
            # MOVE_RE.search() looks for pattern like "[A0 B0]"
            m = MOVE_RE.search(raw)
            
            # If regex found a match
            if m:
                # Extract the matched move string
                mv = m.group(0)
                
                # Check if this move is in our legal filtered list
                if mv in legal_filtered:
                    # Valid move found! Return it
                    return mv
            
            # Second strategy (attempts 1 and 2): Direct instruction
            if attempt > 0:
                # Ask model directly for move without game context
                raw2 = self._llm_once("Output exactly one legal move [A0 B0].")
                
                # Search for move pattern in this response
                m2 = MOVE_RE.search(raw2)
                
                if m2:
                    mv2 = m2.group(0)
                    
                    # Check validity
                    if mv2 in legal_filtered:
                        return mv2
        
        # Step 5: Fallback - all attempts failed
        # Choose random legal move to ensure game continues
        # random.choice() picks one item from list randomly
        return random.choice(legal_filtered)
    
    def cleanup(self):
        """

        Free GPU memory by deleting model.

        Call this when done with agent to release resources.

        """
        # Delete the LLM object, freeing VRAM
        del self.llm
        
        # Import torch to access CUDA functions
        import torch
        
        # Force PyTorch to release all unused GPU memory
        torch.cuda.empty_cache()