vangru commited on
Commit
ad7b5db
·
verified ·
1 Parent(s): 8ccb87b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -26
app.py CHANGED
@@ -1,33 +1,64 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # Load summarization pipeline
5
- summarizer = pipeline(
6
- "summarization",
7
- model="facebook/bart-large-cnn"
8
- )
9
 
10
- def summarize_text(text):
 
 
 
 
 
 
 
 
11
  if not text.strip():
12
  return "Please enter some text to summarize."
13
-
14
- summary = summarizer(
15
  text,
16
- max_length=150,
17
- min_length=40,
18
- do_sample=False
 
 
 
 
 
 
 
 
19
  )
20
-
21
- return summary[0]['summary_text']
22
-
23
- # Gradio Interface
24
- demo = gr.Interface(
25
- fn=summarize_text,
26
- inputs=gr.Textbox(lines=15, placeholder="Paste your text here..."),
27
- outputs="text",
28
- title="Advanced BART Text Summarizer",
29
- description="Summarize long text using facebook/bart-large-cnn"
30
- )
31
-
32
- if __name__ == "__main__":
33
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
+ MODEL_NAME = "facebook/bart-large-cnn"
 
 
 
 
6
 
7
+ # Load model and tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = model.to(device)
13
+
14
+
15
+ def summarize(text, max_length, min_length):
16
  if not text.strip():
17
  return "Please enter some text to summarize."
18
+
19
+ inputs = tokenizer(
20
  text,
21
+ return_tensors="pt",
22
+ max_length=1024,
23
+ truncation=True
24
+ ).to(device)
25
+
26
+ summary_ids = model.generate(
27
+ inputs["input_ids"],
28
+ num_beams=4,
29
+ max_length=max_length,
30
+ min_length=min_length,
31
+ early_stopping=True
32
  )
33
+
34
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
35
+ return summary
36
+
37
+
38
+ with gr.Blocks(title="Advanced BART Summarizer") as demo:
39
+ gr.Markdown("# Advanced BART Text Summarizer")
40
+ gr.Markdown("Summarization using facebook/bart-large-cnn")
41
+
42
+ input_text = gr.Textbox(
43
+ lines=12,
44
+ placeholder="Enter long article or paragraph here..."
45
+ )
46
+
47
+ with gr.Row():
48
+ max_len = gr.Slider(50, 300, value=150, label="Max Summary Length")
49
+ min_len = gr.Slider(20, 100, value=40, label="Min Summary Length")
50
+
51
+ output_text = gr.Textbox(
52
+ lines=8,
53
+ label="Generated Summary"
54
+ )
55
+
56
+ summarize_btn = gr.Button("Summarize")
57
+
58
+ summarize_btn.click(
59
+ summarize,
60
+ inputs=[input_text, max_len, min_len],
61
+ outputs=output_text
62
+ )
63
+
64
+ demo.launch()