trohrbaugh commited on
Commit
4a382cf
·
verified ·
1 Parent(s): 831268f

Upload patch_iquestcoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. patch_iquestcoder.py +126 -0
patch_iquestcoder.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Patch IQuestLab/IQuest-Coder models for transformers 5.x compatibility.
4
+
5
+ Fixes the meta-device RoPE bug where accelerate zeros out inv_freq during
6
+ model initialization, causing the model to produce only newlines/garbage.
7
+
8
+ Usage:
9
+ python patch_iquestcoder.py
10
+
11
+ This will find and patch all cached IQuest-Coder modeling files automatically.
12
+ Run this AFTER downloading the model (e.g. after a failed heretic run or
13
+ after running `huggingface-cli download IQuestLab/IQuest-Coder-V1-40B-Instruct`).
14
+ """
15
+
16
+ import glob
17
+ import os
18
+ import re
19
+ import sys
20
+
21
+ # Pattern to find the forward method that needs patching
22
+ ORIGINAL_PATTERN = re.compile(
23
+ r'( @torch\.no_grad\(\)\n'
24
+ r' @dynamic_rope_update\n'
25
+ r' def forward\(self, x: torch\.Tensor, position_ids: torch\.Tensor\)'
26
+ r' -> Tuple\[torch\.Tensor, torch\.Tensor\]:\n)'
27
+ r'( inv_freq_expanded = self\.inv_freq\[None, :, None\]\.float\(\)\.expand\(position_ids\.shape\[0\], -1, 1\)\.to\(x\.device\))'
28
+ )
29
+
30
+ REPLACEMENT = (
31
+ r'\1'
32
+ r' # Lazy recompute: accelerate meta-device init leaves inv_freq as zeros\n'
33
+ r' if self.inv_freq is not None and self.inv_freq.numel() > 0 and (self.inv_freq == 0).all():\n'
34
+ r' inv_freq, self.attention_scaling = self.rope_init_fn(self.config, None)\n'
35
+ r' self.inv_freq = inv_freq.to(device=x.device, dtype=self.inv_freq.dtype)\n'
36
+ r' self.original_inv_freq = self.inv_freq\n'
37
+ r'\2'
38
+ )
39
+
40
+ # Check string to see if already patched
41
+ PATCH_MARKER = "Lazy recompute: accelerate meta-device init"
42
+
43
+ # Search locations for cached model files
44
+ SEARCH_PATHS = [
45
+ os.path.expanduser("~/.cache/huggingface/hub/models--IQuestLab--*/**/modeling_iquestcoder.py"),
46
+ "/llm/huggingface/modules/transformers_modules/IQuestLab/**/modeling_iquestcoder.py",
47
+ # Common alternate HF cache locations
48
+ "/data/huggingface/**/modeling_iquestcoder.py",
49
+ "/scratch/**/modeling_iquestcoder.py",
50
+ ]
51
+
52
+
53
+ def find_model_files():
54
+ """Find all cached IQuest-Coder modeling files."""
55
+ found = []
56
+ # Also check HF_HOME / TRANSFORMERS_CACHE env vars
57
+ for env_var in ["HF_HOME", "TRANSFORMERS_CACHE", "HUGGINGFACE_HUB_CACHE"]:
58
+ val = os.environ.get(env_var)
59
+ if val:
60
+ SEARCH_PATHS.append(os.path.join(val, "**/modeling_iquestcoder.py"))
61
+
62
+ for pattern in SEARCH_PATHS:
63
+ found.extend(glob.glob(pattern, recursive=True))
64
+
65
+ # Deduplicate (resolve symlinks)
66
+ seen = set()
67
+ unique = []
68
+ for f in found:
69
+ real = os.path.realpath(f)
70
+ if real not in seen:
71
+ seen.add(real)
72
+ unique.append(f)
73
+ return unique
74
+
75
+
76
+ def patch_file(filepath):
77
+ """Apply the RoPE lazy-recompute patch to a modeling file."""
78
+ with open(filepath, "r", encoding="utf-8") as f:
79
+ content = f.read()
80
+
81
+ if PATCH_MARKER in content:
82
+ print(f" SKIP (already patched): {filepath}")
83
+ return False
84
+
85
+ new_content, count = ORIGINAL_PATTERN.subn(REPLACEMENT, content)
86
+ if count == 0:
87
+ print(f" WARN (pattern not found — may need manual patching): {filepath}")
88
+ return False
89
+
90
+ with open(filepath, "w", encoding="utf-8") as f:
91
+ f.write(new_content)
92
+
93
+ print(f" OK (patched {count} location(s)): {filepath}")
94
+ return True
95
+
96
+
97
+ def main():
98
+ print("IQuest-Coder RoPE patch for transformers 5.x")
99
+ print("=" * 50)
100
+ print()
101
+
102
+ files = find_model_files()
103
+ if not files:
104
+ print("No IQuest-Coder model files found in cache.")
105
+ print("Download the model first, then re-run this script.")
106
+ print()
107
+ print("Searched:")
108
+ for p in SEARCH_PATHS:
109
+ print(f" {p}")
110
+ sys.exit(1)
111
+
112
+ print(f"Found {len(files)} file(s):\n")
113
+ patched = 0
114
+ for f in files:
115
+ if patch_file(f):
116
+ patched += 1
117
+
118
+ print()
119
+ if patched:
120
+ print(f"Done — patched {patched} file(s).")
121
+ else:
122
+ print("No files needed patching.")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()