|
|
import torch |
|
|
|
|
|
from peft import ( |
|
|
LoraConfig, |
|
|
get_peft_model, |
|
|
) |
|
|
|
|
|
from torch import nn |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
from .bert import CosNorm_Classifier |
|
|
|
|
|
activation_map = {'relu': nn.ReLU(), 'tanh': nn.Tanh()} |
|
|
|
|
|
class LLAMA_lora_Disaware(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.num_labels = args.num_labels |
|
|
self.llama = AutoModelForCausalLM.from_pretrained( |
|
|
args.llama_model, |
|
|
return_dict=True, |
|
|
load_in_8bit=False, |
|
|
device_map=args.device, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=torch.float16, |
|
|
) |
|
|
self.llama.config.pad_token_id = 0 |
|
|
self.llama.config.bos_token_id = 1 |
|
|
self.llama.config.eos_token_id = 2 |
|
|
|
|
|
target_modules=[ "q_proj", "v_proj"] |
|
|
config = LoraConfig( |
|
|
r=4, |
|
|
lora_alpha=8, |
|
|
target_modules=target_modules, |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
print("lora", config) |
|
|
self.llama = get_peft_model(self.llama, config) |
|
|
hidden_dropout_prob = 0.1 |
|
|
hidden_size = self.llama.config.hidden_size |
|
|
hidden_size_2 = hidden_size // 2 |
|
|
self.dense = nn.Linear(hidden_size, hidden_size).half() |
|
|
self.activation = activation_map[args.activation] |
|
|
self.dropout = nn.Dropout(hidden_dropout_prob).half() |
|
|
self.dense = self.dense.to(args.device) |
|
|
self.activation = self.activation.to(args.device) |
|
|
self.dropout = self.dropout.to(args.device) |
|
|
|
|
|
self.cosnorm_classifier = CosNorm_Classifier( |
|
|
hidden_size, args.num_labels, args.scale, args.device) |
|
|
|
|
|
|
|
|
def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, |
|
|
feature_ext=False, mode=None, loss_fct=None, centroids=None, dist_infos = None): |
|
|
outputs = self.llama(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ) |
|
|
encoded_layer_ = outputs.hidden_states[-1].mean(dim=1) |
|
|
|
|
|
|
|
|
pooled_output = self.dense(encoded_layer_) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
x = pooled_output |
|
|
|
|
|
if feature_ext: |
|
|
return pooled_output |
|
|
|
|
|
else: |
|
|
|
|
|
feat_size = x.shape[1] |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
f_expand = x.unsqueeze(1).expand(-1, self.num_labels, -1) |
|
|
centroids_expand = centroids.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
dist_cur = torch.norm(f_expand - centroids_expand, 2, 2) |
|
|
values_nn, labels_nn = torch.sort(dist_cur, 1) |
|
|
|
|
|
nearest_centers = centroids[labels_nn[:, 0]] |
|
|
dist_denominator = torch.norm(x - nearest_centers, 2, 1) |
|
|
second_nearest_centers = centroids[labels_nn[:, 1]] |
|
|
dist_numerator = torch.norm(x - second_nearest_centers, 2, 1) |
|
|
|
|
|
dist_info = dist_numerator - dist_denominator |
|
|
dist_info = torch.exp(dist_info) |
|
|
scalar = dist_info |
|
|
|
|
|
reachability = scalar.unsqueeze(1).expand(-1, feat_size) |
|
|
x = reachability * pooled_output |
|
|
|
|
|
logits = self.cosnorm_classifier(x) |
|
|
|
|
|
if mode == 'train': |
|
|
loss = loss_fct(logits, labels) |
|
|
return loss |
|
|
|
|
|
elif mode == 'eval': |
|
|
return pooled_output, logits |