KB-Infinity-Tech commited on
Commit
1510136
·
verified ·
1 Parent(s): bac06cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -34
app.py CHANGED
@@ -2,60 +2,37 @@ import os
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
- # -----------------------------
6
- # CONFIG
7
- # -----------------------------
8
- BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
9
- FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini") # <-- CHANGE THIS
10
- TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
11
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
 
14
- # -----------------------------
15
- # LOAD MODELS (MODERN API)
16
- # -----------------------------
17
  base_pipe = pipeline(
18
- "text2text-generation",
19
  model=BASE_MODEL_ID,
20
  token=HF_TOKEN,
21
  )
22
 
23
  finetuned_pipe = pipeline(
24
- "text2text-generation",
25
  model=FINETUNED_MODEL_ID,
26
  token=HF_TOKEN,
27
  )
28
 
29
- # -----------------------------
30
- # FUNCTION
31
- # -----------------------------
32
  def summarize(text):
33
- prompt = TASK_PREFIX + text
34
 
35
- base_output = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)
36
- fine_output = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)
37
 
38
- base_summary = base_output[0]["generated_text"]
39
- fine_summary = fine_output[0]["generated_text"]
40
 
41
- return base_summary, fine_summary
42
-
43
- # -----------------------------
44
- # UI
45
- # -----------------------------
46
  iface = gr.Interface(
47
  fn=summarize,
48
- inputs=gr.Textbox(lines=8, label="Dialogue"),
49
- outputs=[
50
- gr.Textbox(label=f"Base ({BASE_MODEL_ID})"),
51
- gr.Textbox(label=f"Fine-tuned ({FINETUNED_MODEL_ID})"),
52
- ],
53
- title="T5 SAMSum Demo",
54
- description="Compare base vs fine-tuned summarization",
55
  )
56
 
57
- # -----------------------------
58
- # RUN
59
- # -----------------------------
60
  if __name__ == "__main__":
61
  iface.launch()
 
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
+ BASE_MODEL_ID = "google/flan-t5-small"
6
+ FINETUNED_MODEL_ID = "kbsha/t5-samsum-mini" # change to yours
 
 
 
 
7
 
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
 
 
 
10
  base_pipe = pipeline(
11
+ "text-generation",
12
  model=BASE_MODEL_ID,
13
  token=HF_TOKEN,
14
  )
15
 
16
  finetuned_pipe = pipeline(
17
+ "text-generation",
18
  model=FINETUNED_MODEL_ID,
19
  token=HF_TOKEN,
20
  )
21
 
 
 
 
22
  def summarize(text):
23
+ prompt = "summarize: " + text
24
 
25
+ base = base_pipe(prompt, max_new_tokens=60)[0]["generated_text"]
26
+ fine = finetuned_pipe(prompt, max_new_tokens=60)[0]["generated_text"]
27
 
28
+ return base, fine
 
29
 
 
 
 
 
 
30
  iface = gr.Interface(
31
  fn=summarize,
32
+ inputs=gr.Textbox(lines=8),
33
+ outputs=[gr.Textbox(), gr.Textbox()],
34
+ title="T5 Summarizer",
 
 
 
 
35
  )
36
 
 
 
 
37
  if __name__ == "__main__":
38
  iface.launch()