DarshanScripts commited on
Commit
887ae64
·
verified ·
1 Parent(s): 3cc0e2e

Upload stratego\models\vllm_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego//models//vllm_model.py +226 -0
stratego//models//vllm_model.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vLLM Model Wrapper for HuggingFace Models
3
+ This file creates an AI agent that uses vLLM to run large language models from HuggingFace.
4
+ vLLM provides fast GPU-accelerated inference for LLMs.
5
+ """
6
+
7
+ # Import Python's operating system module to interact with environment variables
8
+ import os
9
+ # Import random module for fallback random move selection
10
+ import random
11
+ # Import Optional type hint for parameters that can be None
12
+ from typing import Optional
13
+ # Import vLLM's main classes: LLM for model loading, SamplingParams for generation settings
14
+ from vllm import LLM, SamplingParams
15
+
16
+ # Import the base protocol that defines what an agent should look like
17
+ from .base import AgentLike
18
+ # Import utility functions for parsing game observations and moves
19
+ from ..utils.parsing import (
20
+ extract_legal_moves, # Function: gets list of valid moves from observation text
21
+ extract_forbidden, # Function: gets list of forbidden moves
22
+ slice_board_and_moves, # Function: creates compact version of board state
23
+ strip_think, # Function: removes thinking text from model output
24
+ MOVE_RE # Regular expression: pattern to find moves like "[A0 B0]"
25
+ )
26
+ # Import prompt management classes
27
+ from ..prompts import PromptPack, get_prompt_pack
28
+
29
+
30
+ class VLLMAgent(AgentLike):
31
+ """
32
+ Agent class powered by vLLM for fast GPU inference.
33
+ This agent can load any HuggingFace model and use it to play Stratego.
34
+
35
+ Inherits from: AgentLike (protocol defining agent interface)
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ model_name: str, # String: HuggingFace model ID (e.g., "google/gemma-2-2b-it")
41
+ system_prompt: Optional[str] = None, # String or None: custom system prompt override
42
+ prompt_pack: Optional[PromptPack | str] = None, # PromptPack object or string: prompt configuration
43
+ temperature: float = 0.2, # Float: controls randomness (0.0=deterministic, 1.0=creative)
44
+ top_p: float = 0.9, # Float: nucleus sampling threshold
45
+ max_tokens: int = 64, # Integer: maximum tokens to generate per response
46
+ gpu_memory_utilization: float = 0.3, # Check Check checking Float: fraction of GPU memory to use (0.0-1.0)
47
+ tensor_parallel_size: int = 1, # Integer: number of GPUs to use for this model
48
+ download_dir: str = "/scratch/hm24/.cache/huggingface", # String: where to cache model files
49
+ **kwargs, # Dictionary: additional vLLM arguments
50
+ ):
51
+ """
52
+ Initialize the vLLM agent by loading a model from HuggingFace.
53
+
54
+ This constructor:
55
+ 1. Sets up prompt configuration
56
+ 2. Configures cache directories
57
+ 3. Loads the model into GPU memory using vLLM
58
+ 4. Configures generation parameters
59
+ """
60
+
61
+ # Store the model name as an instance variable (self.model_name)
62
+ # This is used later for displaying which model made a move
63
+ self.model_name = model_name
64
+
65
+ # Handle prompt_pack parameter which can be a string name or PromptPack object
66
+ if isinstance(prompt_pack, str) or prompt_pack is None:
67
+ # If it's a string or None, load the prompt pack by name
68
+ # get_prompt_pack() returns a PromptPack object with system prompts and guidance
69
+ self.prompt_pack: PromptPack = get_prompt_pack(prompt_pack)
70
+ else:
71
+ # If it's already a PromptPack object, use it directly
72
+ self.prompt_pack = prompt_pack
73
+
74
+ # Set system prompt: use custom if provided, otherwise use from prompt pack
75
+ # The system prompt tells the model how to behave (e.g., "You are a Stratego player")
76
+ self.system_prompt = system_prompt if system_prompt is not None else self.prompt_pack.system
77
+
78
+ # Force HuggingFace to cache models in /scratch instead of home directory
79
+ # Environment variables control where transformers library saves downloaded models
80
+ os.environ["HF_HOME"] = download_dir
81
+ os.environ["TRANSFORMERS_CACHE"] = download_dir
82
+
83
+ # Print status messages to show progress
84
+ print(f"🤖 Loading {model_name} with vLLM...")
85
+ print(f"📁 Cache directory: {download_dir}")
86
+
87
+ # Create vLLM engine instance
88
+ # This loads the model from HuggingFace and prepares it for inference
89
+ self.llm = LLM(
90
+ model=model_name, # Which model to load
91
+ download_dir=download_dir, # Where to save/load model files
92
+ gpu_memory_utilization=gpu_memory_utilization, # How much GPU memory to use
93
+ tensor_parallel_size=tensor_parallel_size, # How many GPUs to split model across
94
+ trust_remote_code=True, # Allow custom model code from HuggingFace
95
+ **kwargs # Pass any additional vLLM parameters
96
+ )
97
+
98
+ # Create sampling parameters object
99
+ # This controls how the model generates text (temperature, length, etc.)
100
+ self.sampling_params = SamplingParams(
101
+ temperature=temperature, # Randomness in generation
102
+ top_p=top_p, # Nucleus sampling parameter
103
+ max_tokens=max_tokens, # Maximum length of generated response
104
+ stop=["\n\n", "Player", "Legal moves:"], # List of strings that stop generation
105
+ )
106
+
107
+ # Print success message
108
+ print(f"✅ Model loaded successfully!")
109
+
110
+ def _llm_once(self, prompt: str) -> str:
111
+ """
112
+ Generate a single response from the model.
113
+
114
+ This is a private method (starts with _) used internally.
115
+
116
+ Args:
117
+ prompt (str): The input text to send to the model
118
+
119
+ Returns:
120
+ str: The model's response text, cleaned of thinking markers
121
+ """
122
+ # Combine system prompt and user prompt into full prompt
123
+ # Format: "System: <system_prompt>\n\nUser: <prompt>"
124
+ full_prompt = f"{self.system_prompt}\n\n{prompt}"
125
+
126
+ # Call vLLM to generate response
127
+ # Returns a list of output objects (we only generate 1, so index [0])
128
+ outputs = self.llm.generate([full_prompt], self.sampling_params)
129
+
130
+ # Extract text from first output's first completion
131
+ # Structure: outputs[request_index].outputs[completion_index].text
132
+ response = outputs[0].outputs[0].text.strip()
133
+
134
+ # Remove any thinking markers (like <think>...</think>) from response
135
+ # strip_think() is a utility function that cleans the text
136
+ return strip_think(response)
137
+
138
+ def __call__(self, observation: str) -> str:
139
+ """
140
+ Main method called when agent needs to make a move.
141
+ This makes the agent callable like a function: agent(observation)
142
+
143
+ Args:
144
+ observation (str): The current game state as text from TextArena
145
+
146
+ Returns:
147
+ str: A move in format "[A0 B0]" representing from-square to-square
148
+ """
149
+ # Step 1: Extract list of legal moves from observation text
150
+ # Returns list like ["[A0 B0]", "[A1 B1]", ...]
151
+ legal = extract_legal_moves(observation)
152
+
153
+ # If no legal moves exist, return empty string (game might be over)
154
+ if not legal:
155
+ return ""
156
+
157
+ # Step 2: Get forbidden moves (moves that were already tried and failed)
158
+ # Returns set of move strings to avoid
159
+ forbidden = set(extract_forbidden(observation))
160
+
161
+ # Filter legal moves to remove forbidden ones
162
+ # List comprehension: keep only moves NOT in forbidden set
163
+ # If all moves forbidden, fall back to full legal list
164
+ legal_filtered = [m for m in legal if m not in forbidden] or legal[:]
165
+
166
+ # Step 3: Create compact version of observation for model
167
+ # slice_board_and_moves() removes unnecessary text to save tokens
168
+ slim = slice_board_and_moves(observation)
169
+
170
+ # Get guidance prompt from prompt pack
171
+ # This wraps the slim observation with instructions
172
+ guidance = self.prompt_pack.guidance(slim)
173
+
174
+ # Step 4: Try to get valid move with retry loop (max 3 attempts)
175
+ for attempt in range(3):
176
+ # First strategy: Use guidance prompt
177
+ # Call model with full game context and instructions
178
+ raw = self._llm_once(guidance)
179
+
180
+ # Search for move pattern in response using regex
181
+ # MOVE_RE.search() looks for pattern like "[A0 B0]"
182
+ m = MOVE_RE.search(raw)
183
+
184
+ # If regex found a match
185
+ if m:
186
+ # Extract the matched move string
187
+ mv = m.group(0)
188
+
189
+ # Check if this move is in our legal filtered list
190
+ if mv in legal_filtered:
191
+ # Valid move found! Return it
192
+ return mv
193
+
194
+ # Second strategy (attempts 1 and 2): Direct instruction
195
+ if attempt > 0:
196
+ # Ask model directly for move without game context
197
+ raw2 = self._llm_once("Output exactly one legal move [A0 B0].")
198
+
199
+ # Search for move pattern in this response
200
+ m2 = MOVE_RE.search(raw2)
201
+
202
+ if m2:
203
+ mv2 = m2.group(0)
204
+
205
+ # Check validity
206
+ if mv2 in legal_filtered:
207
+ return mv2
208
+
209
+ # Step 5: Fallback - all attempts failed
210
+ # Choose random legal move to ensure game continues
211
+ # random.choice() picks one item from list randomly
212
+ return random.choice(legal_filtered)
213
+
214
+ def cleanup(self):
215
+ """
216
+ Free GPU memory by deleting model.
217
+ Call this when done with agent to release resources.
218
+ """
219
+ # Delete the LLM object, freeing VRAM
220
+ del self.llm
221
+
222
+ # Import torch to access CUDA functions
223
+ import torch
224
+
225
+ # Force PyTorch to release all unused GPU memory
226
+ torch.cuda.empty_cache()