kto ya commited on
Commit
912f8e2
·
verified ·
1 Parent(s): 2f38200

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import gradio as gr
4
+
5
+ # Check if CUDA is available and set the device
6
+ use_cuda = torch.cuda.is_available()
7
+ device = torch.device("cuda" if use_cuda else "cpu")
8
+
9
+ # Load the tokenizer and model
10
+ t5_tokenizer = transformers.GPT2Tokenizer.from_pretrained("SiberiaSoft/SiberianPersonaFred-2")
11
+ t5_model = transformers.T5ForConditionalGeneration.from_pretrained("SiberiaSoft/SiberianPersonaFred-2").to(device)
12
+
13
+ # Define the function to generate responses
14
+ def generate_response(dialog):
15
+ dialog = dialog.strip()
16
+ if len(dialog) == 0:
17
+ return "Please enter a message."
18
+
19
+ dialog = dialog[0].upper() + dialog[1:]
20
+ prompt = '<SC6>Я культурный гопник, ем семечки, люблю играть в доту 2 на любимом герое инвокере, говорю без матов и маты не люблю недавно у меня был такой диалог:' + '\nТы: ' + dialog + '\nЯ: <extra_id_0>'
21
+
22
+ input_ids = t5_tokenizer(prompt, return_tensors='pt').input_ids.to(device)
23
+ out_ids = t5_model.generate(input_ids=input_ids, do_sample=True, temperature=0.9, max_new_tokens=512, top_p=0.85, top_k=2, repetition_penalty=1.2)
24
+
25
+ t5_output = t5_tokenizer.decode(out_ids[0][1:])
26
+ if '</s>' in t5_output:
27
+ t5_output = t5_output[:t5_output.find('</s>')].strip()
28
+ t5_output = t5_output.replace('<extra_id_0>', '').strip()
29
+ t5_output = t5_output.split('Собеседник')[0].strip()
30
+
31
+ return t5_output
32
+
33
+ # Create a Gradio interface
34
+ iface = gr.Interface(fn=generate_response,
35
+ inputs="text",
36
+ outputs="text",
37
+ title="Siberian Persona Chatbot",
38
+ description="A chatbot that responds with a Siberian persona.")
39
+
40
+ # Launch the interface
41
+ iface.launch()