moro_text_2 / app.py
orgoflu's picture
Update app.py
7f49497 verified
import nltk
nltk.download("punkt")
import gradio as gr
import trafilatura
import requests
from markdownify import markdownify as md
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.text_rank import TextRankSummarizer
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoModelForVision2Seq
# ===== ์‚ฌ์šฉํ•  ๋ชจ๋ธ 2๊ฐœ =====
MODEL_OPTIONS = {
"Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",
"CLOVA-Donut-CORDv2": "naver-clova-ix/donut-base-finetuned-cord-v2"
}
# ===== ๋ชจ๋ธ ๋กœ๋“œ =====
def load_model(model_name):
if model_name == "naver-clova-ix/donut-base-finetuned-cord-v2":
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForVision2Seq.from_pretrained(model_name)
return pipeline("image-to-text", model=model, tokenizer=tokenizer)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True
).to("cpu")
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
# ===== ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ =====
def clean_text(text: str) -> str:
return re.sub(r'\s+', ' ', text).strip()
def remove_duplicates(sentences):
seen, result = set(), []
for s in sentences:
s_clean = s.strip()
if s_clean and s_clean not in seen:
seen.add(s_clean)
result.append(s_clean)
return result
# ===== ์ž๋™ ์š”์•ฝ =====
def summarize_text(text):
text = clean_text(text)
length = len(text)
if length < 300:
sentence_count = 1
elif length < 800:
sentence_count = 2
elif length < 1500:
sentence_count = 3
else:
sentence_count = 4
try:
parser = PlaintextParser.from_string(text, Tokenizer("korean"))
if len(parser.document.sentences) == 0:
raise ValueError
except:
try:
parser = PlaintextParser.from_string(text, Tokenizer("english"))
if len(parser.document.sentences) == 0:
raise ValueError
except:
sentences = re.split(r'(?<=[.!?])\s+', text)
return sentences[:sentence_count]
summarizer = TextRankSummarizer()
summary_sentences = summarizer(parser.document, sentence_count)
summary_list = [str(sentence) for sentence in summary_sentences]
summary_list = remove_duplicates(summary_list)
summary_list.sort(key=lambda s: text.find(s))
return summary_list
# ===== LLM ์žฌ์ž‘์„ฑ =====
def rewrite_with_llm(sentences, model_choice):
model_name = MODEL_OPTIONS[model_choice]
llm_pipeline = load_model(model_name)
joined_text = "\n".join(sentences)
if model_choice == "CLOVA-Donut-CORDv2":
# CLOVA Donut์€ ์›๋ž˜ ์ด๋ฏธ์ง€ ์ „์šฉ์ด์ง€๋งŒ, ์—ฌ๊ธฐ์„œ๋Š” ํ…์ŠคํŠธ ์ž…๋ ฅ๋„ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜
return joined_text
prompt = f"""๋‹ค์Œ ๋ฌธ์žฅ์„ ์˜๋ฏธ๋Š” ์œ ์ง€ํ•˜๋˜, ์›๋ฌธ์— ์—†๋Š” ๋‚ด์šฉ์€ ์ ˆ๋Œ€ ์ถ”๊ฐ€ํ•˜์ง€ ๋ง๊ณ ,
๋ฌธ์žฅ๋งŒ ๋” ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋ฐ”๊ฟ”์ฃผ์„ธ์š”. ๋‹ค๋ฅธ ์„ค๋ช…์ด๋‚˜ ๋ถ€์—ฐ ๋ฌธ์žฅ์€ ์“ฐ์ง€ ๋งˆ์„ธ์š”.
๋ฌธ์žฅ:
{joined_text}
"""
result = llm_pipeline(prompt, max_new_tokens=150, do_sample=False, temperature=0)
return result[0]["generated_text"].replace(prompt, "").strip()
# ===== ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ =====
def extract_summarize_paraphrase(url, model_choice):
headers = {"User-Agent": "Mozilla/5.0"}
try:
r = requests.get(url, headers=headers, timeout=10)
r.raise_for_status()
html_content = trafilatura.extract(
r.text,
output_format="html",
include_tables=False,
favor_recall=True
)
if not html_content:
markdown_text = md(r.text, heading_style="ATX")
else:
markdown_text = md(html_content, heading_style="ATX")
summary_sentences = summarize_text(markdown_text)
if not summary_sentences:
summary_sentences = ["์š”์•ฝ ์—†์Œ"]
paraphrased_text = rewrite_with_llm(summary_sentences, model_choice)
return (
markdown_text or "๋ณธ๋ฌธ ์—†์Œ",
"\n".join(summary_sentences),
paraphrased_text
)
except Exception as e:
return f"์—๋Ÿฌ ๋ฐœ์ƒ: {e}", "์š”์•ฝ ์—†์Œ", "์žฌ์ž‘์„ฑ ์—†์Œ"
# ===== Gradio UI =====
iface = gr.Interface(
fn=extract_summarize_paraphrase,
inputs=[
gr.Textbox(label="URL ์ž…๋ ฅ", placeholder="https://example.com"),
gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Qwen2.5-1.5B-Instruct", label="์žฌ์ž‘์„ฑ ๋ชจ๋ธ ์„ ํƒ")
],
outputs=[
gr.Markdown(label="์ถ”์ถœ๋œ ๋ณธ๋ฌธ"),
gr.Textbox(label="์ž๋™ ์š”์•ฝ", lines=5),
gr.Textbox(label="์ž๋™ ์žฌ์ž‘์„ฑ (LLM)", lines=5)
],
title="ํ•œ๊ตญ์–ด ๋ณธ๋ฌธ ์ถ”์ถœ + ์ž๋™ ์š”์•ฝ + LLM ์žฌ์ž‘์„ฑ",
description="Qwen 1.5B ๋˜๋Š” CLOVA Donut(CORDv2)๋กœ ์žฌ์ž‘์„ฑ"
)
if __name__ == "__main__":
iface.launch()