| import functools | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import polars as pl | |
| import torch | |
| class DumpLoader: | |
| def __init__(self): | |
| directory = os.environ.get("SGLANG_DUMP_LOADER_DIR") | |
| self._enable = directory is not None | |
| if self._enable: | |
| self._directory = Path(directory) | |
| self._df = read_meta(directory) | |
| def enable(self): | |
| return self._enable | |
| def load(self, name, **kwargs): | |
| assert self._enable, "Please call DumpLoader.load only when it is enabled" | |
| from sglang.srt.debug_utils.dumper import dumper | |
| forward_pass_id = dumper._forward_pass_id | |
| conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs) | |
| row = find_row(self._df, conditions=conditions) | |
| assert ( | |
| row is not None | |
| ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}" | |
| path = self._directory / row["filename"] | |
| output = torch.load(path, weights_only=False) | |
| print( | |
| f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})" | |
| ) | |
| return output | |
| def read_meta(directory): | |
| directory = Path(directory) | |
| assert directory.is_dir(), f"{directory=} should be a directory" | |
| rows = [] | |
| for p in directory.glob("*.pt"): | |
| full_kwargs = {} | |
| for kv in p.stem.split("___"): | |
| k, v = kv.split("=") | |
| full_kwargs[k] = v | |
| rows.append( | |
| { | |
| "filename": str(p.name), | |
| **full_kwargs, | |
| } | |
| ) | |
| df = pl.DataFrame(rows) | |
| df = df.with_columns( | |
| pl.col("forward_pass_id").cast(int), | |
| pl.col("rank").cast(int), | |
| pl.col("dump_index").cast(int), | |
| ) | |
| return df | |
| def find_row(df, conditions: Dict[str, Any]): | |
| df_sub = df.filter( | |
| functools.reduce( | |
| lambda a, b: a & b, | |
| [ | |
| pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col]) | |
| for col in conditions.keys() | |
| ], | |
| ) | |
| ) | |
| assert len(df_sub) <= 1 | |
| return df_sub.to_dicts()[0] if len(df_sub) > 0 else None | |
| def _cast_to_polars_dtype(value, target_dtype): | |
| if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32): | |
| return int(value) | |
| elif target_dtype in (pl.Float64, pl.Float32): | |
| return float(value) | |
| elif target_dtype == pl.Boolean: | |
| return bool(value) | |
| elif target_dtype == pl.String: | |
| return str(value) | |
| else: | |
| return value | |
| dump_loader = DumpLoader() | |
Xet Storage Details
- Size:
- 2.63 kB
- Xet hash:
- 1b25e8bb6e2b6d5aec4daa5602ee3b791d6e6c3918009aca112e5848737da8e7
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.