File size: 985 Bytes
3f53a53
469baf9
122aed7
a1a3bd9
469baf9
ea46755
469baf9
 
a1a3bd9
17fffd5
ea46755
469baf9
ea46755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122aed7
ea46755
 
 
 
 
 
 
469baf9
 
ea46755
17fffd5
 
469baf9
17fffd5
 
3f53a53
ea46755
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_name = "facebook/blenderbot-400M-distill"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def chat_function(message, history):

    history_text = ""

    # Keep only last 2 exchanges
    for pair in history[-2:]:
        if pair[0] and pair[1]:
            history_text += pair[0] + " " + pair[1] + " "

    input_text = history_text + message

    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=128
    )

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=60
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return response


demo = gr.ChatInterface(
    fn=chat_function,
    title="BlenderBot Chat",
    description="Ask me anything!"
)

demo.launch()