study_llm / app.py
Spyspook's picture
Upload 2 files
8a6a35d verified
import os
import re
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ==== настройки ====
# Можно задать через переменную окружения MODEL_ID в Settings → Repository secrets.
MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer")
MAX_INPUT_LENGTH = 512 # вход в токенах
MAX_TARGET_LENGTH = 64 # длина заголовка
DEFAULT_NUM_TITLES = 3
DEFAULT_TEMPERATURE = 0.7
DEFAULT_BEAMS = 4 # для стабильности метрик можно 4; для разнообразия ставь do_sample=True
device = "cuda" if torch.cuda.is_available() else "cpu"
# ==== загрузка модели/токенизатора один раз ====
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device)
model.eval()
# Простая функция выделения первой фразы без NLTK (чтобы не тянуть ресурсы в Space)
_SENT_END_RE = re.compile(r"([.!?])\s+")
def first_sentence(text: str) -> str:
text = text.strip()
if not text:
return text
parts = _SENT_END_RE.split(text, maxsplit=1)
# parts = [before, sep, after] или просто [text]
if len(parts) >= 2:
return (parts[0] + parts[1]).strip()
return text
def generate_titles(article_text, num_titles, temperature, beams, do_sample):
if not article_text or not article_text.strip():
return []
# Префикс для T5
prefixed = "summarize: " + article_text.strip()
# Токенизация и обрезка по контексту
inputs = tokenizer(
prefixed,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LENGTH,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = dict(
max_length=MAX_TARGET_LENGTH,
num_return_sequences=int(num_titles),
early_stopping=True,
)
# Логика генерации:
# - Если do_sample=True → семплирование (temperature, top_p),
# - иначе — детерминированный beam search.
if do_sample:
gen_kwargs.update(
dict(
do_sample=True,
temperature=float(temperature),
top_p=0.95,
num_beams=1,
)
)
else:
gen_kwargs.update(
dict(
do_sample=False,
num_beams=int(beams),
length_penalty=1.0,
)
)
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Берем первую фразу и убираем дубликаты, сохраняя порядок
seen = set()
titles = []
for t in decoded:
t1 = first_sentence(t)
if t1 and t1 not in seen:
seen.add(t1)
titles.append(t1)
# Вернем как список списков для удобной таблицы
return [[t] for t in titles]
# ==== интерфейс Gradio ====
with gr.Blocks() as demo:
gr.Markdown("## T5 Article Title Generator")
with gr.Row():
text_in = gr.Textbox(
label="Article text",
placeholder="Paste article text here…",
lines=14,
)
with gr.Row():
num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles")
temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)")
with gr.Row():
beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)")
do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)")
generate_btn = gr.Button("Generate")
out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True)
generate_btn.click(
fn=generate_titles,
inputs=[text_in, num_titles, temperature, beams, do_sample],
outputs=out_table,
api_name="generate",
)
# Для HF Spaces достаточно экспортировать переменную приложения
app = demo
if __name__ == "__main__":
demo.launch()