KB-Infinity-Tech commited on
Commit
433126f
·
verified ·
1 Parent(s): 1510136

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -21
app.py CHANGED
@@ -1,37 +1,53 @@
1
- import os
2
  import gradio as gr
3
- from transformers import pipeline
 
4
 
5
- BASE_MODEL_ID = "google/flan-t5-small"
6
- FINETUNED_MODEL_ID = "kbsha/t5-samsum-mini" # change to yours
 
 
 
7
 
8
- HF_TOKEN = os.getenv("HF_TOKEN")
 
9
 
10
- base_pipe = pipeline(
11
- "text-generation",
12
- model=BASE_MODEL_ID,
13
- token=HF_TOKEN,
14
- )
15
 
16
- finetuned_pipe = pipeline(
17
- "text-generation",
18
- model=FINETUNED_MODEL_ID,
19
- token=HF_TOKEN,
20
- )
21
 
 
 
 
22
  def summarize(text):
23
  prompt = "summarize: " + text
24
 
25
- base = base_pipe(prompt, max_new_tokens=60)[0]["generated_text"]
26
- fine = finetuned_pipe(prompt, max_new_tokens=60)[0]["generated_text"]
 
 
 
 
 
 
 
27
 
28
- return base, fine
29
 
 
 
 
30
  iface = gr.Interface(
31
  fn=summarize,
32
- inputs=gr.Textbox(lines=8),
33
- outputs=[gr.Textbox(), gr.Textbox()],
34
- title="T5 Summarizer",
 
 
 
 
35
  )
36
 
37
  if __name__ == "__main__":
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
+ # -----------------------------
6
+ # LOAD MODELS
7
+ # -----------------------------
8
+ BASE_MODEL = "google/flan-t5-small"
9
+ FINETUNED_MODEL = "KB-Infinity-Tech/t5-samsum-mini"
10
 
11
+ base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
12
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
13
 
14
+ fine_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
15
+ fine_model = AutoModelForSeq2SeqLM.from_pretrained(FINETUNED_MODEL)
 
 
 
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ base_model.to(device)
19
+ fine_model.to(device)
 
 
20
 
21
+ # -----------------------------
22
+ # GENERATE FUNCTION
23
+ # -----------------------------
24
  def summarize(text):
25
  prompt = "summarize: " + text
26
 
27
+ # Base model
28
+ inputs = base_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
29
+ outputs = base_model.generate(**inputs, max_new_tokens=60)
30
+ base_summary = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+
32
+ # Fine-tuned model
33
+ inputs2 = fine_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
34
+ outputs2 = fine_model.generate(**inputs2, max_new_tokens=60)
35
+ fine_summary = fine_tokenizer.decode(outputs2[0], skip_special_tokens=True)
36
 
37
+ return base_summary, fine_summary
38
 
39
+ # -----------------------------
40
+ # UI
41
+ # -----------------------------
42
  iface = gr.Interface(
43
  fn=summarize,
44
+ inputs=gr.Textbox(lines=8, label="Dialogue"),
45
+ outputs=[
46
+ gr.Textbox(label="Base Model (FLAN-T5)"),
47
+ gr.Textbox(label="Fine-Tuned Model"),
48
+ ],
49
+ title="🧠 T5 Summarization Compare",
50
+ description="Compare base FLAN-T5 vs your fine-tuned SAMSum model",
51
  )
52
 
53
  if __name__ == "__main__":