Update app.py
Browse files
app.py
CHANGED
|
@@ -2,18 +2,16 @@ import gradio as gr
|
|
| 2 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
model_name = "mrm8488/
|
| 6 |
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 7 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 8 |
|
| 9 |
tokenizer.pad_token = tokenizer.eos_token
|
| 10 |
-
model.config.pad_token_id =
|
| 11 |
-
|
| 12 |
-
# Warm-up
|
| 13 |
-
_ = model.generate(tokenizer.encode("Hello", return_tensors="pt"))
|
| 14 |
|
| 15 |
def chat_with_gpt2(user_input):
|
| 16 |
-
|
|
|
|
| 17 |
outputs = model.generate(
|
| 18 |
inputs,
|
| 19 |
max_length=250,
|
|
@@ -25,14 +23,14 @@ def chat_with_gpt2(user_input):
|
|
| 25 |
pad_token_id=tokenizer.eos_token_id
|
| 26 |
)
|
| 27 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 28 |
-
return response[len(
|
| 29 |
|
| 30 |
interface = gr.Interface(
|
| 31 |
fn=chat_with_gpt2,
|
| 32 |
-
inputs=gr.Textbox(lines=3, placeholder="Ask
|
| 33 |
outputs="text",
|
| 34 |
title="Trip Planner Chatbot",
|
| 35 |
-
description="Ask
|
| 36 |
)
|
| 37 |
|
| 38 |
interface.launch()
|
|
|
|
| 2 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
model_name = "mrm8488/gpt2-finetuned-jhegarty-texts"
|
| 6 |
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 7 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 8 |
|
| 9 |
tokenizer.pad_token = tokenizer.eos_token
|
| 10 |
+
model.config.pad_token_id = model.config.eos_token_id
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def chat_with_gpt2(user_input):
|
| 13 |
+
prompt = f"You are a helpful travel assistant. Answer concisely and informatively. Question: {user_input}\nAnswer:"
|
| 14 |
+
inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 15 |
outputs = model.generate(
|
| 16 |
inputs,
|
| 17 |
max_length=250,
|
|
|
|
| 23 |
pad_token_id=tokenizer.eos_token_id
|
| 24 |
)
|
| 25 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 26 |
+
return response[len(prompt):].strip()
|
| 27 |
|
| 28 |
interface = gr.Interface(
|
| 29 |
fn=chat_with_gpt2,
|
| 30 |
+
inputs=gr.Textbox(lines=3, placeholder="Ask about your trip..."),
|
| 31 |
outputs="text",
|
| 32 |
title="Trip Planner Chatbot",
|
| 33 |
+
description="Ask anything related to your travel plans!"
|
| 34 |
)
|
| 35 |
|
| 36 |
interface.launch()
|