File size: 6,583 Bytes
e4bbd30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from transformers import AutoModelForCausalLM, AutoTokenizer
from .templates import make_prompt, check_template_name, print_templates, supported_templates
from .config import model_dict
from torch.nn import CrossEntropyLoss
        
LOSS_FCT = CrossEntropyLoss(reduction='mean')

class ppluie:
    def __init__(
            self, 
            model,
            device = "cuda:0",            
            template = "FS-DIRECT",
            use_chat_template = True,
            half_mode = True,
            n_right_specials_tokens = 1
        ):

        self.device = device
        self.use_chat_tmplt = use_chat_template

        # assert model in supported_models, f""+model+" not support, supported models: "+str(supported_models)
        if model not in model_dict.keys():
            print("You call ParaPLUIE with "+model+".\nParaPLUIE has been tested with "+str(model_dict.keys())+"\n Using it with another model could lead to unexpected behaviour.")
            self.model = AutoModelForCausalLM.from_pretrained(model)
            self.tokenizer = AutoTokenizer.from_pretrained(model, padding_side='left')
            self.n_right_special_tokens = n_right_specials_tokens
        # use configs
        else: 
            self.model = AutoModelForCausalLM.from_pretrained(
                model_dict[model]["path"],
                trust_remote_code = model_dict[model]["trust_remote_code"]
            )
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_dict[model]["path"],
                trust_remote_code = model_dict[model]["trust_remote_code"],
                padding_side = 'left'
            )
            self.n_right_special_tokens = model_dict[model]["r_spe_tokens"]
            self.use_chat_tmplt = model_dict[model]["use_chat_tmplt"]
        
        if half_mode:
            self.model = self.model.half()
        self.model = self.model.eval()
        self.model = self.model.to(self.device)
        self.setTemplate(template)    

    def show_templates(self):
        print_templates()

    def show_available_models(self):
        print("LLM tested with PPLUIE: ")
        for k in (model_dict.keys()):
            print(k)

    def setTemplate(self, template: str):
        check_template_name(template)
        self.template = template

    def stringify_prompt(self, prompt):
        stringify_prompt = ""
        for v in prompt[:-1]:
            stringify_prompt += v["content"] + " "
        stringify_prompt += prompt[-1]["content"]
        return stringify_prompt

    def vraisemblance(self, promptY, promptN): # opti une seul inférence
        #check model que guillaume a pas fait
        if self.use_chat_tmplt:
            input_model = self.tokenizer.apply_chat_template(promptY, return_tensors="pt", padding=False).to(self.device)[:,:-self.n_right_special_tokens]

            encodedsY = self.tokenizer.apply_chat_template(promptY, return_tensors="pt", padding=False).to(self.device)
            encodedsN = self.tokenizer.apply_chat_template(promptN, return_tensors="pt", padding=False).to(self.device)
            #on supprime le/s token de fin de sentence 
            encodedsY = encodedsY[:,1:-self.n_right_special_tokens]
            encodedsN = encodedsN[:,1:-self.n_right_special_tokens]
        else:
            input_model = self.tokenizer(self.stringify_prompt(promptY), return_tensors="pt", padding=False)["input_ids"].to(self.device)

            encodedsY = self.tokenizer(self.stringify_prompt(promptY), return_tensors="pt", padding=False)["input_ids"].to(self.device)
            encodedsN = self.tokenizer(self.stringify_prompt(promptN), return_tensors="pt", padding=False)["input_ids"].to(self.device)
            #on supprime le/s token de fin de sentence 
            encodedsY = encodedsY[:,1:]
            encodedsN = encodedsN[:,1:]

        generate_ids = self.model(
            input_ids=input_model[:,:-1],
            return_dict=True
        )

        n_tokens = len(input_model[0])
        generate_ids = generate_ids["logits"].squeeze().float()

        loss_yes = LOSS_FCT(generate_ids, encodedsY.view(-1)) * n_tokens
        loss_no = LOSS_FCT(generate_ids, encodedsN.view(-1)) * n_tokens

        loss = loss_no.item() - loss_yes.item()

        return loss
    
    def chech_end_tokens_tmpl(self):

        prompt_yes, prompt_no = make_prompt(
            self.template,
            "this is a test",
            "this is a test",
            self.model, # if intermediate generation is needed
            self.tokenizer, # if intermediate generation is needed
            self.device # if intermediate generation is needed
        )

        if self.use_chat_tmplt:
            enc = self.tokenizer.apply_chat_template(prompt_yes, return_tensors="pt", padding=False)[0][-10:]
        else:
            enc = self.tokenizer(self.stringify_prompt(prompt_yes), return_tensors="pt", padding=False)["input_ids"][0][-10:]

        print("Yes prompt:")
        print(enc)
        for t in enc:
            print(t, " - ", self.tokenizer.decode(t))

        if self.use_chat_tmplt:
            enc = self.tokenizer.apply_chat_template(prompt_no, return_tensors="pt", padding=False)[0][-10:]
        else:
            enc = self.tokenizer(self.stringify_prompt(prompt_no), return_tensors="pt", padding=False)["input_ids"][0][-10:]

        print("No prompt:")
        print(enc)
        for t in enc:
            print(t, " - ", self.tokenizer.decode(t))
    

    def __call__(self, reference, hypothese, logger=None):
        prompt_yes, prompt_no = make_prompt(
            self.template,
            reference,
            hypothese,
            self.model, # if intermediate generation is needed
            self.tokenizer, # if intermediate generation is needed
            self.device # if intermediate generation is needed
        )

        score = self.vraisemblance(
            promptY=prompt_yes,
            promptN=prompt_no)
        # self.chech_end_tokens_tmpl(prompt_yes)

        if logger:
            logger.info('Start entrie')
            logger.info('Template: '+self.template)
            logger.info('Yes prompt : '+str(prompt_yes))
            logger.info('Yes vraisemblance : '+str(self.vraisemblance(prompt_yes)))
            logger.info('No prompt : '+str(prompt_no))
            logger.info('No vraisemblance : '+str(self.vraisemblance(prompt_no)))
            logger.info('Score : '+str(score))
            logger.info('End entrie')
            
        return score
    
    def get_all_templates(self):
        return supported_templates