|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
from typing import List, Tuple, Dict, Any |
|
|
import os |
|
|
import json |
|
|
|
|
|
|
|
|
from ALM import AttentionLinkedMemory |
|
|
|
|
|
class QwenGenerator(nn.Module): |
|
|
def __init__(self, model_name_or_path: str, device="cuda", tokenizer_path: str = None): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.model_name_or_path = model_name_or_path |
|
|
self.tokenizer_path = tokenizer_path if tokenizer_path else model_name_or_path |
|
|
|
|
|
print(f"Loading Qwen model from: {self.model_name_or_path}...") |
|
|
print(f"Loading Qwen tokenizer from: {self.tokenizer_path}...") |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name_or_path, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.tokenizer.padding_side = "left" |
|
|
|
|
|
print(f"Qwen model and tokenizer loaded. Model device: {self.model.device}") |
|
|
|
|
|
def format_prompt(self, query: str, context_snippets: List[str]) -> str: |
|
|
if context_snippets: |
|
|
context_str = "\n".join(f"- {cs}" for cs in context_snippets) |
|
|
|
|
|
final_prompt_str = "<|im_start|>system\nYou are a helpful assistant. Use the provided context to answer the user's query. If the context is insufficient, say so.<|im_end|>\n" |
|
|
final_prompt_str += "<|im_start|>user\n" |
|
|
final_prompt_str += f"Context:\n{context_str}\n\n" |
|
|
final_prompt_str += f"Query:\n{query}\n<|im_end|>\n<|im_start|>assistant\n" |
|
|
else: |
|
|
|
|
|
final_prompt_str = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" |
|
|
final_prompt_str += "<|im_start|>user\n" |
|
|
final_prompt_str += f"Query:\n{query}\n<|im_end|>\n<|im_start|>assistant\n" |
|
|
return final_prompt_str |
|
|
|
|
|
def generate(self, prompts: List[str], max_new_tokens: int = 150, **kwargs) -> List[str]: |
|
|
self.model.eval() |
|
|
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=2048) |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
do_sample=kwargs.get("do_sample", True), |
|
|
temperature=kwargs.get("temperature", 0.7), |
|
|
top_p=kwargs.get("top_p", 0.9), |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
decoded_outputs = [] |
|
|
for i, output_ids in enumerate(outputs): |
|
|
prompt_len = inputs['input_ids'][i].shape[0] |
|
|
generated_ids = output_ids[prompt_len:] |
|
|
decoded_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
decoded_outputs.append(decoded_text.strip()) |
|
|
|
|
|
return decoded_outputs |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
"""Saves the Qwen model and tokenizer to a directory.""" |
|
|
model_save_path = os.path.join(save_directory, "qwen_model") |
|
|
tokenizer_save_path = os.path.join(save_directory, "qwen_tokenizer") |
|
|
|
|
|
print(f"Saving Qwen model to {model_save_path}") |
|
|
self.model.save_pretrained(model_save_path) |
|
|
print(f"Saving Qwen tokenizer to {tokenizer_save_path}") |
|
|
self.tokenizer.save_pretrained(tokenizer_save_path) |
|
|
|
|
|
class ALMQwenModel_HF(nn.Module): |
|
|
def __init__(self, |
|
|
alm_config: Dict[str, Any], |
|
|
qwen_model_name_or_path: str, |
|
|
qwen_tokenizer_path: str = None, |
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
|
top_k_buckets: int = 3, |
|
|
top_k_items_per_bucket: int = 2): |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.alm_config = alm_config |
|
|
self.qwen_model_name_or_path = qwen_model_name_or_path |
|
|
self.qwen_tokenizer_path = qwen_tokenizer_path |
|
|
self.top_k_buckets = top_k_buckets |
|
|
self.top_k_items_per_bucket = top_k_items_per_bucket |
|
|
|
|
|
self.alm_layer = AttentionLinkedMemory(**alm_config).to(device) |
|
|
self.qwen_generator = QwenGenerator( |
|
|
model_name_or_path=qwen_model_name_or_path, |
|
|
device=device, |
|
|
tokenizer_path=qwen_tokenizer_path |
|
|
) |
|
|
|
|
|
def forward(self, |
|
|
query_texts: List[str], |
|
|
query_embeddings_for_alm: torch.Tensor, |
|
|
memory_item_embeddings: torch.Tensor, |
|
|
memory_text_items: List[List[List[str]]], |
|
|
memory_mask: torch.Tensor = None |
|
|
) -> Tuple[List[str], torch.Tensor, torch.Tensor]: |
|
|
self.alm_layer.eval() |
|
|
query_embeddings_for_alm = query_embeddings_for_alm.to(self.device) |
|
|
memory_item_embeddings = memory_item_embeddings.to(self.device) |
|
|
if memory_mask is not None: |
|
|
memory_mask = memory_mask.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
_, bucket_att_weights, item_att_weights = self.alm_layer( |
|
|
query_embeddings_for_alm, memory_item_embeddings, memory_mask |
|
|
) |
|
|
|
|
|
batch_retrieved_texts: List[List[str]] = [] |
|
|
for b_idx in range(len(query_texts)): |
|
|
retrieved_for_sample: List[str] = [] |
|
|
current_bucket_weights = bucket_att_weights[b_idx] |
|
|
_, top_bucket_indices = torch.topk(current_bucket_weights, |
|
|
k=min(self.top_k_buckets, current_bucket_weights.size(0))) |
|
|
|
|
|
for bucket_idx in top_bucket_indices: |
|
|
bucket_idx_item = bucket_idx.item() |
|
|
current_item_weights = item_att_weights[b_idx, bucket_idx_item, :] |
|
|
|
|
|
if memory_mask is not None: |
|
|
item_m = memory_mask[b_idx, bucket_idx_item, :] |
|
|
current_item_weights = current_item_weights.masked_fill(item_m == 0, -float('inf')) |
|
|
|
|
|
num_valid_items = (current_item_weights > -float('inf')).sum().item() |
|
|
if num_valid_items == 0: continue |
|
|
|
|
|
_, top_item_indices_in_bucket = torch.topk(current_item_weights, |
|
|
k=min(self.top_k_items_per_bucket, num_valid_items)) |
|
|
|
|
|
for item_idx_in_bucket in top_item_indices_in_bucket: |
|
|
item_idx_in_bucket_item = item_idx_in_bucket.item() |
|
|
if memory_mask is not None and not memory_mask[b_idx, bucket_idx_item, item_idx_in_bucket_item]: |
|
|
continue |
|
|
try: |
|
|
text_content = memory_text_items[b_idx][bucket_idx_item][item_idx_in_bucket_item] |
|
|
if text_content: |
|
|
retrieved_for_sample.append(text_content) |
|
|
except IndexError: |
|
|
print(f"Warning: IndexError accessing memory_text_items[{b_idx}][{bucket_idx_item}][{item_idx_in_bucket_item}]") |
|
|
continue |
|
|
batch_retrieved_texts.append(list(dict.fromkeys(retrieved_for_sample))) |
|
|
|
|
|
prompts_for_qwen = [] |
|
|
for i, q_text in enumerate(query_texts): |
|
|
prompt = self.qwen_generator.format_prompt(q_text, batch_retrieved_texts[i]) |
|
|
prompts_for_qwen.append(prompt) |
|
|
|
|
|
generated_answers = self.qwen_generator.generate(prompts_for_qwen) |
|
|
return generated_answers, bucket_att_weights, item_att_weights |
|
|
|
|
|
def save_model(self, save_directory: str): |
|
|
"""Saves the entire ALMQwenModel_HF to the specified directory.""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
alm_state_dict_path = os.path.join(save_directory, "alm_layer_state_dict.pth") |
|
|
torch.save(self.alm_layer.state_dict(), alm_state_dict_path) |
|
|
print(f"ALM layer state_dict saved to {alm_state_dict_path}") |
|
|
|
|
|
|
|
|
qwen_save_path = os.path.join(save_directory, "qwen_generator") |
|
|
os.makedirs(qwen_save_path, exist_ok=True) |
|
|
self.qwen_generator.save_pretrained(qwen_save_path) |
|
|
print(f"Qwen generator (model & tokenizer) saved in {qwen_save_path}") |
|
|
|
|
|
|
|
|
config = { |
|
|
"alm_config": self.alm_config, |
|
|
|
|
|
"qwen_model_name_or_path": "qwen_generator/qwen_model", |
|
|
"qwen_tokenizer_path": "qwen_generator/qwen_tokenizer", |
|
|
"top_k_buckets": self.top_k_buckets, |
|
|
"top_k_items_per_bucket": self.top_k_items_per_bucket |
|
|
} |
|
|
config_path = os.path.join(save_directory, "alm_qwen_hf_config.json") |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config, f, indent=4) |
|
|
print(f"ALMQwenModel_HF configuration saved to {config_path}") |
|
|
print(f"Model saved successfully to {save_directory}") |
|
|
|
|
|
@classmethod |
|
|
def load_model(cls, load_directory: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"): |
|
|
"""Loads an ALMQwenModel_HF from the specified directory.""" |
|
|
print(f"Loading model from {load_directory}...") |
|
|
|
|
|
|
|
|
config_path = os.path.join(load_directory, "alm_qwen_hf_config.json") |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"Configuration file not found: {config_path}") |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
alm_config = config["alm_config"] |
|
|
|
|
|
qwen_model_path = os.path.join(load_directory, config["qwen_model_name_or_path"]) |
|
|
qwen_tokenizer_path = os.path.join(load_directory, config["qwen_tokenizer_path"]) |
|
|
top_k_buckets = config["top_k_buckets"] |
|
|
top_k_items_per_bucket = config["top_k_items_per_bucket"] |
|
|
|
|
|
|
|
|
model = cls( |
|
|
alm_config=alm_config, |
|
|
qwen_model_name_or_path=qwen_model_path, |
|
|
qwen_tokenizer_path=qwen_tokenizer_path, |
|
|
device=device, |
|
|
top_k_buckets=top_k_buckets, |
|
|
top_k_items_per_bucket=top_k_items_per_bucket |
|
|
) |
|
|
print("ALMQwenModel_HF structure initialized.") |
|
|
|
|
|
|
|
|
alm_state_dict_path = os.path.join(load_directory, "alm_layer_state_dict.pth") |
|
|
if not os.path.exists(alm_state_dict_path): |
|
|
raise FileNotFoundError(f"ALM state_dict not found: {alm_state_dict_path}") |
|
|
|
|
|
|
|
|
model.alm_layer.to(device) |
|
|
state_dict = torch.load(alm_state_dict_path, map_location=device) |
|
|
model.alm_layer.load_state_dict(state_dict) |
|
|
print(f"ALM layer state_dict loaded from {alm_state_dict_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.device = device |
|
|
|
|
|
print(f"Model loaded successfully from {load_directory} and placed on device: {device}") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n--- Testing ALM-Qwen with Hugging Face Qwen ---") |
|
|
|
|
|
_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {_device}") |
|
|
|
|
|
|
|
|
_batch_size = 1 |
|
|
_alm_query_dim = 128 |
|
|
_alm_memory_dim = 64 |
|
|
_alm_embed_dim = 256 |
|
|
_alm_num_heads = 8 |
|
|
_alm_output_dim = 128 |
|
|
_num_kb_buckets = 3 |
|
|
_max_kb_items_per_bucket = 5 |
|
|
|
|
|
_alm_config_example = { |
|
|
'query_dim': _alm_query_dim, |
|
|
'memory_dim': _alm_memory_dim, |
|
|
'embed_dim': _alm_embed_dim, |
|
|
'num_heads': _alm_num_heads, |
|
|
'output_dim': _alm_output_dim, |
|
|
'dropout_rate': 0.0 |
|
|
} |
|
|
|
|
|
_query_texts_for_qwen = ["What is attention in LLMs?"] |
|
|
_query_embeddings_for_alm = torch.randn(_batch_size, _alm_query_dim) |
|
|
_kb_memory_item_embeddings = torch.randn(_batch_size, _num_kb_buckets, _max_kb_items_per_bucket, _alm_memory_dim) |
|
|
_kb_memory_text_items: List[List[List[str]]] = [] |
|
|
for b in range(_batch_size): |
|
|
batch_sample_text = [] |
|
|
for i in range(_num_kb_buckets): |
|
|
bucket_texts = [f"Doc {b+1}-B{i+1}-I{j+1}: info snippet {j}." for j in range(_max_kb_items_per_bucket)] |
|
|
batch_sample_text.append(bucket_texts) |
|
|
_kb_memory_text_items.append(batch_sample_text) |
|
|
_kb_memory_mask = torch.ones(_batch_size, _num_kb_buckets, _max_kb_items_per_bucket, dtype=torch.bool) |
|
|
_kb_memory_mask[:, :, -1:] = False |
|
|
|
|
|
_qwen_model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
|
|
|
try: |
|
|
|
|
|
print("\n--- Creating and testing original model ---") |
|
|
original_model = ALMQwenModel_HF( |
|
|
alm_config=_alm_config_example, |
|
|
qwen_model_name_or_path=_qwen_model_name, |
|
|
device=_device, |
|
|
top_k_buckets=2, |
|
|
top_k_items_per_bucket=1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
_ = original_model( |
|
|
_query_texts_for_qwen, _query_embeddings_for_alm, _kb_memory_item_embeddings, |
|
|
_kb_memory_text_items, _kb_memory_mask |
|
|
) |
|
|
print("Original model created and tested with a dummy pass.") |
|
|
|
|
|
|
|
|
save_dir = "./saved_alm_qwen_model" |
|
|
print(f"\n--- Saving model to {save_dir} ---") |
|
|
original_model.save_model(save_dir) |
|
|
|
|
|
|
|
|
print(f"\n--- Loading model from {save_dir} ---") |
|
|
|
|
|
loaded_model = ALMQwenModel_HF.load_model(save_dir, device=_device) |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
|
print("\n--- Testing loaded model ---") |
|
|
generated_answers, _, _ = loaded_model( |
|
|
_query_texts_for_qwen, |
|
|
_query_embeddings_for_alm, |
|
|
_kb_memory_item_embeddings, |
|
|
_kb_memory_text_items, |
|
|
_kb_memory_mask |
|
|
) |
|
|
print("\n--- Results from Loaded Model ---") |
|
|
for i in range(len(_query_texts_for_qwen)): |
|
|
print(f"Query {i+1}: {_query_texts_for_qwen[i]}") |
|
|
print(f" Generated Answer {i+1}: {generated_answers[i]}") |
|
|
print("-" * 30) |
|
|
|
|
|
print("\nSave and Load test completed.") |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"ImportError: {e}.") |
|
|
except Exception as e: |
|
|
print(f"An error occurred: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|