Xenobd commited on
Commit
060755b
·
verified ·
1 Parent(s): 835001c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -21
app.py CHANGED
@@ -1,36 +1,41 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import pipeline
4
 
5
- # CPU device
6
- device = -1
7
 
8
- # Optimized pipeline
9
- text_summary = pipeline(
10
- "summarization",
11
- model="sshleifer/distilbart-cnn-12-6",
12
- device=device,
13
- torch_dtype=torch.float32 # bfloat16 not always supported on CPU
 
 
14
  )
15
 
16
- # Optional: compile model if PyTorch >=2.1
17
- try:
18
- text_summary.model = torch.compile(text_summary.model)
19
- except:
20
- pass
 
 
21
 
22
- def summary(input_text):
23
- with torch.no_grad():
24
- output = text_summary(input_text, max_length=512, min_length=30, truncation=True)
25
  return output[0]['summary_text']
26
 
27
  gr.close_all()
28
 
 
29
  demo = gr.Interface(
30
- fn=summary,
31
  inputs=[gr.Textbox(label="Input text to summarize", lines=6)],
32
  outputs=[gr.Textbox(label="Summarized text", lines=4)],
33
- title="@GenAILearniverse Project 1: Text Summarizer",
34
- description="THIS APPLICATION WILL BE USED TO SUMMARIZE THE TEXT"
35
  )
 
36
  demo.launch()
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
2
  import gradio as gr
3
+ import torch
4
 
5
+ model_name = "sshleifer/distilbart-cnn-12-6"
 
6
 
7
+ # Load tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+
10
+ # Load model in 8-bit (CPU-friendly)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(
12
+ model_name,
13
+ load_in_8bit=True, # reduces memory & speeds up CPU
14
+ device_map="auto" # automatically maps to CPU
15
  )
16
 
17
+ # Create pipeline
18
+ summarizer = pipeline(
19
+ "summarization",
20
+ model=model,
21
+ tokenizer=tokenizer,
22
+ device=-1
23
+ )
24
 
25
+ # Function for Gradio
26
+ def summary_ui(text):
27
+ output = summarizer(text, max_length=512, min_length=30, truncation=True)
28
  return output[0]['summary_text']
29
 
30
  gr.close_all()
31
 
32
+ # Gradio interface
33
  demo = gr.Interface(
34
+ fn=summary_ui,
35
  inputs=[gr.Textbox(label="Input text to summarize", lines=6)],
36
  outputs=[gr.Textbox(label="Summarized text", lines=4)],
37
+ title="8-bit CPU Text Summarizer",
38
+ description="Summarize your text fast on CPU using 8-bit quantization"
39
  )
40
+
41
  demo.launch()