IotaCluster commited on
Commit
424339c
·
verified ·
1 Parent(s): 4b4fd48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -37
app.py CHANGED
@@ -1,55 +1,44 @@
1
  import re
2
  import gradio as gr
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
 
5
- # Load the model and tokenizer
6
- model_name = "t5-small"
7
- tokenizer = T5Tokenizer.from_pretrained(model_name)
8
- model = T5ForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
 
 
 
 
9
 
10
- # Function to remove confidentiality notice
11
  def remove_confidentiality(text: str) -> str:
12
- # Pattern matches the confidentiality notice starting with 'CONFIDENTIALITY NOTICE:'
13
  pattern = r"\*\*CONFIDENTIALITY NOTICE:.*"
14
- # Split text at the notice and keep only the part before it
15
- cleaned = re.split(pattern, text, flags=re.DOTALL)[0]
16
- return cleaned.strip()
17
 
18
- # Define the summarization function
19
  def summarize_text(text):
20
- # Clean the text by removing the confidentiality notice if present
21
- cleaned_text = remove_confidentiality(text)
22
- # Prepare input for summarization
23
- input_text = "summarize: " + cleaned_text.strip()
24
- input_ids = tokenizer.encode(
25
- input_text,
26
- return_tensors="pt",
27
- max_length=500,
28
- truncation=True
29
- )
30
- summary_ids = model.generate(
31
- input_ids,
32
- max_length=900,
33
- min_length=800,
34
- length_penalty=2.0,
35
- num_beams=2,
36
  early_stopping=True
37
  )
38
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
- return summary
40
 
41
- # Gradio interface
42
  iface = gr.Interface(
43
  fn=summarize_text,
44
  inputs=gr.Textbox(lines=15, placeholder="Paste your text here..."),
45
  outputs=gr.Textbox(label="Summary"),
46
- title="T5 Text Summarizer",
47
- description="Enter any long English text to get a summarized version using the T5 model."
48
  )
49
 
50
- # Launch
51
- def main():
52
- iface.launch()
53
-
54
  if __name__ == "__main__":
55
- main()
 
1
  import re
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
5
+ # Choose the distilled BART checkpoint
6
+ model_name = "sshleifer/distilbart-cnn-12-6"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+
10
+ # Pre‑build the HF summarization pipeline (faster in repeated calls)
11
+ summarizer = pipeline(
12
+ "summarization",
13
+ model=model,
14
+ tokenizer=tokenizer,
15
+ device=0 # set to -1 for CPU, or the GPU ID if available
16
+ )
17
 
 
18
  def remove_confidentiality(text: str) -> str:
 
19
  pattern = r"\*\*CONFIDENTIALITY NOTICE:.*"
20
+ return re.split(pattern, text, flags=re.DOTALL)[0].strip()
 
 
21
 
 
22
  def summarize_text(text):
23
+ cleaned = remove_confidentiality(text)
24
+ # pipeline will chunk long inputs automatically if you pass `max_length` and `min_length`
25
+ summary_list = summarizer(
26
+ cleaned,
27
+ max_length=200, # shorter target length for punchier summaries
28
+ min_length=50,
29
+ length_penalty=1.5,
30
+ num_beams=3,
 
 
 
 
 
 
 
 
31
  early_stopping=True
32
  )
33
+ return summary_list[0]["summary_text"]
 
34
 
 
35
  iface = gr.Interface(
36
  fn=summarize_text,
37
  inputs=gr.Textbox(lines=15, placeholder="Paste your text here..."),
38
  outputs=gr.Textbox(label="Summary"),
39
+ title="Fast & Accurate Summarizer",
40
+ description="Using the distilled BART model for quicker, high-quality summaries."
41
  )
42
 
 
 
 
 
43
  if __name__ == "__main__":
44
+ iface.launch()