KiranRand commited on
Commit
a7277d9
·
verified ·
1 Parent(s): 2b2aae7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -1,24 +1,41 @@
1
  import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
3
 
4
- # Load GPT-2 large model and tokenizer
5
- model_name = "openai-community/gpt2-large" # Using the gpt2-large variant
6
- model = GPT2LMHeadModel.from_pretrained(model_name)
7
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 
 
 
 
 
8
 
9
- # Function to generate responses from GPT-2
10
  def chat_with_gpt2(user_input):
11
- # Encode the user input and append to the model's context
12
- inputs = tokenizer.encode(user_input, return_tensors="pt")
13
- # Generate a response from GPT-2
14
- outputs = model.generate(inputs, max_length=1000, num_return_sequences=1, no_repeat_ngram_size=2, pad_token_id=50256)
15
 
16
- # Decode and return the response
 
 
 
 
 
 
 
 
 
 
 
 
17
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
- return response
19
-
20
- # Create the Gradio interface
21
- interface = gr.Interface(fn=chat_with_gpt2, inputs="text", outputs="text", title="Trip Planner Chatbot")
22
 
23
- # Launch the interface
24
- interface.launch()
 
 
 
 
1
  import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ import torch
4
 
5
+ # Load model and tokenizer
6
+ model_name = "openai-community/gpt2-large"
 
7
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
+ model = GPT2LMHeadModel.from_pretrained(model_name)
9
+
10
+ # Set padding token (GPT-2 doesn't have one by default)
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+ model.config.pad_token_id = model.config.eos_token_id
13
 
14
+ # Chat function
15
  def chat_with_gpt2(user_input):
16
+ # Encode input
17
+ inputs = tokenizer.encode(user_input, return_tensors="pt", truncation=True, max_length=512)
 
 
18
 
19
+ # Generate response (limit output for speed and memory)
20
+ outputs = model.generate(
21
+ inputs,
22
+ max_length=250, # shortened for performance
23
+ num_return_sequences=1,
24
+ do_sample=True,
25
+ temperature=0.7,
26
+ top_k=50,
27
+ top_p=0.95,
28
+ pad_token_id=tokenizer.eos_token_id
29
+ )
30
+
31
+ # Decode response
32
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ # Return only the bot's reply (not repeating input)
35
+ return response[len(user_input):].strip()
 
36
 
37
+ # Gradio UI
38
+ interface = gr.Interface(
39
+ fn=chat_with_gpt2,
40
+ inputs=gr.Textbox(lines=3, placeholder="Ask me about your trip..."),
41
+ outputs="