KB-Infinity-Tech commited on
Commit
bac06cc
·
verified ·
1 Parent(s): 0cb255a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -82
app.py CHANGED
@@ -1,103 +1,61 @@
1
- # import os
2
- # import gradio as gr
3
- # from transformers import pipeline
4
-
5
- # # -----------------------------
6
- # # CONFIG
7
- # # -----------------------------
8
- # BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
9
- # FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini") # <-- change to your real repo
10
- # TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
11
-
12
- # HF_TOKEN = os.getenv("HF_TOKEN")
13
-
14
- # # -----------------------------
15
- # # LOAD MODELS (NEW API)
16
- # # -----------------------------
17
- # base_pipe = pipeline(
18
- # "text2text-generation",
19
- # model=BASE_MODEL_ID,
20
- # token=HF_TOKEN,
21
- # )
22
-
23
- # finetuned_pipe = pipeline(
24
- # "text2text-generation",
25
- # model=FINETUNED_MODEL_ID,
26
- # token=HF_TOKEN,
27
- # )
28
-
29
- # # -----------------------------
30
- # # FUNCTION
31
- # # -----------------------------
32
- # def summarize(text):
33
- # prompt = TASK_PREFIX + text
34
-
35
- # base_output = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)
36
- # finetuned_output = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)
37
-
38
- # base_summary = base_output[0]["generated_text"]
39
- # finetuned_summary = finetuned_output[0]["generated_text"]
40
-
41
- # return base_summary, finetuned_summary
42
-
43
- # # -----------------------------
44
- # # UI
45
- # # -----------------------------
46
- # iface = gr.Interface(
47
- # fn=summarize,
48
- # inputs=gr.Textbox(lines=8, label="Dialogue"),
49
- # outputs=[
50
- # gr.Textbox(label=f"Base ({BASE_MODEL_ID})"),
51
- # gr.Textbox(label=f"Fine-tuned ({FINETUNED_MODEL_ID})"),
52
- # ],
53
- # title="T5 SAMSum Demo",
54
- # description="Compare base vs fine-tuned summarization",
55
- # )
56
-
57
- # # -----------------------------
58
- # # RUN
59
- # # -----------------------------
60
- # if __name__ == "__main__":
61
- # iface.launch()
62
-
63
-
64
- """Gradio app comparing base FLAN-T5 vs the fine-tuned checkpoint."""
65
-
66
  import os
67
-
68
  import gradio as gr
69
  from transformers import pipeline
70
 
 
 
 
71
  BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
72
- FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini")
73
  TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
74
 
 
75
 
76
- base_pipe = pipeline("summarization", model=BASE_MODEL_ID)
77
- finetuned_pipe = pipeline("summarization", model=FINETUNED_MODEL_ID)
 
 
 
 
 
 
78
 
 
 
 
 
 
79
 
80
- def summarize(text: str) -> tuple[str, str]:
 
 
 
81
  prompt = TASK_PREFIX + text
82
- base_summary = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)[0]["summary_text"]
83
- finetuned_summary = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)[0][
84
- "summary_text"
85
- ]
86
- return base_summary, finetuned_summary
87
 
 
 
 
 
 
 
 
88
 
 
 
 
89
  iface = gr.Interface(
90
  fn=summarize,
91
- inputs=gr.Textbox(lines=8, label="Dialogue", placeholder="Paste a conversation here"),
92
  outputs=[
93
- gr.Textbox(label=f"Base model ({BASE_MODEL_ID})"),
94
- gr.Textbox(label=f"Fine-tuned model ({FINETUNED_MODEL_ID})"),
95
  ],
96
  title="T5 SAMSum Demo",
97
- description="Compare the original FLAN-T5-small summarization with the fine-tuned checkpoint.",
98
  )
99
 
100
-
 
 
101
  if __name__ == "__main__":
102
- iface.launch()
103
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
+ # -----------------------------
6
+ # CONFIG
7
+ # -----------------------------
8
  BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
9
+ FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini") # <-- CHANGE THIS
10
  TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
11
 
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
 
14
+ # -----------------------------
15
+ # LOAD MODELS (MODERN API)
16
+ # -----------------------------
17
+ base_pipe = pipeline(
18
+ "text2text-generation",
19
+ model=BASE_MODEL_ID,
20
+ token=HF_TOKEN,
21
+ )
22
 
23
+ finetuned_pipe = pipeline(
24
+ "text2text-generation",
25
+ model=FINETUNED_MODEL_ID,
26
+ token=HF_TOKEN,
27
+ )
28
 
29
+ # -----------------------------
30
+ # FUNCTION
31
+ # -----------------------------
32
+ def summarize(text):
33
  prompt = TASK_PREFIX + text
 
 
 
 
 
34
 
35
+ base_output = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)
36
+ fine_output = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)
37
+
38
+ base_summary = base_output[0]["generated_text"]
39
+ fine_summary = fine_output[0]["generated_text"]
40
+
41
+ return base_summary, fine_summary
42
 
43
+ # -----------------------------
44
+ # UI
45
+ # -----------------------------
46
  iface = gr.Interface(
47
  fn=summarize,
48
+ inputs=gr.Textbox(lines=8, label="Dialogue"),
49
  outputs=[
50
+ gr.Textbox(label=f"Base ({BASE_MODEL_ID})"),
51
+ gr.Textbox(label=f"Fine-tuned ({FINETUNED_MODEL_ID})"),
52
  ],
53
  title="T5 SAMSum Demo",
54
+ description="Compare base vs fine-tuned summarization",
55
  )
56
 
57
+ # -----------------------------
58
+ # RUN
59
+ # -----------------------------
60
  if __name__ == "__main__":
61
+ iface.launch()