lityops commited on
Commit
89793b6
·
verified ·
1 Parent(s): 43c6414

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -1,13 +1,26 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- MODEL_ID = 'lityops/Abstractive-Style-Summarizer'
 
5
 
6
- summarizer = pipeline("summarization", model=MODEL_ID)
 
 
 
 
 
 
 
 
7
 
8
  def generate_summary(text, style):
9
- if not text or len(text.strip()) < 50:
10
- return "Input must at least be 50 words long"
 
 
11
 
12
  input_text = f"Summarize {style}: {text}"
13
  input_words = len(text.split())
@@ -48,6 +61,12 @@ def generate_summary(text, style):
48
  )
49
  return output[0]["summary_text"]
50
 
 
 
 
 
 
 
51
  custom_css = """
52
  #header {text-align: center; margin-bottom: 25px;}
53
  .gradio-container {max-width: 95% !important;}
@@ -73,7 +92,7 @@ with gr.Blocks() as demo:
73
  )
74
  with gr.Row():
75
  clear_btn = gr.Button("Clear Input")
76
- submit_btn = gr.Button("Process Summary", variant="primary")
77
 
78
  with gr.Column(scale=1):
79
  output_box = gr.Textbox(
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
4
+ from peft import PeftModel
5
 
6
+ ADAPTER_HUB = "lityops/Abstractive-Style-Summarizer"
7
+ BASE_MODEL_NAME = "google/flan-t5-base"
8
 
9
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
10
+ model = PeftModel.from_pretrained(base_model, ADAPTER_HUB)
11
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_HUB)
12
+
13
+ summarizer = pipeline(
14
+ "summarization",
15
+ model=model,
16
+ tokenizer=tokenizer
17
+ )
18
 
19
  def generate_summary(text, style):
20
+ if not text or len(text.split()) < 100:
21
+ return "Input must at least be 100 words long"
22
+ if len(text.split()) > 512:
23
+ return "Input must at most be 512 words long"
24
 
25
  input_text = f"Summarize {style}: {text}"
26
  input_words = len(text.split())
 
61
  )
62
  return output[0]["summary_text"]
63
 
64
+ custom_css = """
65
+ #header {text-align: center; margin-bottom: 25px;}
66
+ .gradio-container {max-width: 1000px !important;}
67
+ footer {display: none !important;}
68
+ """
69
+
70
  custom_css = """
71
  #header {text-align: center; margin-bottom: 25px;}
72
  .gradio-container {max-width: 95% !important;}
 
92
  )
93
  with gr.Row():
94
  clear_btn = gr.Button("Clear Input")
95
+ submit_btn = gr.Button("Generate Summary", variant="primary")
96
 
97
  with gr.Column(scale=1):
98
  output_box = gr.Textbox(