Xenobd commited on
Commit
a6fb062
·
verified ·
1 Parent(s): f9851b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -41
app.py CHANGED
@@ -1,44 +1,50 @@
1
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, BitsAndBytesConfig
2
  import gradio as gr
3
-
4
- model_name = "sshleifer/distilbart-cnn-12-6"
5
-
6
- # Load tokenizer
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
-
9
- # Setup 8-bit quantization
10
- bnb_config = BitsAndBytesConfig(
11
- load_in_8bit=True
12
- )
13
-
14
- # Load model
15
- model = AutoModelForSeq2SeqLM.from_pretrained(
16
- model_name,
17
- quantization_config=bnb_config,
18
- device_map="auto" # automatically maps to CPU/GPU
19
- )
20
-
21
- # Create pipeline WITHOUT device argument
22
- summarizer = pipeline(
23
- "summarization",
24
- model=model,
25
- tokenizer=tokenizer
26
- )
27
-
28
- # Function for Gradio
29
- def summary_ui(text):
30
- output = summarizer(text, max_length=512, min_length=30, truncation=True)
31
- return output[0]['summary_text']
32
-
33
- gr.close_all()
34
-
35
- # Gradio interface
36
- demo = gr.Interface(
37
- fn=summary_ui,
38
- inputs=[gr.Textbox(label="Input text to summarize", lines=6)],
39
- outputs=[gr.Textbox(label="Summarized text", lines=4)],
40
- title="8-bit CPU Text Summarizer",
41
- description="Summarize your text fast on CPU using 8-bit quantization"
 
 
 
 
 
 
 
42
  )
43
 
44
- demo.launch()
 
 
1
  import gradio as gr
2
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
3
+ from transformers import AutoTokenizer, pipeline
4
+
5
+ # Load ONNX model
6
+ def create_fast_summarizer():
7
+ model = ORTModelForSeq2SeqLM.from_pretrained(
8
+ "onnx-community/bart-large-cnn-ONNX",
9
+ encoder_file_name="encoder_model_q4.onnx",
10
+ decoder_file_name="decoder_model_q4.onnx",
11
+ provider="CPUExecutionProvider",
12
+ use_io_binding=True
13
+ )
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ "onnx-community/bart-large-cnn-ONNX",
16
+ use_fast=True
17
+ )
18
+ return pipeline(
19
+ "summarization",
20
+ model=model,
21
+ tokenizer=tokenizer,
22
+ device=-1
23
+ )
24
+
25
+ summarizer = create_fast_summarizer()
26
+
27
+ # Summarize function with prompt + tuned params
28
+ def summarize_text(text):
29
+ prompt = "Summarize the key events, including casualties and political context:\n" + text
30
+ result = summarizer(
31
+ prompt,
32
+ max_length=160,
33
+ min_length=80,
34
+ do_sample=False,
35
+ num_beams=6,
36
+ length_penalty=1.5,
37
+ early_stopping=True
38
+ )
39
+ return result[0]['summary_text']
40
+
41
+ # Build Gradio interface
42
+ app = gr.Interface(
43
+ fn=summarize_text,
44
+ inputs=gr.Textbox(lines=15, placeholder="Paste your text here..."),
45
+ outputs="text",
46
+ title="ONNX Summarizer 🚀",
47
+ description="Paste any news or article text and get a concise, context-rich summary."
48
  )
49
 
50
+ app.launch()