File size: 3,357 Bytes
c31cf57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
from __future__ import annotations

import sys
import subprocess
from pathlib import Path
from urllib.request import urlopen, Request


HF_BLOB_URL = "https://huggingface.co/sarvamai/sarvam-30b/blob/main/sarvam.py"

NEW_LINES = [
    '    "SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"),\n',
    '    "SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"),\n',
]


def run(cmd: list[str]) -> None:
    print(f"+ {' '.join(cmd)}")
    subprocess.check_call(cmd)


def pip_install_vllm() -> None:
    run([sys.executable, "-m", "pip", "install", "vllm==0.15.0"])


def find_vllm_dir() -> Path:
    import vllm  # type: ignore

    vllm_dir = Path(vllm.__file__).resolve().parent
    print(f"Detected vLLM package dir: {vllm_dir}")
    return vllm_dir


def patch_text_generation_models(registry_path: Path) -> None:
    if not registry_path.exists():
        raise FileNotFoundError(f"registry.py not found at: {registry_path}")

    text = registry_path.read_text(encoding="utf-8")
    lines = text.splitlines(keepends=True)

    # Idempotency: if both keys already present, do nothing
    if (
        any('"SarvamMoEForCausalLM"' in l for l in lines)
        and any('"SarvamMLAForCausalLM"' in l for l in lines)
    ):
        print("registry.py already contains Sarvam entries. Skipping patch.")
        return

    # Find the start of the _TEXT_GENERATION_MODELS dict
    start_idx = None
    for i, line in enumerate(lines):
        if line.strip() == "_TEXT_GENERATION_MODELS = {":
            start_idx = i
            break

    if start_idx is None:
        raise RuntimeError(
            "Could not find '_TEXT_GENERATION_MODELS = {' in registry.py. "
            "vLLM version/layout may differ."
        )

    # Find the matching closing brace for that dict using brace depth
    depth = 0
    end_idx = None
    for j in range(start_idx, len(lines)):
        depth += lines[j].count("{")
        depth -= lines[j].count("}")
        if j > start_idx and depth == 0:
            end_idx = j
            break

    if end_idx is None:
        raise RuntimeError("Failed to find end of _TEXT_GENERATION_MODELS dict (brace matching).")

    # Insert new entries just before the closing brace line
    insert_at = end_idx
    lines[insert_at:insert_at] = NEW_LINES

    registry_path.write_text("".join(lines), encoding="utf-8")
    print(f"Patched _TEXT_GENERATION_MODELS in: {registry_path}")


def download_sarvam_py(dst: Path) -> None:
    # Use /raw/ to download file contents, not HTML
    raw_url = HF_BLOB_URL.replace("/blob/", "/raw/")
    print(f"Downloading sarvam.py from: {raw_url}")

    req = Request(raw_url, headers={"User-Agent": "vllm-hotpatch-script"})
    with urlopen(req) as resp:
        data = resp.read()

    dst.parent.mkdir(parents=True, exist_ok=True)
    dst.write_bytes(data)
    print(f"Wrote: {dst}")


def main() -> None:
    pip_install_vllm()

    vllm_dir = find_vllm_dir()
    registry_path = vllm_dir / "model_executor" / "models" / "registry.py"
    sarvam_path = vllm_dir / "model_executor" / "models" / "sarvam.py"

    patch_text_generation_models(registry_path)
    download_sarvam_py(sarvam_path)

    print("\nDone.")
    print(f"- Registry patched: {registry_path}")
    print(f"- Sarvam module installed: {sarvam_path}")


if __name__ == "__main__":
    main()