parapluie / ppluie.py
qlemesle's picture
test
e0a8e70
from transformers import AutoModelForCausalLM, AutoTokenizer
from .templates import make_prompt, check_template_name, print_templates, supported_templates
from .config import model_dict
from torch.nn import CrossEntropyLoss
LOSS_FCT = CrossEntropyLoss(reduction='mean')
class ppluie:
def __init__(
self,
model,
device = "cuda:0",
template = "FS-DIRECT",
use_chat_template = True,
half_mode = True,
n_right_specials_tokens = 1
):
self.device = device
self.use_chat_tmplt = use_chat_template
# assert model in supported_models, f""+model+" not support, supported models: "+str(supported_models)
if model not in model_dict.keys():
print("You call ParaPLUIE with "+model+".\nParaPLUIE has been tested with "+str(model_dict.keys())+"\n Using it with another model could lead to unexpected behaviour.")
self.model = AutoModelForCausalLM.from_pretrained(model)
self.tokenizer = AutoTokenizer.from_pretrained(model, padding_side='left')
self.n_right_special_tokens = n_right_specials_tokens
# use configs
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_dict[model]["path"],
trust_remote_code = model_dict[model]["trust_remote_code"]
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_dict[model]["path"],
trust_remote_code = model_dict[model]["trust_remote_code"],
padding_side = 'left'
)
self.n_right_special_tokens = model_dict[model]["r_spe_tokens"]
self.use_chat_tmplt = model_dict[model]["use_chat_tmplt"]
if half_mode:
self.model = self.model.half()
self.model = self.model.eval()
self.model = self.model.to(self.device)
self.setTemplate(template)
def show_templates(self):
print_templates()
def show_available_models(self):
print("LLM tested with PPLUIE: ")
for k in (model_dict.keys()):
print(k)
def setTemplate(self, template: str):
check_template_name(template)
self.template = template
def stringify_prompt(self, prompt):
stringify_prompt = ""
for v in prompt[:-1]:
stringify_prompt += v["content"] + " "
stringify_prompt += prompt[-1]["content"]
return stringify_prompt
def vraisemblance(self, promptY, promptN): # opti une seul inférence
#check model que guillaume a pas fait
if self.use_chat_tmplt:
input_model = self.tokenizer.apply_chat_template(promptY, return_tensors="pt", padding=False).to(self.device)[:,:-self.n_right_special_tokens]
encodedsY = self.tokenizer.apply_chat_template(promptY, return_tensors="pt", padding=False).to(self.device)
encodedsN = self.tokenizer.apply_chat_template(promptN, return_tensors="pt", padding=False).to(self.device)
#on supprime le/s token de fin de sentence
encodedsY = encodedsY[:,1:-self.n_right_special_tokens]
encodedsN = encodedsN[:,1:-self.n_right_special_tokens]
else:
input_model = self.tokenizer(self.stringify_prompt(promptY), return_tensors="pt", padding=False)["input_ids"].to(self.device)
encodedsY = self.tokenizer(self.stringify_prompt(promptY), return_tensors="pt", padding=False)["input_ids"].to(self.device)
encodedsN = self.tokenizer(self.stringify_prompt(promptN), return_tensors="pt", padding=False)["input_ids"].to(self.device)
#on supprime le/s token de fin de sentence
encodedsY = encodedsY[:,1:]
encodedsN = encodedsN[:,1:]
generate_ids = self.model(
input_ids=input_model[:,:-1],
return_dict=True
)
n_tokens = len(input_model[0])
generate_ids = generate_ids["logits"].squeeze().float()
loss_yes = LOSS_FCT(generate_ids, encodedsY.view(-1)) * n_tokens
loss_no = LOSS_FCT(generate_ids, encodedsN.view(-1)) * n_tokens
loss = loss_no.item() - loss_yes.item()
return loss
def chech_end_tokens_tmpl(self):
prompt_yes, prompt_no = make_prompt(
self.template,
"this is a test",
"this is a test",
self.model, # if intermediate generation is needed
self.tokenizer, # if intermediate generation is needed
self.device # if intermediate generation is needed
)
if self.use_chat_tmplt:
enc = self.tokenizer.apply_chat_template(prompt_yes, return_tensors="pt", padding=False)[0][-10:]
else:
enc = self.tokenizer(self.stringify_prompt(prompt_yes), return_tensors="pt", padding=False)["input_ids"][0][-10:]
print("Yes prompt:")
print(enc)
for t in enc:
print(t, " - ", self.tokenizer.decode(t))
if self.use_chat_tmplt:
enc = self.tokenizer.apply_chat_template(prompt_no, return_tensors="pt", padding=False)[0][-10:]
else:
enc = self.tokenizer(self.stringify_prompt(prompt_no), return_tensors="pt", padding=False)["input_ids"][0][-10:]
print("No prompt:")
print(enc)
for t in enc:
print(t, " - ", self.tokenizer.decode(t))
def __call__(self, reference, hypothese, logger=None):
prompt_yes, prompt_no = make_prompt(
self.template,
reference,
hypothese,
self.model, # if intermediate generation is needed
self.tokenizer, # if intermediate generation is needed
self.device # if intermediate generation is needed
)
score = self.vraisemblance(
promptY=prompt_yes,
promptN=prompt_no)
# self.chech_end_tokens_tmpl(prompt_yes)
if logger:
logger.info('Start entrie')
logger.info('Template: '+self.template)
logger.info('Yes prompt : '+str(prompt_yes))
logger.info('Yes vraisemblance : '+str(self.vraisemblance(prompt_yes)))
logger.info('No prompt : '+str(prompt_no))
logger.info('No vraisemblance : '+str(self.vraisemblance(prompt_no)))
logger.info('Score : '+str(score))
logger.info('End entrie')
return score
def get_all_templates(self):
return supported_templates