| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
|
|
| DEFAULT_PROFILES = [ |
| (1, 128), |
| (4, 128), |
| ] |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class ShapeProfile: |
| batch_size: int |
| seq_len: int |
|
|
| @property |
| def profile_id(self) -> str: |
| return f"b{self.batch_size}_s{self.seq_len}" |
|
|
|
|
| def parse_profiles(raw: str | None) -> list[ShapeProfile]: |
| if not raw: |
| return [ShapeProfile(batch_size=b, seq_len=s) for b, s in DEFAULT_PROFILES] |
|
|
| parsed: list[ShapeProfile] = [] |
| for item in raw.split(","): |
| item = item.strip().lower() |
| if not item: |
| continue |
| if "x" not in item: |
| raise ValueError(f"Invalid profile '{item}'. Expected format BxS, e.g. 4x512") |
| left, right = item.split("x", 1) |
| batch_size = int(left) |
| seq_len = int(right) |
| if batch_size <= 0 or seq_len <= 0: |
| raise ValueError(f"Invalid profile '{item}'. B and S must be positive") |
| parsed.append(ShapeProfile(batch_size=batch_size, seq_len=seq_len)) |
|
|
| if not parsed: |
| raise ValueError("No valid profiles parsed") |
|
|
| unique = {(p.batch_size, p.seq_len): p for p in parsed} |
| return [ |
| unique[key] |
| for key in sorted(unique.keys(), key=lambda x: (x[1], x[0])) |
| ] |
|
|