File size: 12,892 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
#!/usr/bin/env python
# Copyright 2024 Bytedance
# Apache-2.0
#
# VERL + vLLM inference with runtime LoRA (no merge).
# - Wraps a LoRA .pt into a PEFT adapter and attaches via rollout.lora_modules
# - Mixed precision defaults for H100: dtype=bf16, kv_cache_dtype=fp8_e5m2
# - Pins max_model_len, max_num_batched_tokens, sets swap_space
# - Uses OmegaConf.open_dict to add keys safely (no "not in struct" errors)
# - Prevents FSDP from trying to load LoRA .pt as a full model
import os
import ast
import json
import hydra
import numpy as np
import ray
import torch
from pathlib import Path
from pprint import pprint
# Quiet logs
os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN")
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
# vLLM CuMem allocator is incompatible with expandable_segments
_bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if "expandable_segments:True" in _bad:
print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}")
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
import pandas as pd
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.hdfs_io import makedirs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
# ---------------- LoRA helpers ----------------
DEFAULT_TARGET_MODULES = [
"q_proj","k_proj","v_proj","o_proj",
"up_proj","gate_proj","down_proj",
]
def _infer_lengths_and_defaults(config):
"""Ensure rollout/data keys exist and set reasonable H100 defaults."""
# Ensure nested structs exist
with open_dict(config):
if "rollout" not in config:
config["rollout"] = OmegaConf.create()
if "data" not in config:
config["data"] = OmegaConf.create()
if "trainer" not in config:
config["trainer"] = OmegaConf.create()
if "ray_init" not in config:
config["ray_init"] = OmegaConf.create()
# Defaults that work on a single H100
with open_dict(config.rollout):
# If user didn't set these, choose H100-friendly defaults
config.rollout.setdefault("dtype", "bfloat16") # weights/activations
config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2") # KV cache precision
config.rollout.setdefault("tensor_model_parallel_size", 1)
config.rollout.setdefault("enable_chunked_prefill", True)
config.rollout.setdefault("swap_space", 8) # GB of host swap for KV
config.rollout.setdefault("gpu_memory_utilization", 0.62) # adjust 0.60~0.75 if needed
# Pin lengths to avoid vLLM over-reserving KV cache
pl = int(config.rollout.get("prompt_length", 1024))
rl = int(config.rollout.get("response_length", 128))
need = int(pl + rl)
config.rollout.setdefault("max_model_len", need)
config.rollout.setdefault("max_num_batched_tokens", need)
# Users may pass +rollout.quantization={fp8|awq|gptq} to shrink weights further
# We don't force it here.
with open_dict(config.data):
config.data.setdefault("batch_size", 1)
config.data.setdefault("n_samples", 1)
config.data.setdefault("prompt_key", "prompt")
with open_dict(config.trainer):
config.trainer.setdefault("n_gpus_per_node", 1)
config.trainer.setdefault("nnodes", 1)
with open_dict(config.ray_init):
config.ray_init.setdefault("num_cpus", 4)
def _infer_lora_rank_from_state(sd):
for k, v in sd.items():
if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2:
return int(v.shape[0])
return None
def _list_target_modules_from_state(sd):
found = set()
for k in sd.keys():
if "lora_A.weight" in k or "lora_B.weight" in k:
if ".q_proj." in k: found.add("q_proj")
if ".k_proj." in k: found.add("k_proj")
if ".v_proj." in k: found.add("v_proj")
if ".o_proj." in k: found.add("o_proj")
if ".up_proj." in k: found.add("up_proj")
if ".gate_proj." in k: found.add("gate_proj")
if ".down_proj." in k: found.add("down_proj")
return sorted(found)
def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0):
cfg = {
"peft_type": "LORA",
"auto_mapping": None,
"base_model_name_or_path": "",
"bias": "none",
"inference_mode": True,
"lora_alpha": int(alpha),
"lora_dropout": float(dropout),
"r": int(r),
"target_modules": target_modules,
"task_type": "CAUSAL_LM",
}
with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(cfg, f, ensure_ascii=False, indent=2)
def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str,
fallback_rank=32, fallback_alpha=16):
os.makedirs(out_dir, exist_ok=True)
print(f"[lora] Loading LoRA state from: {adapter_pt_path}")
sd = torch.load(adapter_pt_path, map_location="cpu")
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
r = _infer_lora_rank_from_state(sd) or int(fallback_rank)
tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES
print(f"[lora] inferred rank={r}, target_modules={tmods}")
_write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods)
torch.save(sd, os.path.join(out_dir, "adapter_model.bin"))
return r, tmods
def _maybe_attach_lora_adapter(config):
"""Attach LoRA adapter directory to vLLM rollout (runtime LoRA)."""
# Accept either +lora.pt_path or model.load_param_path as a hint
lora_pt = None
if "lora" in config and getattr(config.lora, "pt_path", ""):
lora_pt = config.lora.pt_path
elif getattr(config.model, "load_param_path", ""):
lora_pt = config.model.load_param_path
if not lora_pt or not Path(lora_pt).is_file():
print("[lora] No LoRA .pt provided; running base model only.")
return
adapter_dir = os.path.join("/tmp", "lora_adapter_vllm")
r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16)
# Ensure rollout keys exist and add LoRA knobs required by vLLM
with open_dict(config):
if "rollout" not in config:
config["rollout"] = OmegaConf.create()
with open_dict(config.rollout):
config.rollout.setdefault("max_loras", 1)
config.rollout.setdefault("max_lora_rank", int(r))
config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}]
print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})")
# CRITICAL: don't let FSDP try to load the LoRA .pt as a full state dict
with open_dict(config.model):
if getattr(config.model, "load_param", False):
print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.")
config.model["load_param"] = False
# ---------------- Hydra entry ----------------
@hydra.main(config_path="config", config_name="infer", version_base=None)
def main(config):
_infer_lengths_and_defaults(config)
# Ray env for workers
if not ray.is_initialized():
ray.init(
runtime_env={"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"PYTORCH_CUDA_ALLOC_CONF": "", # keep allocator happy for vLLM
}},
num_cpus=config.ray_init.num_cpus,
)
ray.get(main_task.remote(config))
@ray.remote(num_cpus=1)
def main_task(config):
print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
# Build LoRA adapter if provided
_maybe_attach_lora_adapter(config)
# Optionally pre-gen dataset schema if your repo provides it
try:
from prompts.infer_prompt import infer_dataset
infer_dataset(
model_name=config.model.path,
data_path=os.path.dirname(os.path.dirname(config.data.path)),
)
except Exception as e:
print(f"[info] infer_dataset() skipped: {e}")
# ---- Tokenizer from base model
local_path = copy_to_local(config.model.path)
trust_remote_code = getattr(config.model, "trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ---- Sampling checks
if float(config.rollout.temperature) == 0.0:
assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1."
assert int(config.data.n_samples) >= 1, "n_samples should always >= 1"
# ---- Load dataset
dataset = pd.read_parquet(config.data.path)
prompt_key = getattr(config.data, "prompt_key", "prompt")
if prompt_key not in dataset.columns:
raise KeyError(f"Dataset missing column '{prompt_key}'")
chat_lst = dataset[prompt_key].tolist()
chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst]
# ---- Worker group (vLLM inside Rollout)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None))
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
wg.init_model() # vLLM spins up; adapter used if set in rollout.lora_modules
total = len(dataset)
bs = int(config.data.batch_size)
num_batch = -(-total // bs)
slots = [[] for _ in range(int(config.data.n_samples))]
for b in range(num_batch):
print(f"[{b+1}/{num_batch}] Start to process.")
batch_chat = chat_lst[b * bs : (b + 1) * bs]
inputs = tokenizer.apply_chat_template(
batch_chat,
add_generation_prompt=True,
padding=True,
truncation=True,
max_length=int(config.rollout.prompt_length),
return_tensors="pt",
return_dict=True,
tokenize=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
position_ids = compute_position_id_with_mask(attention_mask)
batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
data = DataProto.from_dict(batch_dict)
data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
print(f"[{b+1}/{num_batch}] Start to generate.")
for n in range(int(config.data.n_samples)):
output_padded = wg.generate_sequences(data_padded)
output = unpad_dataproto(output_padded, pad_size=pad_size)
texts = []
for i in range(len(output)):
item = output[i]
pl = item.batch["prompts"].shape[-1]
valid_len = item.batch["attention_mask"][pl:].sum()
resp_ids = item.batch["responses"][:valid_len]
s = tokenizer.decode(resp_ids, skip_special_tokens=True)
print(f"[raw] Response {i}: {s!r}")
ix = s.find("</think>")
if ix != -1:
s = s[ix + len("</think>") :].lstrip()
print(f"Response {i}: {s!r}")
try:
texts.append(ast.literal_eval(s))
except Exception:
texts.append(s)
slots[n].extend(texts)
outputs = np.array(slots, dtype=object)
outputs = np.transpose(outputs, (1, 0)).tolist()
dataset["response"] = outputs
keep = ["file_id", "vt", "gt", "response"]
cols = [c for c in keep if c in dataset.columns]
if cols:
dataset = dataset[cols]
out_path = config.data.output_path
makedirs(os.path.dirname(out_path), exist_ok=True)
dataset.to_json(out_path, orient="records", lines=True, force_ascii=False)
print(f"[done] Wrote: {out_path}")
if __name__ == "__main__":
main()
|