KiranRand commited on
Commit
aa52321
·
verified ·
1 Parent(s): 70474f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -2,22 +2,24 @@ import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
 
5
- # Load model and tokenizer
6
- model_name = "openai-community/gpt2-large"
7
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
  model = GPT2LMHeadModel.from_pretrained(model_name)
9
 
10
  # Set padding token (GPT-2 doesn't have one by default)
11
  tokenizer.pad_token = tokenizer.eos_token
12
- model.config.pad_token_id = model.config.eos_token_id
13
 
14
  # Chat function
15
  def chat_with_gpt2(user_input):
 
16
  inputs = tokenizer.encode(user_input, return_tensors="pt", truncation=True, max_length=512)
17
-
 
18
  outputs = model.generate(
19
  inputs,
20
- max_length=250, # shortened for performance
21
  num_return_sequences=1,
22
  do_sample=True,
23
  temperature=0.7,
@@ -25,11 +27,12 @@ def chat_with_gpt2(user_input):
25
  top_p=0.95,
26
  pad_token_id=tokenizer.eos_token_id
27
  )
28
-
 
29
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
  return response[len(user_input):].strip()
31
 
32
- # Gradio UI
33
  interface = gr.Interface(
34
  fn=chat_with_gpt2,
35
  inputs=gr.Textbox(lines=3, placeholder="Ask me about your trip..."),
@@ -38,4 +41,5 @@ interface = gr.Interface(
38
  description="Ask anything related to your travel or itinerary!"
39
  )
40
 
 
41
  interface.launch()
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
 
5
+ # Load GPT-2 (community version) model and tokenizer
6
+ model_name = "openai-community/gpt2"
7
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
  model = GPT2LMHeadModel.from_pretrained(model_name)
9
 
10
  # Set padding token (GPT-2 doesn't have one by default)
11
  tokenizer.pad_token = tokenizer.eos_token
12
+ model.config.pad_token_id = tokenizer.eos_token_id
13
 
14
  # Chat function
15
  def chat_with_gpt2(user_input):
16
+ # Tokenize and encode user input
17
  inputs = tokenizer.encode(user_input, return_tensors="pt", truncation=True, max_length=512)
18
+
19
+ # Generate response
20
  outputs = model.generate(
21
  inputs,
22
+ max_length=250,
23
  num_return_sequences=1,
24
  do_sample=True,
25
  temperature=0.7,
 
27
  top_p=0.95,
28
  pad_token_id=tokenizer.eos_token_id
29
  )
30
+
31
+ # Decode and return the generated response
32
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
  return response[len(user_input):].strip()
34
 
35
+ # Gradio interface
36
  interface = gr.Interface(
37
  fn=chat_with_gpt2,
38
  inputs=gr.Textbox(lines=3, placeholder="Ask me about your trip..."),
 
41
  description="Ask anything related to your travel or itinerary!"
42
  )
43
 
44
+ # Launch the Gradio app
45
  interface.launch()