Spaces:
Configuration error
Configuration error
File size: 10,607 Bytes
fd5d391 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """
vLLM Model Wrapper for HuggingFace Models
This file creates an AI agent that uses vLLM to run large language models from HuggingFace.
vLLM provides fast GPU-accelerated inference for LLMs.
"""
# Import Python's operating system module to interact with environment variables
import os
# Import random module for fallback random move selection
import random
# Import Optional type hint for parameters that can be None
from typing import Optional
# Import vLLM's main classes: LLM for model loading, SamplingParams for generation settings
from vllm import LLM, SamplingParams
# Import the base protocol that defines what an agent should look like
from .base import AgentLike
# Import utility functions for parsing game observations and moves
from ..utils.parsing import (
extract_legal_moves, # Function: gets list of valid moves from observation text
extract_forbidden, # Function: gets list of forbidden moves
slice_board_and_moves, # Function: creates compact version of board state
strip_think, # Function: removes thinking text from model output
MOVE_RE # Regular expression: pattern to find moves like "[A0 B0]"
)
# Import prompt management classes
from ..prompts import PromptPack, get_prompt_pack
class VLLMAgent(AgentLike):
"""
Agent class powered by vLLM for fast GPU inference.
This agent can load any HuggingFace model and use it to play Stratego.
Inherits from: AgentLike (protocol defining agent interface)
"""
def __init__(
self,
model_name: str, # String: HuggingFace model ID (e.g., "google/gemma-2-2b-it")
system_prompt: Optional[str] = None, # String or None: custom system prompt override
prompt_pack: Optional[PromptPack | str] = None, # PromptPack object or string: prompt configuration
temperature: float = 0.2, # Float: controls randomness (0.0=deterministic, 1.0=creative)
top_p: float = 0.9, # Float: nucleus sampling threshold
max_tokens: int = 64, # Integer: maximum tokens to generate per response
gpu_memory_utilization: float = 0.3, # Check Check checking Float: fraction of GPU memory to use (0.0-1.0)
tensor_parallel_size: int = 1, # Integer: number of GPUs to use for this model
download_dir: str = "/scratch/hm24/.cache/huggingface", # String: where to cache model files
**kwargs, # Dictionary: additional vLLM arguments
):
"""
Initialize the vLLM agent by loading a model from HuggingFace.
This constructor:
1. Sets up prompt configuration
2. Configures cache directories
3. Loads the model into GPU memory using vLLM
4. Configures generation parameters
"""
# Store the model name as an instance variable (self.model_name)
# This is used later for displaying which model made a move
self.model_name = model_name
# Handle prompt_pack parameter which can be a string name or PromptPack object
if isinstance(prompt_pack, str) or prompt_pack is None:
# If it's a string or None, load the prompt pack by name
# get_prompt_pack() returns a PromptPack object with system prompts and guidance
self.prompt_pack: PromptPack = get_prompt_pack(prompt_pack)
else:
# If it's already a PromptPack object, use it directly
self.prompt_pack = prompt_pack
# Set system prompt: use custom if provided, otherwise use from prompt pack
# The system prompt tells the model how to behave (e.g., "You are a Stratego player")
self.system_prompt = system_prompt if system_prompt is not None else self.prompt_pack.system
# Force HuggingFace to cache models in /scratch instead of home directory
# Environment variables control where transformers library saves downloaded models
os.environ["HF_HOME"] = download_dir
os.environ["TRANSFORMERS_CACHE"] = download_dir
# Print status messages to show progress
print(f"🤖 Loading {model_name} with vLLM...")
print(f"📁 Cache directory: {download_dir}")
# Create vLLM engine instance
# This loads the model from HuggingFace and prepares it for inference
self.llm = LLM(
model=model_name, # Which model to load
download_dir=download_dir, # Where to save/load model files
gpu_memory_utilization=gpu_memory_utilization, # How much GPU memory to use
tensor_parallel_size=tensor_parallel_size, # How many GPUs to split model across
trust_remote_code=True, # Allow custom model code from HuggingFace
**kwargs # Pass any additional vLLM parameters
)
# Create sampling parameters object
# This controls how the model generates text (temperature, length, etc.)
self.sampling_params = SamplingParams(
temperature=temperature, # Randomness in generation
top_p=top_p, # Nucleus sampling parameter
max_tokens=max_tokens, # Maximum length of generated response
stop=["\n\n", "Player", "Legal moves:"], # List of strings that stop generation
)
# Print success message
print(f"✅ Model loaded successfully!")
def _llm_once(self, prompt: str) -> str:
"""
Generate a single response from the model.
This is a private method (starts with _) used internally.
Args:
prompt (str): The input text to send to the model
Returns:
str: The model's response text, cleaned of thinking markers
"""
# Combine system prompt and user prompt into full prompt
# Format: "System: <system_prompt>\n\nUser: <prompt>"
full_prompt = f"{self.system_prompt}\n\n{prompt}"
# Call vLLM to generate response
# Returns a list of output objects (we only generate 1, so index [0])
outputs = self.llm.generate([full_prompt], self.sampling_params)
# Extract text from first output's first completion
# Structure: outputs[request_index].outputs[completion_index].text
response = outputs[0].outputs[0].text.strip()
# Remove any thinking markers (like <think>...</think>) from response
# strip_think() is a utility function that cleans the text
return strip_think(response)
def __call__(self, observation: str) -> str:
"""
Main method called when agent needs to make a move.
This makes the agent callable like a function: agent(observation)
Args:
observation (str): The current game state as text from TextArena
Returns:
str: A move in format "[A0 B0]" representing from-square to-square
"""
# Step 1: Extract list of legal moves from observation text
# Returns list like ["[A0 B0]", "[A1 B1]", ...]
legal = extract_legal_moves(observation)
# If no legal moves exist, return empty string (game might be over)
if not legal:
return ""
# Step 2: Get forbidden moves (moves that were already tried and failed)
# Returns set of move strings to avoid
forbidden = set(extract_forbidden(observation))
# Filter legal moves to remove forbidden ones
# List comprehension: keep only moves NOT in forbidden set
# If all moves forbidden, fall back to full legal list
legal_filtered = [m for m in legal if m not in forbidden] or legal[:]
# Step 3: Create compact version of observation for model
# slice_board_and_moves() removes unnecessary text to save tokens
slim = slice_board_and_moves(observation)
# Get guidance prompt from prompt pack
# This wraps the slim observation with instructions
guidance = self.prompt_pack.guidance(slim)
# Step 4: Try to get valid move with retry loop (max 3 attempts)
for attempt in range(3):
# First strategy: Use guidance prompt
# Call model with full game context and instructions
raw = self._llm_once(guidance)
# Search for move pattern in response using regex
# MOVE_RE.search() looks for pattern like "[A0 B0]"
m = MOVE_RE.search(raw)
# If regex found a match
if m:
# Extract the matched move string
mv = m.group(0)
# Check if this move is in our legal filtered list
if mv in legal_filtered:
# Valid move found! Return it
return mv
# Second strategy (attempts 1 and 2): Direct instruction
if attempt > 0:
# Ask model directly for move without game context
raw2 = self._llm_once("Output exactly one legal move [A0 B0].")
# Search for move pattern in this response
m2 = MOVE_RE.search(raw2)
if m2:
mv2 = m2.group(0)
# Check validity
if mv2 in legal_filtered:
return mv2
# Step 5: Fallback - all attempts failed
# Choose random legal move to ensure game continues
# random.choice() picks one item from list randomly
return random.choice(legal_filtered)
def cleanup(self):
"""
Free GPU memory by deleting model.
Call this when done with agent to release resources.
"""
# Delete the LLM object, freeing VRAM
del self.llm
# Import torch to access CUDA functions
import torch
# Force PyTorch to release all unused GPU memory
torch.cuda.empty_cache() |