KB-Infinity-Tech commited on
Commit
0aab236
·
verified ·
1 Parent(s): d11c3c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -18
app.py CHANGED
@@ -1,39 +1,63 @@
1
- """Gradio app comparing base FLAN-T5 vs the fine-tuned checkpoint."""
2
-
3
  import os
4
-
5
  import gradio as gr
6
  from transformers import pipeline
7
 
 
 
 
8
  BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
9
- FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini")
10
  TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
11
 
 
12
 
13
- base_pipe = pipeline("summarization", model=BASE_MODEL_ID)
14
- finetuned_pipe = pipeline("summarization", model=FINETUNED_MODEL_ID)
 
 
 
 
 
 
15
 
 
 
 
 
 
16
 
17
- def summarize(text: str) -> tuple[str, str]:
 
 
 
18
  prompt = TASK_PREFIX + text
19
- base_summary = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)[0]["summary_text"]
20
- finetuned_summary = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)[0][
21
- "summary_text"
22
- ]
23
- return base_summary, finetuned_summary
24
 
 
 
 
 
 
25
 
 
 
 
 
 
26
  iface = gr.Interface(
27
  fn=summarize,
28
- inputs=gr.Textbox(lines=8, label="Dialogue", placeholder="Paste a conversation here"),
29
  outputs=[
30
- gr.Textbox(label=f"Base model ({BASE_MODEL_ID})"),
31
- gr.Textbox(label=f"Fine-tuned model ({FINETUNED_MODEL_ID})"),
32
  ],
33
  title="T5 SAMSum Demo",
34
- description="Compare the original FLAN-T5-small summarization with the fine-tuned checkpoint.",
35
  )
36
 
37
-
 
 
38
  if __name__ == "__main__":
39
- iface.launch()
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
+ # -----------------------------
6
+ # CONFIG
7
+ # -----------------------------
8
  BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
9
+ FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini") # <-- change to your real repo
10
  TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
11
 
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
 
14
+ # -----------------------------
15
+ # LOAD MODELS (NEW API)
16
+ # -----------------------------
17
+ base_pipe = pipeline(
18
+ "text2text-generation",
19
+ model=BASE_MODEL_ID,
20
+ token=HF_TOKEN,
21
+ )
22
 
23
+ finetuned_pipe = pipeline(
24
+ "text2text-generation",
25
+ model=FINETUNED_MODEL_ID,
26
+ token=HF_TOKEN,
27
+ )
28
 
29
+ # -----------------------------
30
+ # FUNCTION
31
+ # -----------------------------
32
+ def summarize(text):
33
  prompt = TASK_PREFIX + text
 
 
 
 
 
34
 
35
+ base_output = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)
36
+ finetuned_output = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)
37
+
38
+ base_summary = base_output[0]["generated_text"]
39
+ finetuned_summary = finetuned_output[0]["generated_text"]
40
 
41
+ return base_summary, finetuned_summary
42
+
43
+ # -----------------------------
44
+ # UI
45
+ # -----------------------------
46
  iface = gr.Interface(
47
  fn=summarize,
48
+ inputs=gr.Textbox(lines=8, label="Dialogue"),
49
  outputs=[
50
+ gr.Textbox(label=f"Base ({BASE_MODEL_ID})"),
51
+ gr.Textbox(label=f"Fine-tuned ({FINETUNED_MODEL_ID})"),
52
  ],
53
  title="T5 SAMSum Demo",
54
+ description="Compare base vs fine-tuned summarization",
55
  )
56
 
57
+ # -----------------------------
58
+ # RUN
59
+ # -----------------------------
60
  if __name__ == "__main__":
61
+ iface.launch()
62
+
63
+