ChatGPTWebsite / app.py
Udyan's picture
Update app.py
ea46755 verified
raw
history blame contribute delete
985 Bytes
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = "facebook/blenderbot-400M-distill"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def chat_function(message, history):
history_text = ""
# Keep only last 2 exchanges
for pair in history[-2:]:
if pair[0] and pair[1]:
history_text += pair[0] + " " + pair[1] + " "
input_text = history_text + message
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=128
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=60
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
demo = gr.ChatInterface(
fn=chat_function,
title="BlenderBot Chat",
description="Ask me anything!"
)
demo.launch()