KiranRand commited on
Commit
3425308
·
verified ·
1 Parent(s): a8fb4ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -2,18 +2,16 @@ import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
 
5
- model_name = "mrm8488/GPT-2-finetuned-on-travel"
6
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
  model = GPT2LMHeadModel.from_pretrained(model_name)
8
 
9
  tokenizer.pad_token = tokenizer.eos_token
10
- model.config.pad_token_id = tokenizer.eos_token_id
11
-
12
- # Warm-up
13
- _ = model.generate(tokenizer.encode("Hello", return_tensors="pt"))
14
 
15
  def chat_with_gpt2(user_input):
16
- inputs = tokenizer.encode(user_input, return_tensors="pt", truncation=True, max_length=512)
 
17
  outputs = model.generate(
18
  inputs,
19
  max_length=250,
@@ -25,14 +23,14 @@ def chat_with_gpt2(user_input):
25
  pad_token_id=tokenizer.eos_token_id
26
  )
27
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- return response[len(user_input):].strip()
29
 
30
  interface = gr.Interface(
31
  fn=chat_with_gpt2,
32
- inputs=gr.Textbox(lines=3, placeholder="Ask me about your trip..."),
33
  outputs="text",
34
  title="Trip Planner Chatbot",
35
- description="Ask me travel questions or get suggestions for your trip!"
36
  )
37
 
38
  interface.launch()
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
 
5
+ model_name = "mrm8488/gpt2-finetuned-jhegarty-texts"
6
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
  model = GPT2LMHeadModel.from_pretrained(model_name)
8
 
9
  tokenizer.pad_token = tokenizer.eos_token
10
+ model.config.pad_token_id = model.config.eos_token_id
 
 
 
11
 
12
  def chat_with_gpt2(user_input):
13
+ prompt = f"You are a helpful travel assistant. Answer concisely and informatively. Question: {user_input}\nAnswer:"
14
+ inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
15
  outputs = model.generate(
16
  inputs,
17
  max_length=250,
 
23
  pad_token_id=tokenizer.eos_token_id
24
  )
25
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+ return response[len(prompt):].strip()
27
 
28
  interface = gr.Interface(
29
  fn=chat_with_gpt2,
30
+ inputs=gr.Textbox(lines=3, placeholder="Ask about your trip..."),
31
  outputs="text",
32
  title="Trip Planner Chatbot",
33
+ description="Ask anything related to your travel plans!"
34
  )
35
 
36
  interface.launch()