|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import ast
|
| import json
|
| import hydra
|
| import numpy as np
|
| import ray
|
| import torch
|
| from pathlib import Path
|
| from pprint import pprint
|
|
|
|
|
| os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN")
|
| os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
|
|
|
|
|
| _bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
| if "expandable_segments:True" in _bad:
|
| print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}")
|
| os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
|
|
|
| import pandas as pd
|
| from omegaconf import OmegaConf, open_dict
|
|
|
| from verl import DataProto
|
| from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
| from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| from verl.utils import hf_tokenizer
|
| from verl.utils.fs import copy_to_local
|
| from verl.utils.hdfs_io import makedirs
|
| from verl.utils.model import compute_position_id_with_mask
|
| from verl.workers.fsdp_workers import ActorRolloutRefWorker
|
|
|
|
|
|
|
| DEFAULT_TARGET_MODULES = [
|
| "q_proj","k_proj","v_proj","o_proj",
|
| "up_proj","gate_proj","down_proj",
|
| ]
|
|
|
| def _infer_lengths_and_defaults(config):
|
| """Ensure rollout/data keys exist and set reasonable H100 defaults."""
|
|
|
| with open_dict(config):
|
| if "rollout" not in config:
|
| config["rollout"] = OmegaConf.create()
|
| if "data" not in config:
|
| config["data"] = OmegaConf.create()
|
| if "trainer" not in config:
|
| config["trainer"] = OmegaConf.create()
|
| if "ray_init" not in config:
|
| config["ray_init"] = OmegaConf.create()
|
|
|
|
|
| with open_dict(config.rollout):
|
|
|
| config.rollout.setdefault("dtype", "bfloat16")
|
| config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2")
|
| config.rollout.setdefault("tensor_model_parallel_size", 1)
|
| config.rollout.setdefault("enable_chunked_prefill", True)
|
| config.rollout.setdefault("swap_space", 8)
|
| config.rollout.setdefault("gpu_memory_utilization", 0.62)
|
|
|
|
|
| pl = int(config.rollout.get("prompt_length", 1024))
|
| rl = int(config.rollout.get("response_length", 128))
|
| need = int(pl + rl)
|
| config.rollout.setdefault("max_model_len", need)
|
| config.rollout.setdefault("max_num_batched_tokens", need)
|
|
|
|
|
|
|
|
|
| with open_dict(config.data):
|
| config.data.setdefault("batch_size", 1)
|
| config.data.setdefault("n_samples", 1)
|
| config.data.setdefault("prompt_key", "prompt")
|
|
|
| with open_dict(config.trainer):
|
| config.trainer.setdefault("n_gpus_per_node", 1)
|
| config.trainer.setdefault("nnodes", 1)
|
|
|
| with open_dict(config.ray_init):
|
| config.ray_init.setdefault("num_cpus", 4)
|
|
|
| def _infer_lora_rank_from_state(sd):
|
| for k, v in sd.items():
|
| if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2:
|
| return int(v.shape[0])
|
| return None
|
|
|
| def _list_target_modules_from_state(sd):
|
| found = set()
|
| for k in sd.keys():
|
| if "lora_A.weight" in k or "lora_B.weight" in k:
|
| if ".q_proj." in k: found.add("q_proj")
|
| if ".k_proj." in k: found.add("k_proj")
|
| if ".v_proj." in k: found.add("v_proj")
|
| if ".o_proj." in k: found.add("o_proj")
|
| if ".up_proj." in k: found.add("up_proj")
|
| if ".gate_proj." in k: found.add("gate_proj")
|
| if ".down_proj." in k: found.add("down_proj")
|
| return sorted(found)
|
|
|
| def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0):
|
| cfg = {
|
| "peft_type": "LORA",
|
| "auto_mapping": None,
|
| "base_model_name_or_path": "",
|
| "bias": "none",
|
| "inference_mode": True,
|
| "lora_alpha": int(alpha),
|
| "lora_dropout": float(dropout),
|
| "r": int(r),
|
| "target_modules": target_modules,
|
| "task_type": "CAUSAL_LM",
|
| }
|
| with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
|
| json.dump(cfg, f, ensure_ascii=False, indent=2)
|
|
|
| def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str,
|
| fallback_rank=32, fallback_alpha=16):
|
| os.makedirs(out_dir, exist_ok=True)
|
| print(f"[lora] Loading LoRA state from: {adapter_pt_path}")
|
| sd = torch.load(adapter_pt_path, map_location="cpu")
|
| if isinstance(sd, dict) and "state_dict" in sd:
|
| sd = sd["state_dict"]
|
|
|
| r = _infer_lora_rank_from_state(sd) or int(fallback_rank)
|
| tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES
|
| print(f"[lora] inferred rank={r}, target_modules={tmods}")
|
|
|
| _write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods)
|
| torch.save(sd, os.path.join(out_dir, "adapter_model.bin"))
|
| return r, tmods
|
|
|
| def _maybe_attach_lora_adapter(config):
|
| """Attach LoRA adapter directory to vLLM rollout (runtime LoRA)."""
|
|
|
| lora_pt = None
|
| if "lora" in config and getattr(config.lora, "pt_path", ""):
|
| lora_pt = config.lora.pt_path
|
| elif getattr(config.model, "load_param_path", ""):
|
| lora_pt = config.model.load_param_path
|
|
|
| if not lora_pt or not Path(lora_pt).is_file():
|
| print("[lora] No LoRA .pt provided; running base model only.")
|
| return
|
|
|
| adapter_dir = os.path.join("/tmp", "lora_adapter_vllm")
|
| r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16)
|
|
|
|
|
| with open_dict(config):
|
| if "rollout" not in config:
|
| config["rollout"] = OmegaConf.create()
|
| with open_dict(config.rollout):
|
| config.rollout.setdefault("max_loras", 1)
|
| config.rollout.setdefault("max_lora_rank", int(r))
|
| config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}]
|
| print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})")
|
|
|
|
|
| with open_dict(config.model):
|
| if getattr(config.model, "load_param", False):
|
| print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.")
|
| config.model["load_param"] = False
|
|
|
|
|
|
|
| @hydra.main(config_path="config", config_name="infer", version_base=None)
|
| def main(config):
|
| _infer_lengths_and_defaults(config)
|
|
|
|
|
| if not ray.is_initialized():
|
| ray.init(
|
| runtime_env={"env_vars": {
|
| "TOKENIZERS_PARALLELISM": "true",
|
| "NCCL_DEBUG": "WARN",
|
| "PYTORCH_CUDA_ALLOC_CONF": "",
|
| }},
|
| num_cpus=config.ray_init.num_cpus,
|
| )
|
|
|
| ray.get(main_task.remote(config))
|
|
|
| @ray.remote(num_cpus=1)
|
| def main_task(config):
|
| print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
|
| pprint(OmegaConf.to_container(config, resolve=True))
|
| OmegaConf.resolve(config)
|
|
|
|
|
| _maybe_attach_lora_adapter(config)
|
|
|
|
|
| try:
|
| from prompts.infer_prompt import infer_dataset
|
| infer_dataset(
|
| model_name=config.model.path,
|
| data_path=os.path.dirname(os.path.dirname(config.data.path)),
|
| )
|
| except Exception as e:
|
| print(f"[info] infer_dataset() skipped: {e}")
|
|
|
|
|
| local_path = copy_to_local(config.model.path)
|
| trust_remote_code = getattr(config.model, "trust_remote_code", False)
|
| tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
| tokenizer.padding_side = "left"
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| if float(config.rollout.temperature) == 0.0:
|
| assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1."
|
| assert int(config.data.n_samples) >= 1, "n_samples should always >= 1"
|
|
|
|
|
| dataset = pd.read_parquet(config.data.path)
|
| prompt_key = getattr(config.data, "prompt_key", "prompt")
|
| if prompt_key not in dataset.columns:
|
| raise KeyError(f"Dataset missing column '{prompt_key}'")
|
| chat_lst = dataset[prompt_key].tolist()
|
| chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst]
|
|
|
|
|
| ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
|
| resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
|
| print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None))
|
| wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
|
| wg.init_model()
|
|
|
| total = len(dataset)
|
| bs = int(config.data.batch_size)
|
| num_batch = -(-total // bs)
|
| slots = [[] for _ in range(int(config.data.n_samples))]
|
|
|
| for b in range(num_batch):
|
| print(f"[{b+1}/{num_batch}] Start to process.")
|
| batch_chat = chat_lst[b * bs : (b + 1) * bs]
|
|
|
| inputs = tokenizer.apply_chat_template(
|
| batch_chat,
|
| add_generation_prompt=True,
|
| padding=True,
|
| truncation=True,
|
| max_length=int(config.rollout.prompt_length),
|
| return_tensors="pt",
|
| return_dict=True,
|
| tokenize=True,
|
| )
|
| input_ids = inputs["input_ids"]
|
| attention_mask = inputs["attention_mask"]
|
| position_ids = compute_position_id_with_mask(attention_mask)
|
| batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
|
|
|
| data = DataProto.from_dict(batch_dict)
|
| data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
|
|
|
| print(f"[{b+1}/{num_batch}] Start to generate.")
|
| for n in range(int(config.data.n_samples)):
|
| output_padded = wg.generate_sequences(data_padded)
|
| output = unpad_dataproto(output_padded, pad_size=pad_size)
|
| texts = []
|
| for i in range(len(output)):
|
| item = output[i]
|
| pl = item.batch["prompts"].shape[-1]
|
| valid_len = item.batch["attention_mask"][pl:].sum()
|
| resp_ids = item.batch["responses"][:valid_len]
|
| s = tokenizer.decode(resp_ids, skip_special_tokens=True)
|
| print(f"[raw] Response {i}: {s!r}")
|
| ix = s.find("</think>")
|
| if ix != -1:
|
| s = s[ix + len("</think>") :].lstrip()
|
| print(f"Response {i}: {s!r}")
|
| try:
|
| texts.append(ast.literal_eval(s))
|
| except Exception:
|
| texts.append(s)
|
| slots[n].extend(texts)
|
|
|
| outputs = np.array(slots, dtype=object)
|
| outputs = np.transpose(outputs, (1, 0)).tolist()
|
| dataset["response"] = outputs
|
|
|
| keep = ["file_id", "vt", "gt", "response"]
|
| cols = [c for c in keep if c in dataset.columns]
|
| if cols:
|
| dataset = dataset[cols]
|
|
|
| out_path = config.data.output_path
|
| makedirs(os.path.dirname(out_path), exist_ok=True)
|
| dataset.to_json(out_path, orient="records", lines=True, force_ascii=False)
|
| print(f"[done] Wrote: {out_path}")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|