FineTuningModel / app.py
KB-Infinity-Tech's picture
Update app.py
0aab236 verified
raw
history blame
1.68 kB
import os
import gradio as gr
from transformers import pipeline
# -----------------------------
# CONFIG
# -----------------------------
BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "google/flan-t5-small")
FINETUNED_MODEL_ID = os.getenv("FINETUNED_MODEL_ID", "KB-Infinity-Tech/t5-samsum-mini") # <-- change to your real repo
TASK_PREFIX = os.getenv("TASK_PREFIX", "summarize: ")
HF_TOKEN = os.getenv("HF_TOKEN")
# -----------------------------
# LOAD MODELS (NEW API)
# -----------------------------
base_pipe = pipeline(
"text2text-generation",
model=BASE_MODEL_ID,
token=HF_TOKEN,
)
finetuned_pipe = pipeline(
"text2text-generation",
model=FINETUNED_MODEL_ID,
token=HF_TOKEN,
)
# -----------------------------
# FUNCTION
# -----------------------------
def summarize(text):
prompt = TASK_PREFIX + text
base_output = base_pipe(prompt, max_length=60, min_length=10, do_sample=False)
finetuned_output = finetuned_pipe(prompt, max_length=60, min_length=10, do_sample=False)
base_summary = base_output[0]["generated_text"]
finetuned_summary = finetuned_output[0]["generated_text"]
return base_summary, finetuned_summary
# -----------------------------
# UI
# -----------------------------
iface = gr.Interface(
fn=summarize,
inputs=gr.Textbox(lines=8, label="Dialogue"),
outputs=[
gr.Textbox(label=f"Base ({BASE_MODEL_ID})"),
gr.Textbox(label=f"Fine-tuned ({FINETUNED_MODEL_ID})"),
],
title="T5 SAMSum Demo",
description="Compare base vs fine-tuned summarization",
)
# -----------------------------
# RUN
# -----------------------------
if __name__ == "__main__":
iface.launch()