File size: 1,092 Bytes
53b65c7
4e97d70
 
53b65c7
4e97d70
 
53b65c7
4e97d70
 
 
 
53b65c7
4e97d70
 
 
53b65c7
4e97d70
 
 
53b65c7
4e97d70
 
53b65c7
4e97d70
 
 
 
 
53b65c7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

def respond(message, chat_history):
    # Append user message to chat history
    chat_history = chat_history or []
    chat_history.append(message)

    # Prepare input (concatenate previous chat for context)
    input_text = " ".join(chat_history)
    input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')

    # Generate response
    output_ids = model.generate(input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    output_text = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

    chat_history.append(output_text)
    return output_text, chat_history

with gr.Blocks() as demo:
    chat_history = gr.State([])
    chatbot = gr.Chatbot()
    msg = gr.Textbox(placeholder="Ask me anything...")
    msg.submit(respond, [msg, chat_history], [chatbot, chat_history])

if __name__ == "__main__":
    demo.launch()