moelanoby commited on
Commit
b00ef0f
·
verified ·
1 Parent(s): 42bcacc

Create alm_qwen.py

Browse files
Files changed (1) hide show
  1. alm_qwen.py +347 -0
alm_qwen.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE alm_qwen_hf.py ---
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
+ from typing import List, Tuple, Dict, Any
6
+ import os
7
+ import json
8
+
9
+ # Assuming ALM.py is in the same directory or accessible in PYTHONPATH
10
+ from ALM import AttentionLinkedMemory # Make sure ALM.py is saved and accessible
11
+
12
+ class QwenGenerator(nn.Module):
13
+ def __init__(self, model_name_or_path: str, device="cuda", tokenizer_path: str = None):
14
+ super().__init__()
15
+ self.device = device
16
+ self.model_name_or_path = model_name_or_path # Store for saving config
17
+ self.tokenizer_path = tokenizer_path if tokenizer_path else model_name_or_path
18
+
19
+ print(f"Loading Qwen model from: {self.model_name_or_path}...")
20
+ print(f"Loading Qwen tokenizer from: {self.tokenizer_path}...")
21
+
22
+ # Standard loading (requires more resources)
23
+ self.model = AutoModelForCausalLM.from_pretrained(
24
+ self.model_name_or_path,
25
+ torch_dtype="auto",
26
+ device_map="auto",
27
+ trust_remote_code=True
28
+ )
29
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True)
30
+
31
+ if self.tokenizer.pad_token is None:
32
+ self.tokenizer.pad_token = self.tokenizer.eos_token
33
+ self.tokenizer.padding_side = "left"
34
+
35
+ print(f"Qwen model and tokenizer loaded. Model device: {self.model.device}")
36
+
37
+ def format_prompt(self, query: str, context_snippets: List[str]) -> str:
38
+ if context_snippets:
39
+ context_str = "\n".join(f"- {cs}" for cs in context_snippets)
40
+ # Qwen specific chat format
41
+ final_prompt_str = "<|im_start|>system\nYou are a helpful assistant. Use the provided context to answer the user's query. If the context is insufficient, say so.<|im_end|>\n"
42
+ final_prompt_str += "<|im_start|>user\n"
43
+ final_prompt_str += f"Context:\n{context_str}\n\n"
44
+ final_prompt_str += f"Query:\n{query}\n<|im_end|>\n<|im_start|>assistant\n"
45
+ else:
46
+ # Qwen specific chat format without context
47
+ final_prompt_str = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
48
+ final_prompt_str += "<|im_start|>user\n"
49
+ final_prompt_str += f"Query:\n{query}\n<|im_end|>\n<|im_start|>assistant\n"
50
+ return final_prompt_str
51
+
52
+ def generate(self, prompts: List[str], max_new_tokens: int = 150, **kwargs) -> List[str]:
53
+ self.model.eval()
54
+ inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=2048)
55
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
56
+
57
+ with torch.no_grad():
58
+ outputs = self.model.generate(
59
+ **inputs,
60
+ max_new_tokens=max_new_tokens,
61
+ pad_token_id=self.tokenizer.pad_token_id,
62
+ eos_token_id=self.tokenizer.eos_token_id,
63
+ do_sample=kwargs.get("do_sample", True),
64
+ temperature=kwargs.get("temperature", 0.7),
65
+ top_p=kwargs.get("top_p", 0.9),
66
+ **kwargs
67
+ )
68
+
69
+ decoded_outputs = []
70
+ for i, output_ids in enumerate(outputs):
71
+ prompt_len = inputs['input_ids'][i].shape[0]
72
+ generated_ids = output_ids[prompt_len:]
73
+ decoded_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
74
+ decoded_outputs.append(decoded_text.strip())
75
+
76
+ return decoded_outputs
77
+
78
+ def save_pretrained(self, save_directory: str):
79
+ """Saves the Qwen model and tokenizer to a directory."""
80
+ model_save_path = os.path.join(save_directory, "qwen_model")
81
+ tokenizer_save_path = os.path.join(save_directory, "qwen_tokenizer")
82
+
83
+ print(f"Saving Qwen model to {model_save_path}")
84
+ self.model.save_pretrained(model_save_path)
85
+ print(f"Saving Qwen tokenizer to {tokenizer_save_path}")
86
+ self.tokenizer.save_pretrained(tokenizer_save_path)
87
+
88
+ class ALMQwenModel_HF(nn.Module):
89
+ def __init__(self,
90
+ alm_config: Dict[str, Any],
91
+ qwen_model_name_or_path: str, # Can be HF name or local path
92
+ qwen_tokenizer_path: str = None, # Optional separate path for tokenizer
93
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
94
+ top_k_buckets: int = 3,
95
+ top_k_items_per_bucket: int = 2):
96
+ super().__init__()
97
+ self.device = device
98
+ self.alm_config = alm_config # Store for saving
99
+ self.qwen_model_name_or_path = qwen_model_name_or_path # Store for saving
100
+ self.qwen_tokenizer_path = qwen_tokenizer_path
101
+ self.top_k_buckets = top_k_buckets # Store for saving
102
+ self.top_k_items_per_bucket = top_k_items_per_bucket # Store for saving
103
+
104
+ self.alm_layer = AttentionLinkedMemory(**alm_config).to(device)
105
+ self.qwen_generator = QwenGenerator(
106
+ model_name_or_path=qwen_model_name_or_path,
107
+ device=device,
108
+ tokenizer_path=qwen_tokenizer_path
109
+ )
110
+
111
+ def forward(self,
112
+ query_texts: List[str],
113
+ query_embeddings_for_alm: torch.Tensor,
114
+ memory_item_embeddings: torch.Tensor,
115
+ memory_text_items: List[List[List[str]]],
116
+ memory_mask: torch.Tensor = None
117
+ ) -> Tuple[List[str], torch.Tensor, torch.Tensor]:
118
+ self.alm_layer.eval()
119
+ query_embeddings_for_alm = query_embeddings_for_alm.to(self.device)
120
+ memory_item_embeddings = memory_item_embeddings.to(self.device)
121
+ if memory_mask is not None:
122
+ memory_mask = memory_mask.to(self.device)
123
+
124
+ with torch.no_grad():
125
+ _, bucket_att_weights, item_att_weights = self.alm_layer(
126
+ query_embeddings_for_alm, memory_item_embeddings, memory_mask
127
+ )
128
+
129
+ batch_retrieved_texts: List[List[str]] = []
130
+ for b_idx in range(len(query_texts)):
131
+ retrieved_for_sample: List[str] = []
132
+ current_bucket_weights = bucket_att_weights[b_idx]
133
+ _, top_bucket_indices = torch.topk(current_bucket_weights,
134
+ k=min(self.top_k_buckets, current_bucket_weights.size(0)))
135
+
136
+ for bucket_idx in top_bucket_indices:
137
+ bucket_idx_item = bucket_idx.item()
138
+ current_item_weights = item_att_weights[b_idx, bucket_idx_item, :]
139
+
140
+ if memory_mask is not None:
141
+ item_m = memory_mask[b_idx, bucket_idx_item, :]
142
+ current_item_weights = current_item_weights.masked_fill(item_m == 0, -float('inf'))
143
+
144
+ num_valid_items = (current_item_weights > -float('inf')).sum().item()
145
+ if num_valid_items == 0: continue
146
+
147
+ _, top_item_indices_in_bucket = torch.topk(current_item_weights,
148
+ k=min(self.top_k_items_per_bucket, num_valid_items))
149
+
150
+ for item_idx_in_bucket in top_item_indices_in_bucket:
151
+ item_idx_in_bucket_item = item_idx_in_bucket.item()
152
+ if memory_mask is not None and not memory_mask[b_idx, bucket_idx_item, item_idx_in_bucket_item]:
153
+ continue
154
+ try:
155
+ text_content = memory_text_items[b_idx][bucket_idx_item][item_idx_in_bucket_item]
156
+ if text_content:
157
+ retrieved_for_sample.append(text_content)
158
+ except IndexError:
159
+ print(f"Warning: IndexError accessing memory_text_items[{b_idx}][{bucket_idx_item}][{item_idx_in_bucket_item}]")
160
+ continue
161
+ batch_retrieved_texts.append(list(dict.fromkeys(retrieved_for_sample)))
162
+
163
+ prompts_for_qwen = []
164
+ for i, q_text in enumerate(query_texts):
165
+ prompt = self.qwen_generator.format_prompt(q_text, batch_retrieved_texts[i])
166
+ prompts_for_qwen.append(prompt)
167
+
168
+ generated_answers = self.qwen_generator.generate(prompts_for_qwen)
169
+ return generated_answers, bucket_att_weights, item_att_weights
170
+
171
+ def save_model(self, save_directory: str):
172
+ """Saves the entire ALMQwenModel_HF to the specified directory."""
173
+ os.makedirs(save_directory, exist_ok=True)
174
+
175
+ # 1. Save ALM layer state_dict
176
+ alm_state_dict_path = os.path.join(save_directory, "alm_layer_state_dict.pth")
177
+ torch.save(self.alm_layer.state_dict(), alm_state_dict_path)
178
+ print(f"ALM layer state_dict saved to {alm_state_dict_path}")
179
+
180
+ # 2. Save QwenGenerator (model and tokenizer)
181
+ qwen_save_path = os.path.join(save_directory, "qwen_generator")
182
+ os.makedirs(qwen_save_path, exist_ok=True)
183
+ self.qwen_generator.save_pretrained(qwen_save_path)
184
+ print(f"Qwen generator (model & tokenizer) saved in {qwen_save_path}")
185
+
186
+ # 3. Save ALMQwenModel_HF configurations
187
+ config = {
188
+ "alm_config": self.alm_config,
189
+ # Store relative paths for qwen model/tokenizer for portability
190
+ "qwen_model_name_or_path": "qwen_generator/qwen_model", # Relative path
191
+ "qwen_tokenizer_path": "qwen_generator/qwen_tokenizer", # Relative path
192
+ "top_k_buckets": self.top_k_buckets,
193
+ "top_k_items_per_bucket": self.top_k_items_per_bucket
194
+ }
195
+ config_path = os.path.join(save_directory, "alm_qwen_hf_config.json")
196
+ with open(config_path, 'w') as f:
197
+ json.dump(config, f, indent=4)
198
+ print(f"ALMQwenModel_HF configuration saved to {config_path}")
199
+ print(f"Model saved successfully to {save_directory}")
200
+
201
+ @classmethod
202
+ def load_model(cls, load_directory: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
203
+ """Loads an ALMQwenModel_HF from the specified directory."""
204
+ print(f"Loading model from {load_directory}...")
205
+
206
+ # 1. Load ALMQwenModel_HF configurations
207
+ config_path = os.path.join(load_directory, "alm_qwen_hf_config.json")
208
+ if not os.path.exists(config_path):
209
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
210
+ with open(config_path, 'r') as f:
211
+ config = json.load(f)
212
+
213
+ alm_config = config["alm_config"]
214
+ # Construct absolute paths for qwen model and tokenizer from saved relative paths
215
+ qwen_model_path = os.path.join(load_directory, config["qwen_model_name_or_path"])
216
+ qwen_tokenizer_path = os.path.join(load_directory, config["qwen_tokenizer_path"])
217
+ top_k_buckets = config["top_k_buckets"]
218
+ top_k_items_per_bucket = config["top_k_items_per_bucket"]
219
+
220
+ # 2. Instantiate the model (QwenGenerator will load its components from the paths)
221
+ model = cls(
222
+ alm_config=alm_config,
223
+ qwen_model_name_or_path=qwen_model_path,
224
+ qwen_tokenizer_path=qwen_tokenizer_path,
225
+ device=device,
226
+ top_k_buckets=top_k_buckets,
227
+ top_k_items_per_bucket=top_k_items_per_bucket
228
+ )
229
+ print("ALMQwenModel_HF structure initialized.")
230
+
231
+ # 3. Load ALM layer state_dict
232
+ alm_state_dict_path = os.path.join(load_directory, "alm_layer_state_dict.pth")
233
+ if not os.path.exists(alm_state_dict_path):
234
+ raise FileNotFoundError(f"ALM state_dict not found: {alm_state_dict_path}")
235
+
236
+ # Ensure the model's ALM layer is on the correct device before loading state_dict
237
+ model.alm_layer.to(device)
238
+ state_dict = torch.load(alm_state_dict_path, map_location=device)
239
+ model.alm_layer.load_state_dict(state_dict)
240
+ print(f"ALM layer state_dict loaded from {alm_state_dict_path}")
241
+
242
+ # Qwen model is already loaded by QwenGenerator instantiation on the correct device due to device_map="auto"
243
+ # or manually if we passed device to QwenGenerator more directly.
244
+ # If device_map="auto" was used, it might be on multiple devices.
245
+ # Ensure the overall model object has its device attribute set.
246
+ model.device = device # Ensure the main model object knows its primary device.
247
+
248
+ print(f"Model loaded successfully from {load_directory} and placed on device: {device}")
249
+ return model
250
+
251
+
252
+ # ========================= Example Usage for ALM-Qwen with Hugging Face =========================
253
+ if __name__ == "__main__":
254
+ print("\n--- Testing ALM-Qwen with Hugging Face Qwen ---")
255
+
256
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
257
+ print(f"Using device: {_device}")
258
+
259
+ # --- Hyperparameters ---
260
+ _batch_size = 1 # Reduced for faster testing
261
+ _alm_query_dim = 128
262
+ _alm_memory_dim = 64
263
+ _alm_embed_dim = 256
264
+ _alm_num_heads = 8
265
+ _alm_output_dim = 128
266
+ _num_kb_buckets = 3
267
+ _max_kb_items_per_bucket = 5
268
+
269
+ _alm_config_example = {
270
+ 'query_dim': _alm_query_dim,
271
+ 'memory_dim': _alm_memory_dim,
272
+ 'embed_dim': _alm_embed_dim,
273
+ 'num_heads': _alm_num_heads,
274
+ 'output_dim': _alm_output_dim,
275
+ 'dropout_rate': 0.0
276
+ }
277
+
278
+ _query_texts_for_qwen = ["What is attention in LLMs?"]
279
+ _query_embeddings_for_alm = torch.randn(_batch_size, _alm_query_dim)
280
+ _kb_memory_item_embeddings = torch.randn(_batch_size, _num_kb_buckets, _max_kb_items_per_bucket, _alm_memory_dim)
281
+ _kb_memory_text_items: List[List[List[str]]] = []
282
+ for b in range(_batch_size):
283
+ batch_sample_text = []
284
+ for i in range(_num_kb_buckets):
285
+ bucket_texts = [f"Doc {b+1}-B{i+1}-I{j+1}: info snippet {j}." for j in range(_max_kb_items_per_bucket)]
286
+ batch_sample_text.append(bucket_texts)
287
+ _kb_memory_text_items.append(batch_sample_text)
288
+ _kb_memory_mask = torch.ones(_batch_size, _num_kb_buckets, _max_kb_items_per_bucket, dtype=torch.bool)
289
+ _kb_memory_mask[:, :, -1:] = False # Mask last item
290
+
291
+ _qwen_model_name = "Qwen/Qwen2.5-0.5B-Instruct" # Smaller model for testing
292
+
293
+ try:
294
+ # --- Instantiate Original Model ---
295
+ print("\n--- Creating and testing original model ---")
296
+ original_model = ALMQwenModel_HF(
297
+ alm_config=_alm_config_example,
298
+ qwen_model_name_or_path=_qwen_model_name,
299
+ device=_device,
300
+ top_k_buckets=2,
301
+ top_k_items_per_bucket=1
302
+ )
303
+
304
+ # Optional: Dummy forward pass to ensure everything is initialized
305
+ # (especially lazy initializations if any, though not typical here)
306
+ _ = original_model(
307
+ _query_texts_for_qwen, _query_embeddings_for_alm, _kb_memory_item_embeddings,
308
+ _kb_memory_text_items, _kb_memory_mask
309
+ )
310
+ print("Original model created and tested with a dummy pass.")
311
+
312
+ # --- Save the Model ---
313
+ save_dir = "./saved_alm_qwen_model"
314
+ print(f"\n--- Saving model to {save_dir} ---")
315
+ original_model.save_model(save_dir)
316
+
317
+ # --- Load the Model ---
318
+ print(f"\n--- Loading model from {save_dir} ---")
319
+ # Ensure to pass the target device for loading
320
+ loaded_model = ALMQwenModel_HF.load_model(save_dir, device=_device)
321
+ print("Model loaded successfully.")
322
+
323
+ # --- Test Loaded Model ---
324
+ print("\n--- Testing loaded model ---")
325
+ generated_answers, _, _ = loaded_model(
326
+ _query_texts_for_qwen,
327
+ _query_embeddings_for_alm,
328
+ _kb_memory_item_embeddings,
329
+ _kb_memory_text_items,
330
+ _kb_memory_mask
331
+ )
332
+ print("\n--- Results from Loaded Model ---")
333
+ for i in range(len(_query_texts_for_qwen)):
334
+ print(f"Query {i+1}: {_query_texts_for_qwen[i]}")
335
+ print(f" Generated Answer {i+1}: {generated_answers[i]}")
336
+ print("-" * 30)
337
+
338
+ print("\nSave and Load test completed.")
339
+
340
+ except ImportError as e:
341
+ print(f"ImportError: {e}.")
342
+ except Exception as e:
343
+ print(f"An error occurred: {e}")
344
+ import traceback
345
+ traceback.print_exc()
346
+
347
+ # --- END OF FILE alm_qwen_hf.py ---