Gausson commited on
Commit
68151bc
·
verified ·
1 Parent(s): 5116c77

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +4 -56
custom_generate/generate.py CHANGED
@@ -1,23 +1,6 @@
1
- def debug_imports():
2
- import sys
3
- import os
4
- import inspect
5
-
6
- print("\n===== 导入调试信息 =====")
7
- print(f"当前工作目录: {os.getcwd()}")
8
- print(f"脚本路径: {os.path.abspath(__file__)}")
9
- print(f"脚本所在目录: {os.path.dirname(os.path.abspath(__file__))}")
10
- print(f"父目录: {os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
11
- print(f"Python路径(sys.path):")
12
- for p in sys.path:
13
- print(f" - {p}")
14
-
15
-
16
- print("=======================\n")
17
-
18
- # 在脚本开头调用
19
- debug_imports()
20
-
21
 
22
 
23
  import torch
@@ -30,28 +13,10 @@ import torch.nn as nn
30
  from transformers.modeling_utils import PreTrainedModel
31
 
32
 
33
- # try:
34
- # from functions_2_patch import _validate_model_kwargs, llama_atten_forward
35
- # from monkey_patching_utils import monkey_patching
36
- # from sep_cache_utils import SepCache
37
- # except :
38
- # from ..functions_2_patch import _validate_model_kwargs, llama_atten_forward
39
- # from ..monkey_patching_utils import monkey_patching
40
- # from ..sep_cache_utils import SepCache
41
-
42
-
43
  from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
44
  from .monkey_patching_utils import monkey_patching
45
  from .sep_cache_utils import SepCache
46
 
47
- # except :
48
- # from ..functions_2_patch import _validate_model_kwargs, llama_atten_forward
49
- # from ..monkey_patching_utils import monkey_patching
50
- # from ..sep_cache_utils import SepCache
51
-
52
-
53
-
54
-
55
 
56
  UNSUPPORTED_GENERATION_ARGS = [
57
  "cache_implementation", # cache-related arguments, here we always use SepCache
@@ -110,23 +75,6 @@ def generate(model,
110
 
111
  **kwargs
112
  ):
113
-
114
- debug_imports()
115
-
116
- import sys
117
- import os
118
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
119
- # from utils.sep_cache_utils import SepCache
120
-
121
- print(f"__file__: {__file__}")
122
- print(f"os.path.abspath(__file__): {os.path.abspath(__file__)}")
123
-
124
- # from ..functions_2_patch import _validate_model_kwargs, llama_atten_forward
125
- # from ..monkey_patching_utils import monkey_patching
126
- # from ..sep_cache_utils import SepCache
127
-
128
-
129
-
130
  """Custom generate function for SepCache.
131
 
132
  A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase,
@@ -229,7 +177,7 @@ def generate(model,
229
  """
230
 
231
  # 0. Monkey Patching for the `update` function of `SepCache`
232
- # model_layers = monkey_patching(model, model_atten_forward=llama_atten_forward, verbose=monkey_patch_verbose)
233
 
234
  # 1. General sanity checks
235
  # 1.a. A few arguments are not allowed, especially arguments that control caches.
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  import torch
 
13
  from transformers.modeling_utils import PreTrainedModel
14
 
15
 
 
 
 
 
 
 
 
 
 
 
16
  from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
17
  from .monkey_patching_utils import monkey_patching
18
  from .sep_cache_utils import SepCache
19
 
 
 
 
 
 
 
 
 
20
 
21
  UNSUPPORTED_GENERATION_ARGS = [
22
  "cache_implementation", # cache-related arguments, here we always use SepCache
 
75
 
76
  **kwargs
77
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  """Custom generate function for SepCache.
79
 
80
  A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase,
 
177
  """
178
 
179
  # 0. Monkey Patching for the `update` function of `SepCache`
180
+ model_layers = monkey_patching(model, model_atten_forward=llama_atten_forward, verbose=monkey_patch_verbose)
181
 
182
  # 1. General sanity checks
183
  # 1.a. A few arguments are not allowed, especially arguments that control caches.