File size: 1,738 Bytes
79ceb46
433126f
 
79ceb46
433126f
 
 
 
 
79ceb46
433126f
 
79ceb46
433126f
 
79ceb46
433126f
 
 
79ceb46
433126f
 
 
bac06cc
1510136
0aab236
433126f
 
 
 
 
 
 
 
 
bac06cc
433126f
bac06cc
433126f
 
 
79ceb46
 
433126f
 
 
 
 
 
 
79ceb46
 
 
bac06cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# -----------------------------
# LOAD MODELS
# -----------------------------
BASE_MODEL = "google/flan-t5-small"
FINETUNED_MODEL = "KB-Infinity-Tech/t5-samsum-mini"

base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)

fine_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
fine_model = AutoModelForSeq2SeqLM.from_pretrained(FINETUNED_MODEL)

device = "cuda" if torch.cuda.is_available() else "cpu"
base_model.to(device)
fine_model.to(device)

# -----------------------------
# GENERATE FUNCTION
# -----------------------------
def summarize(text):
    prompt = "summarize: " + text

    # Base model
    inputs = base_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    outputs = base_model.generate(**inputs, max_new_tokens=60)
    base_summary = base_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Fine-tuned model
    inputs2 = fine_tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    outputs2 = fine_model.generate(**inputs2, max_new_tokens=60)
    fine_summary = fine_tokenizer.decode(outputs2[0], skip_special_tokens=True)

    return base_summary, fine_summary

# -----------------------------
# UI
# -----------------------------
iface = gr.Interface(
    fn=summarize,
    inputs=gr.Textbox(lines=8, label="Dialogue"),
    outputs=[
        gr.Textbox(label="Base Model (FLAN-T5)"),
        gr.Textbox(label="Fine-Tuned Model"),
    ],
    title="🧠 T5 Summarization Compare",
    description="Compare base FLAN-T5 vs your fine-tuned SAMSum model",
)

if __name__ == "__main__":
    iface.launch()