File size: 5,178 Bytes
26bd648
 
 
42dfadf
26bd648
 
 
 
 
 
 
4601952
26bd648
42dfadf
26bd648
af6f11c
276cd92
208dd23
af6f11c
 
dfdac42
af6f11c
208dd23
 
 
 
42dfadf
208dd23
 
 
 
 
 
 
 
26bd648
 
 
6cc7695
26bd648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f49497
 
42dfadf
26bd648
 
42dfadf
26bd648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42dfadf
26bd648
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import nltk
nltk.download("punkt")

import gradio as gr
import trafilatura
import requests
from markdownify import markdownify as md
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.text_rank import TextRankSummarizer
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoModelForVision2Seq

# ===== ์‚ฌ์šฉํ•  ๋ชจ๋ธ 2๊ฐœ =====
MODEL_OPTIONS = {
    "Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",
    "CLOVA-Donut-CORDv2": "naver-clova-ix/donut-base-finetuned-cord-v2"
}

# ===== ๋ชจ๋ธ ๋กœ๋“œ =====
def load_model(model_name):
    if model_name == "naver-clova-ix/donut-base-finetuned-cord-v2":
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForVision2Seq.from_pretrained(model_name)
        return pipeline("image-to-text", model=model, tokenizer=tokenizer)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            trust_remote_code=True
        ).to("cpu")
        return pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)

# ===== ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ =====
def clean_text(text: str) -> str:
    return re.sub(r'\s+', ' ', text).strip()

def remove_duplicates(sentences):
    seen, result = set(), []
    for s in sentences:
        s_clean = s.strip()
        if s_clean and s_clean not in seen:
            seen.add(s_clean)
            result.append(s_clean)
    return result

# ===== ์ž๋™ ์š”์•ฝ =====
def summarize_text(text):
    text = clean_text(text)
    length = len(text)
    if length < 300:
        sentence_count = 1
    elif length < 800:
        sentence_count = 2
    elif length < 1500:
        sentence_count = 3
    else:
        sentence_count = 4

    try:
        parser = PlaintextParser.from_string(text, Tokenizer("korean"))
        if len(parser.document.sentences) == 0:
            raise ValueError
    except:
        try:
            parser = PlaintextParser.from_string(text, Tokenizer("english"))
            if len(parser.document.sentences) == 0:
                raise ValueError
        except:
            sentences = re.split(r'(?<=[.!?])\s+', text)
            return sentences[:sentence_count]

    summarizer = TextRankSummarizer()
    summary_sentences = summarizer(parser.document, sentence_count)
    summary_list = [str(sentence) for sentence in summary_sentences]
    summary_list = remove_duplicates(summary_list)
    summary_list.sort(key=lambda s: text.find(s))
    return summary_list

# ===== LLM ์žฌ์ž‘์„ฑ =====
def rewrite_with_llm(sentences, model_choice):
    model_name = MODEL_OPTIONS[model_choice]
    llm_pipeline = load_model(model_name)

    joined_text = "\n".join(sentences)

    if model_choice == "CLOVA-Donut-CORDv2":
        # CLOVA Donut์€ ์›๋ž˜ ์ด๋ฏธ์ง€ ์ „์šฉ์ด์ง€๋งŒ, ์—ฌ๊ธฐ์„œ๋Š” ํ…์ŠคํŠธ ์ž…๋ ฅ๋„ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜
        return joined_text

    prompt = f"""๋‹ค์Œ ๋ฌธ์žฅ์„ ์˜๋ฏธ๋Š” ์œ ์ง€ํ•˜๋˜, ์›๋ฌธ์— ์—†๋Š” ๋‚ด์šฉ์€ ์ ˆ๋Œ€ ์ถ”๊ฐ€ํ•˜์ง€ ๋ง๊ณ ,
๋ฌธ์žฅ๋งŒ ๋” ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋ฐ”๊ฟ”์ฃผ์„ธ์š”. ๋‹ค๋ฅธ ์„ค๋ช…์ด๋‚˜ ๋ถ€์—ฐ ๋ฌธ์žฅ์€ ์“ฐ์ง€ ๋งˆ์„ธ์š”.

๋ฌธ์žฅ:
{joined_text}
"""
    result = llm_pipeline(prompt, max_new_tokens=150, do_sample=False, temperature=0)
    return result[0]["generated_text"].replace(prompt, "").strip()

# ===== ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ =====
def extract_summarize_paraphrase(url, model_choice):
    headers = {"User-Agent": "Mozilla/5.0"}
    try:
        r = requests.get(url, headers=headers, timeout=10)
        r.raise_for_status()

        html_content = trafilatura.extract(
            r.text,
            output_format="html",
            include_tables=False,
            favor_recall=True
        )

        if not html_content:
            markdown_text = md(r.text, heading_style="ATX")
        else:
            markdown_text = md(html_content, heading_style="ATX")

        summary_sentences = summarize_text(markdown_text)
        if not summary_sentences:
            summary_sentences = ["์š”์•ฝ ์—†์Œ"]

        paraphrased_text = rewrite_with_llm(summary_sentences, model_choice)

        return (
            markdown_text or "๋ณธ๋ฌธ ์—†์Œ",
            "\n".join(summary_sentences),
            paraphrased_text
        )

    except Exception as e:
        return f"์—๋Ÿฌ ๋ฐœ์ƒ: {e}", "์š”์•ฝ ์—†์Œ", "์žฌ์ž‘์„ฑ ์—†์Œ"

# ===== Gradio UI =====
iface = gr.Interface(
    fn=extract_summarize_paraphrase,
    inputs=[
        gr.Textbox(label="URL ์ž…๋ ฅ", placeholder="https://example.com"),
        gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Qwen2.5-1.5B-Instruct", label="์žฌ์ž‘์„ฑ ๋ชจ๋ธ ์„ ํƒ")
    ],
    outputs=[
        gr.Markdown(label="์ถ”์ถœ๋œ ๋ณธ๋ฌธ"),
        gr.Textbox(label="์ž๋™ ์š”์•ฝ", lines=5),
        gr.Textbox(label="์ž๋™ ์žฌ์ž‘์„ฑ (LLM)", lines=5)
    ],
    title="ํ•œ๊ตญ์–ด ๋ณธ๋ฌธ ์ถ”์ถœ + ์ž๋™ ์š”์•ฝ + LLM ์žฌ์ž‘์„ฑ",
    description="Qwen 1.5B ๋˜๋Š” CLOVA Donut(CORDv2)๋กœ ์žฌ์ž‘์„ฑ"
)

if __name__ == "__main__":
    iface.launch()