| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from unidecode import unidecode | |
| from collections import Counter | |
| import torch | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import re | |
| import string | |
| from peft import PeftModel, PeftConfig | |
| tokenizer = AutoTokenizer.from_pretrained("osiria/primo") | |
| model = AutoModelForCausalLM.from_pretrained("osiria/primo") | |
| model = PeftModel.from_pretrained(model, "osiria/primo") | |
| class Prime: | |
| def __init__(self, tokenizer, model): | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| def _check_sublist(self, lst, sub_lst, sep = " "): | |
| l_type = type(lst[0]) | |
| lst = sep.join(list(map(str, lst))) | |
| sub_lst = sep.join(list(map(str, sub_lst))) | |
| return sub_lst in lst | |
| def _exclude_sublist(self, lst, sub_lst, sep = " "): | |
| l_type = type(lst[0]) | |
| lst = sep.join(list(map(str, lst))) | |
| sub_lst = sep.join(list(map(str, sub_lst))) | |
| lst = re.sub("\s+", " ", lst.replace(sub_lst, "")).strip().split(sep) | |
| lst = list(map(l_type, lst)) | |
| return lst | |
| def generate(self, prompt, message = "", sep = " [AI]", max_tokens = 100, excluded = [[40, 19]], | |
| lookback = 5, resample_tokens = [27793], replace_tokens = {11302: 23318}, | |
| stop_tokens = [239], | |
| sample = False, | |
| top_k = 5): | |
| if message: | |
| prompt = message + ". " + prompt | |
| prompt = prompt.replace("“", '"').replace("”", '"').replace("’", "'") | |
| if not sample: | |
| top_k = 2 | |
| tokens = tokenizer.encode("[HUMAN] " + prompt + sep) | |
| tokens_generated = [] | |
| checkpoint = 0 | |
| while tokens[-1] not in stop_tokens and len(tokens_generated) < max_tokens: | |
| output = model.forward(input_ids=torch.tensor([tokens]).to(device)).logits[0,-1] | |
| output = torch.softmax(output, dim = 0) | |
| candidates = torch.topk(output, k = top_k) | |
| if sample: | |
| indices = candidates.indices | |
| scores = candidates.values | |
| next_token = indices[torch.multinomial(scores, 1)[0].item()] | |
| else: | |
| next_token = candidates.indices[0] | |
| next_token = next_token.item() | |
| sub_tokens = tokens_generated[-lookback:] + [next_token] | |
| if next_token in resample_tokens: | |
| next_token = candidates.indices[1] | |
| next_token = next_token.item() | |
| if len(tokens_generated) >= (lookback + 1) and next_token in tokens_generated[-2:]: | |
| next_token = candidates.indices[1] | |
| next_token = next_token.item() | |
| elif len(tokens_generated) >= lookback and self._check_sublist(tokens_generated, sub_tokens): | |
| if checkpoint: | |
| tokens = tokens[:checkpoint] | |
| break | |
| else: | |
| next_token = candidates.indices[1] | |
| next_token = next_token.item() | |
| sample = True | |
| if next_token in replace_tokens: | |
| next_token = replace_tokens[next_token] | |
| tokens = tokens + [next_token] | |
| tokens_generated = tokens_generated + [next_token] | |
| if next_token == 5: | |
| checkpoint = len(tokens) | |
| for ex_lst in excluded: | |
| tokens = self._exclude_sublist(tokens, ex_lst) | |
| output = tokenizer.decode(tokens, skip_special_tokens=True) | |
| output = output.split(sep)[-1].strip() | |
| output = output[0].upper() + output[1:] | |
| if output[-1] == tokenizer.decode(stop_tokens[0]): | |
| output = output[:-1] | |
| if len(re.findall("\d\.", output)) > 1: | |
| output = re.sub("\d\.", "<br>•", output) | |
| output = re.sub("^\<br\>", "", output) | |
| return output | |
| model.eval() | |
| device = torch.device("cuda") | |
| prime = Prime(tokenizer = tokenizer, model = model) | |
| def process_input(user_input, max_tokens, sample, top_k, message): | |
| return prime.generate(prompt = user_input, message = message, | |
| max_tokens = max_tokens, sample = sample, | |
| top_k = top_k) | |
| header = '''-------------------------------------------------------------------------------------------------- | |
| <style> | |
| .vertical-text { | |
| writing-mode: vertical-lr; | |
| text-orientation: upright; | |
| background-color:red; | |
| } | |
| </style> | |
| <center> | |
| <body> | |
| <span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span> | |
| <span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> </span> | |
| <span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;"> </span> | |
| <span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;"> </span> | |
| <span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> </span> | |
| <span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span> | |
| </body> | |
| </center> | |
| <br> | |
| <center><img src="file/primo.png" width="100"></center> | |
| ''' | |
| import gradio as gr | |
| import random | |
| import time | |
| with gr.Blocks(title="primo", css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="md", spacing_size="md")) as interface: | |
| gr.Markdown(header) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("<b>opzioni</b>") | |
| max_tokens = gr.Slider(1, 250, value=150, label="massimo numero di token", info="scegli un limite tra 1 e 250") | |
| sample = gr.Checkbox(label="campionamento") | |
| top_k = gr.Slider(1, 5, step=1, value=1, label="creatività", info="scegli un livello tra 1 e 5") | |
| message = gr.Textbox(label="messaggio di sistema", value = "") | |
| clear = gr.Button("pulisci conversazione") | |
| with gr.Column(scale=8): | |
| chatbot = gr.Chatbot(label = "prime").style(height=600) | |
| msg = gr.Textbox(label = "richiesta") | |
| def user(user_message, history): | |
| return gr.update(value="", interactive=False), history + [[user_message, None]] | |
| def bot(history, message, max_tokens, sample, top_k): | |
| bot_message = process_input(history[-1][0], message = message, max_tokens = max_tokens, | |
| sample = sample, top_k = top_k) | |
| history[-1][1] = "" | |
| for character in bot_message: | |
| history[-1][1] += character | |
| time.sleep(0.05) | |
| yield history | |
| response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot, [chatbot, message, max_tokens, sample, top_k], chatbot | |
| ) | |
| response.then(lambda: gr.update(interactive=True), None, [msg], queue=False) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown("<b>attenzione</b>") | |
| gr.Markdown("il modello potrebbe comportarsi in maniera imprevista nel caso in cui riceva prompt troppo lontani dal suo pre-training o fine-tuning e, per via della natura probabilistica del meccanismo di generazione, potrebbe occasionalmente produrre contenuti distorti o offensivi in relazione a tematiche come il genere, le etnie, le ideologie, e le convinzioni politiche o religiose<br><br>per via di queste limitazioni, il modello e i suoi output dovrebbero essere usati con cautela, e non dovrebbero essere coinvolti in contesti che richiedono che il testo generato sia corretto o veritiero") | |
| interface.queue() | |
| interface.launch() |