rohanphadke commited on
Commit
421b7c3
·
1 Parent(s): f60e304

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -32
app.py CHANGED
@@ -8,47 +8,25 @@
8
 
9
  import gradio as gr
10
  import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
- # Load the pre-trained model and tokenizer
14
- model_name = 'microsoft/DialoGPT-medium'
15
- model = AutoModelForCausalLM.from_pretrained(model_name)
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
-
18
- # Set the device to GPU if available, otherwise use CPU
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
- model = model.to(device)
21
-
22
- # Define a function to generate a response given a list of user inputs
23
- def generate_response(user_inputs):
24
- # Tokenize the user inputs
25
- input_ids = tokenizer.encode(user_inputs, return_tensors='pt').to(device)
26
-
27
- # Generate a response
28
- with torch.no_grad():
29
- output = model.generate(input_ids, max_length=100, num_return_sequences=1)
30
-
31
- # Decode the generated output
32
- response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
33
- return response
34
 
35
  # Define the chatbot function
36
- def chatbot(input_text):
37
- # Append user input to the chat history
38
- history = [input_text]
39
-
40
- # Generate a response
41
- response = generate_response(history)
42
 
43
- # Append the user input and generated response to the chat history
44
- history.append(response)
45
 
46
- # Return the response
47
  return response
 
48
 
49
  # Set up the Gradio interface
50
  iface = gr.Interface(
51
- fn=chatbot,
52
  inputs=gr.inputs.Textbox(placeholder="Enter your message..."),
53
  outputs="text",
54
  title="Conversational Chatbot",
 
8
 
9
  import gradio as gr
10
  import torch
11
+ from transformers import Conversation, pipeline
12
 
13
+ # Load the conversational pipeline
14
+ model_name = "facebook/blenderbot-400M-distill"
15
+ chatbot = pipeline("conversational", model=model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Define the chatbot function
18
+ def generate_response(input_text):
19
+ conversation = Conversation()
20
+ conversation.add_user_input(input_text)
 
 
 
21
 
22
+ response = chatbot(conversation)
 
23
 
 
24
  return response
25
+ # .choices[0]['message']['content']
26
 
27
  # Set up the Gradio interface
28
  iface = gr.Interface(
29
+ fn=generate_response,
30
  inputs=gr.inputs.Textbox(placeholder="Enter your message..."),
31
  outputs="text",
32
  title="Conversational Chatbot",