KB-Infinity-Tech commited on
Commit
79ceb46
·
verified ·
1 Parent(s): 915b94c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app comparing base FLAN-T5 vs the fine-tuned checkpoint."""
2
+
3
+ import os
4
+
5
+ import gradio as gr
6
+ from transformers import pipeline
7
+
8
+ BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
9
+ FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini")
10
+ TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
11
+
12
+
13
+ base_pipe = pipeline("summarization", model=BASE_MODEL_ID)
14
+ finetuned_pipe = pipeline("summarization", model=FINETUNED_MODEL_ID)
15
+
16
+
17
+ def summarize(text: str) -> tuple[str, str]:
18
+ prompt = TASK_PREFIX + text
19
+ base_summary = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)[0]["summary_text"]
20
+ finetuned_summary = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)[0][
21
+ "summary_text"
22
+ ]
23
+ return base_summary, finetuned_summary
24
+
25
+
26
+ iface = gr.Interface(
27
+ fn=summarize,
28
+ inputs=gr.Textbox(lines=8, label="Dialogue", placeholder="Paste a conversation here"),
29
+ outputs=[
30
+ gr.Textbox(label=f"Base model ({BASE_MODEL_ID})"),
31
+ gr.Textbox(label=f"Fine-tuned model ({FINETUNED_MODEL_ID})"),
32
+ ],
33
+ title="T5 SAMSum Demo",
34
+ description="Compare the original FLAN-T5-small summarization with the fine-tuned checkpoint.",
35
+ )
36
+
37
+
38
+ if __name__ == "__main__":
39
+ iface.launch()