ChavanN commited on
Commit
3b218c4
Β·
verified Β·
1 Parent(s): 9dc2bde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -24
app.py CHANGED
@@ -1,33 +1,34 @@
1
  import gradio as gr
2
- from transformers import T5Tokenizer, T5ForConditionalGeneration
3
- import torch
4
 
5
- model_name = "t5-small"
 
6
 
7
- tokenizer = T5Tokenizer.from_pretrained(model_name)
8
- model = T5ForConditionalGeneration.from_pretrained(model_name)
9
- model.eval() # set model to evaluation mode
 
 
 
10
 
11
- device = torch.device("cpu") # explicitly set device to CPU
12
- model.to(device)
13
-
14
- def generate_text(input_text):
15
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
16
- outputs = model.generate(
17
- input_ids,
18
- max_length=100,
19
- num_beams=5,
20
- early_stopping=True
21
- )
22
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
- return result
24
 
 
25
  demo = gr.Interface(
26
- fn=generate_text,
27
- inputs=gr.Textbox(lines=5, label="Input Text"),
28
- outputs=gr.Textbox(lines=5, label="Generated Text"),
29
- title="Simple T5 Text Generator (CPU)",
30
- description="Enter some text, T5 will generate a continuation or answer."
31
  )
32
 
33
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from pprint import pprint
3
+ from lmqg import TransformersQG
4
 
5
+ # Initialize model (CPU-only)
6
+ model = TransformersQG(model='lmqg/t5-base-squad-qg', model_ae='lmqg/t5-base-squad-ae')
7
 
8
+ # Function to chunk text into smaller parts
9
+ def chunk_text(text, chunk_size=450):
10
+ tokenizer = model.tokenizer
11
+ tokens = tokenizer.encode(text)
12
+ chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]
13
+ return chunks
14
 
15
+ # Function to process each chunk and generate QA pairs
16
+ def generate_qa_for_chunks(text):
17
+ chunks = chunk_text(text)
18
+ qa_pairs = []
19
+ for chunk in chunks:
20
+ chunk_text = model.tokenizer.decode(chunk, skip_special_tokens=True)
21
+ qa = model.generate_qa(chunk_text)
22
+ qa_pairs.extend(qa)
23
+ return "\n\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in qa_pairs])
 
 
 
 
24
 
25
+ # Gradio UI
26
  demo = gr.Interface(
27
+ fn=generate_qa_for_chunks,
28
+ inputs=gr.Textbox(lines=20, label="Input Text (SAP Note or Paragraph)"),
29
+ outputs=gr.Textbox(lines=30, label="Generated QA Pairs"),
30
+ title="Question Generator (LMQG - T5)",
31
+ description="Paste text to generate question-answer pairs using lmqg/t5-base-squad-qg"
32
  )
33
 
34
  if __name__ == "__main__":