Spaces:
Sleeping
Sleeping
File size: 5,178 Bytes
26bd648 42dfadf 26bd648 4601952 26bd648 42dfadf 26bd648 af6f11c 276cd92 208dd23 af6f11c dfdac42 af6f11c 208dd23 42dfadf 208dd23 26bd648 6cc7695 26bd648 7f49497 42dfadf 26bd648 42dfadf 26bd648 42dfadf 26bd648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import nltk
nltk.download("punkt")
import gradio as gr
import trafilatura
import requests
from markdownify import markdownify as md
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.text_rank import TextRankSummarizer
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoModelForVision2Seq
# ===== ์ฌ์ฉํ ๋ชจ๋ธ 2๊ฐ =====
MODEL_OPTIONS = {
"Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",
"CLOVA-Donut-CORDv2": "naver-clova-ix/donut-base-finetuned-cord-v2"
}
# ===== ๋ชจ๋ธ ๋ก๋ =====
def load_model(model_name):
if model_name == "naver-clova-ix/donut-base-finetuned-cord-v2":
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForVision2Seq.from_pretrained(model_name)
return pipeline("image-to-text", model=model, tokenizer=tokenizer)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True
).to("cpu")
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
# ===== ํ
์คํธ ์ ์ฒ๋ฆฌ =====
def clean_text(text: str) -> str:
return re.sub(r'\s+', ' ', text).strip()
def remove_duplicates(sentences):
seen, result = set(), []
for s in sentences:
s_clean = s.strip()
if s_clean and s_clean not in seen:
seen.add(s_clean)
result.append(s_clean)
return result
# ===== ์๋ ์์ฝ =====
def summarize_text(text):
text = clean_text(text)
length = len(text)
if length < 300:
sentence_count = 1
elif length < 800:
sentence_count = 2
elif length < 1500:
sentence_count = 3
else:
sentence_count = 4
try:
parser = PlaintextParser.from_string(text, Tokenizer("korean"))
if len(parser.document.sentences) == 0:
raise ValueError
except:
try:
parser = PlaintextParser.from_string(text, Tokenizer("english"))
if len(parser.document.sentences) == 0:
raise ValueError
except:
sentences = re.split(r'(?<=[.!?])\s+', text)
return sentences[:sentence_count]
summarizer = TextRankSummarizer()
summary_sentences = summarizer(parser.document, sentence_count)
summary_list = [str(sentence) for sentence in summary_sentences]
summary_list = remove_duplicates(summary_list)
summary_list.sort(key=lambda s: text.find(s))
return summary_list
# ===== LLM ์ฌ์์ฑ =====
def rewrite_with_llm(sentences, model_choice):
model_name = MODEL_OPTIONS[model_choice]
llm_pipeline = load_model(model_name)
joined_text = "\n".join(sentences)
if model_choice == "CLOVA-Donut-CORDv2":
# CLOVA Donut์ ์๋ ์ด๋ฏธ์ง ์ ์ฉ์ด์ง๋ง, ์ฌ๊ธฐ์๋ ํ
์คํธ ์
๋ ฅ๋ ๊ทธ๋๋ก ๋ฐํ
return joined_text
prompt = f"""๋ค์ ๋ฌธ์ฅ์ ์๋ฏธ๋ ์ ์งํ๋, ์๋ฌธ์ ์๋ ๋ด์ฉ์ ์ ๋ ์ถ๊ฐํ์ง ๋ง๊ณ ,
๋ฌธ์ฅ๋ง ๋ ์์ฐ์ค๋ฝ๊ฒ ๋ฐ๊ฟ์ฃผ์ธ์. ๋ค๋ฅธ ์ค๋ช
์ด๋ ๋ถ์ฐ ๋ฌธ์ฅ์ ์ฐ์ง ๋ง์ธ์.
๋ฌธ์ฅ:
{joined_text}
"""
result = llm_pipeline(prompt, max_new_tokens=150, do_sample=False, temperature=0)
return result[0]["generated_text"].replace(prompt, "").strip()
# ===== ์ ์ฒด ํ์ดํ๋ผ์ธ =====
def extract_summarize_paraphrase(url, model_choice):
headers = {"User-Agent": "Mozilla/5.0"}
try:
r = requests.get(url, headers=headers, timeout=10)
r.raise_for_status()
html_content = trafilatura.extract(
r.text,
output_format="html",
include_tables=False,
favor_recall=True
)
if not html_content:
markdown_text = md(r.text, heading_style="ATX")
else:
markdown_text = md(html_content, heading_style="ATX")
summary_sentences = summarize_text(markdown_text)
if not summary_sentences:
summary_sentences = ["์์ฝ ์์"]
paraphrased_text = rewrite_with_llm(summary_sentences, model_choice)
return (
markdown_text or "๋ณธ๋ฌธ ์์",
"\n".join(summary_sentences),
paraphrased_text
)
except Exception as e:
return f"์๋ฌ ๋ฐ์: {e}", "์์ฝ ์์", "์ฌ์์ฑ ์์"
# ===== Gradio UI =====
iface = gr.Interface(
fn=extract_summarize_paraphrase,
inputs=[
gr.Textbox(label="URL ์
๋ ฅ", placeholder="https://example.com"),
gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Qwen2.5-1.5B-Instruct", label="์ฌ์์ฑ ๋ชจ๋ธ ์ ํ")
],
outputs=[
gr.Markdown(label="์ถ์ถ๋ ๋ณธ๋ฌธ"),
gr.Textbox(label="์๋ ์์ฝ", lines=5),
gr.Textbox(label="์๋ ์ฌ์์ฑ (LLM)", lines=5)
],
title="ํ๊ตญ์ด ๋ณธ๋ฌธ ์ถ์ถ + ์๋ ์์ฝ + LLM ์ฌ์์ฑ",
description="Qwen 1.5B ๋๋ CLOVA Donut(CORDv2)๋ก ์ฌ์์ฑ"
)
if __name__ == "__main__":
iface.launch() |