| | import gradio as gr |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer |
| | from threading import Thread |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained("Finisha-llm/nekolina") |
| | model = AutoModelForCausalLM.from_pretrained("Finisha-llm/nekolina") |
| |
|
| | |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model = model.to(device) |
| |
|
| |
|
| | |
| | class StopOnTokens(StoppingCriteria): |
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| | stop_ids = [2] |
| | for stop_id in stop_ids: |
| | if input_ids[0][-1] == stop_id: |
| | return True |
| | return False |
| |
|
| |
|
| |
|
| | |
| | def predict(message, history): |
| | history_transformer_format = history + [[message, ""]] |
| | stop = StopOnTokens() |
| |
|
| |
|
| | |
| | messages = "<s>".join(["</s>".join(["\nQuestion:" + item[0], "\nReponse:" + item[1]]) |
| | for item in history_transformer_format]) |
| | model_inputs = tokenizer([messages], return_tensors="pt").to(device) |
| | streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) |
| | generate_kwargs = dict( |
| | model_inputs, |
| | streamer=streamer, |
| | max_new_tokens=1024, |
| | do_sample=True, |
| | top_p=0.95, |
| | top_k=50, |
| | temperature=0.7, |
| | num_beams=1, |
| | stopping_criteria=StoppingCriteriaList([stop]) |
| | ) |
| | t = Thread(target=model.generate, kwargs=generate_kwargs) |
| | t.start() |
| | partial_message = "" |
| | for new_token in streamer: |
| | partial_message += new_token |
| | if '</s>' in partial_message: |
| | break |
| | yield partial_message |
| |
|
| |
|
| |
|
| |
|
| | |
| | gr.ChatInterface(predict, |
| | title="nekolina_chatBot", |
| | description="Ask nekolina any questions in nekolien", |
| | examples=['Ti eta?', 'In nekolien : Quiela nombra avst coeura pievia?'] |
| | ).launch() |