jeevana commited on
Commit
388efe5
·
verified ·
1 Parent(s): 189e034

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -2,18 +2,40 @@ import gradio as gr
2
  import torch
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
 
 
 
 
5
 
6
  # Load the fine-tuned model and tokenizer
7
  my_model = GPT2LMHeadModel.from_pretrained("jeevana/GenerativeQnASystem")
8
  my_tokenizer = GPT2Tokenizer.from_pretrained("jeevana/GenerativeQnASystem")
9
 
10
- # def generative_qna(input):
11
- # response = generate_response(my_model, my_tokenizer, input)
12
- # return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generative_qna(input):
15
- print(input)
16
- return input
 
 
 
 
17
 
18
  app = gr.Interface(fn=generative_qna, inputs=[gr.Textbox(label="Question", lines=3)],
19
  outputs=[gr.Textbox(label="Answer", lines=6)],
 
2
  import torch
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
 
5
+ checkpoint = "gpt2"
6
+ tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
7
+
8
 
9
  # Load the fine-tuned model and tokenizer
10
  my_model = GPT2LMHeadModel.from_pretrained("jeevana/GenerativeQnASystem")
11
  my_tokenizer = GPT2Tokenizer.from_pretrained("jeevana/GenerativeQnASystem")
12
 
13
+ def generate_response(model, tokenizer, prompt):
14
+ input_ids = tokenizer.encode(prompt, return_tensors="pt",truncation=True, max_length=1000)
15
+ # Create the attention mask and pad token id
16
+ attention_mask = torch.ones_like(input_ids)
17
+ pad_token_id = tokenizer.eos_token_id
18
+
19
+ output = model.generate(
20
+ input_ids,
21
+ max_new_tokens=70,
22
+ min_new_tokens = 1,
23
+ num_return_sequences=1,
24
+ attention_mask=attention_mask,
25
+ pad_token_id=pad_token_id
26
+ )
27
+ qna = tokenizer.decode(output[0], skip_special_tokens=True)
28
+ answer = qna[len(prompt)+9: ]
29
+ return answer
30
+
31
 
32
  def generative_qna(input):
33
+ response = generate_response(my_model, my_tokenizer, input)
34
+ return response
35
+
36
+ # def generative_qna(input):
37
+ # print(input)
38
+ # return input
39
 
40
  app = gr.Interface(fn=generative_qna, inputs=[gr.Textbox(label="Question", lines=3)],
41
  outputs=[gr.Textbox(label="Answer", lines=6)],