import os import re import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ==== настройки ==== # Можно задать через переменную окружения MODEL_ID в Settings → Repository secrets. MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer") MAX_INPUT_LENGTH = 512 # вход в токенах MAX_TARGET_LENGTH = 64 # длина заголовка DEFAULT_NUM_TITLES = 3 DEFAULT_TEMPERATURE = 0.7 DEFAULT_BEAMS = 4 # для стабильности метрик можно 4; для разнообразия ставь do_sample=True device = "cuda" if torch.cuda.is_available() else "cpu" # ==== загрузка модели/токенизатора один раз ==== tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device) model.eval() # Простая функция выделения первой фразы без NLTK (чтобы не тянуть ресурсы в Space) _SENT_END_RE = re.compile(r"([.!?])\s+") def first_sentence(text: str) -> str: text = text.strip() if not text: return text parts = _SENT_END_RE.split(text, maxsplit=1) # parts = [before, sep, after] или просто [text] if len(parts) >= 2: return (parts[0] + parts[1]).strip() return text def generate_titles(article_text, num_titles, temperature, beams, do_sample): if not article_text or not article_text.strip(): return [] # Префикс для T5 prefixed = "summarize: " + article_text.strip() # Токенизация и обрезка по контексту inputs = tokenizer( prefixed, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LENGTH, ) inputs = {k: v.to(device) for k, v in inputs.items()} gen_kwargs = dict( max_length=MAX_TARGET_LENGTH, num_return_sequences=int(num_titles), early_stopping=True, ) # Логика генерации: # - Если do_sample=True → семплирование (temperature, top_p), # - иначе — детерминированный beam search. if do_sample: gen_kwargs.update( dict( do_sample=True, temperature=float(temperature), top_p=0.95, num_beams=1, ) ) else: gen_kwargs.update( dict( do_sample=False, num_beams=int(beams), length_penalty=1.0, ) ) with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) # Берем первую фразу и убираем дубликаты, сохраняя порядок seen = set() titles = [] for t in decoded: t1 = first_sentence(t) if t1 and t1 not in seen: seen.add(t1) titles.append(t1) # Вернем как список списков для удобной таблицы return [[t] for t in titles] # ==== интерфейс Gradio ==== with gr.Blocks() as demo: gr.Markdown("## T5 Article Title Generator") with gr.Row(): text_in = gr.Textbox( label="Article text", placeholder="Paste article text here…", lines=14, ) with gr.Row(): num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles") temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)") with gr.Row(): beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)") do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)") generate_btn = gr.Button("Generate") out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True) generate_btn.click( fn=generate_titles, inputs=[text_in, num_titles, temperature, beams, do_sample], outputs=out_table, api_name="generate", ) # Для HF Spaces достаточно экспортировать переменную приложения app = demo if __name__ == "__main__": demo.launch()