Spaces:
Sleeping
Sleeping
app.py
Browse files
app.py
CHANGED
|
@@ -1,55 +1,139 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
MODEL_NAME = "
|
|
|
|
| 6 |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
|
| 7 |
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# ํ ํฐํ
|
| 15 |
-
inputs = tokenizer(
|
| 16 |
-
[text],
|
| 17 |
-
max_length=1024,
|
| 18 |
-
truncation=True,
|
| 19 |
-
return_tensors="pt"
|
| 20 |
)
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
# 3. Gradio UI
|
| 36 |
with gr.Blocks() as demo:
|
| 37 |
-
gr.Markdown("## ๐ KoBART ํ๊ตญ์ด ์์ฝ๊ธฐ (
|
| 38 |
with gr.Row():
|
| 39 |
with gr.Column():
|
| 40 |
-
input_text = gr.Textbox(
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
placeholder="์ฌ๊ธฐ์ ์์ฝํ ํ๊ตญ์ด ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์."
|
| 44 |
-
)
|
| 45 |
-
min_len = gr.Slider(50, 500, value=100, step=10, label="์ต์ ์์ฝ ๊ธธ์ด")
|
| 46 |
-
max_len = gr.Slider(100, 1500, value=300, step=10, label="์ต๋ ์์ฝ ๊ธธ์ด")
|
| 47 |
btn = gr.Button("์์ฝ ์คํ")
|
| 48 |
with gr.Column():
|
| 49 |
-
output_text = gr.Textbox(label="์์ฝ ๊ฒฐ๊ณผ", lines=
|
| 50 |
-
|
| 51 |
-
btn.click(summarize, inputs=[input_text, min_len, max_len], outputs=output_text)
|
| 52 |
|
| 53 |
-
# 4. ์คํ
|
| 54 |
if __name__ == "__main__":
|
| 55 |
demo.launch()
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import math
|
| 3 |
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
|
| 6 |
|
| 7 |
+
# โ
๊ณต๊ฐ KoBART ๋ชจ๋ธ
|
| 8 |
+
MODEL_NAME = "gogamza/kobart-base-v2"
|
| 9 |
+
|
| 10 |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
|
| 11 |
model = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
| 12 |
|
| 13 |
+
# CPU ๋์ ์์ํ ์ ์ฉ
|
| 14 |
+
try:
|
| 15 |
+
model = torch.quantization.quantize_dynamic(
|
| 16 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
+
except Exception:
|
| 19 |
+
pass
|
| 20 |
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
# ===== ์ ํธ ํจ์ =====
|
| 24 |
+
def normalize_text(text: str) -> str:
|
| 25 |
+
return re.sub(r"\s+", " ", text).strip()
|
| 26 |
+
|
| 27 |
+
def split_into_sentences(text: str):
|
| 28 |
+
text = text.replace("\n", " ")
|
| 29 |
+
parts = re.split(r"(?<=[\.!?])\s+", text)
|
| 30 |
+
return [p.strip() for p in parts if p.strip()]
|
| 31 |
+
|
| 32 |
+
def token_length(s: str) -> int:
|
| 33 |
+
return len(tokenizer.encode(s, add_special_tokens=False))
|
| 34 |
+
|
| 35 |
+
def chunk_by_tokens(sentences, max_tokens=900):
|
| 36 |
+
chunks, cur, cur_tokens = [], [], 0
|
| 37 |
+
for s in sentences:
|
| 38 |
+
tl = token_length(s)
|
| 39 |
+
if tl > max_tokens:
|
| 40 |
+
piece_size = max(200, int(len(s) * (max_tokens / tl)))
|
| 41 |
+
for i in range(0, len(s), piece_size):
|
| 42 |
+
sub = s[i:i+piece_size]
|
| 43 |
+
if sub.strip():
|
| 44 |
+
chunks.append(sub.strip())
|
| 45 |
+
cur, cur_tokens = [], 0
|
| 46 |
+
continue
|
| 47 |
+
if cur_tokens + tl <= max_tokens:
|
| 48 |
+
cur.append(s)
|
| 49 |
+
cur_tokens += tl
|
| 50 |
+
else:
|
| 51 |
+
if cur:
|
| 52 |
+
chunks.append(" ".join(cur))
|
| 53 |
+
cur, cur_tokens = [s], tl
|
| 54 |
+
if cur:
|
| 55 |
+
chunks.append(" ".join(cur))
|
| 56 |
+
return chunks
|
| 57 |
|
| 58 |
+
# ===== ์์ฝ ํจ์ =====
|
| 59 |
+
def summarize_raw(text: str, min_len: int, max_len: int) -> str:
|
| 60 |
+
inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
summary_ids = model.generate(
|
| 63 |
+
inputs["input_ids"],
|
| 64 |
+
num_beams=4,
|
| 65 |
+
min_length=min_len,
|
| 66 |
+
max_length=max_len,
|
| 67 |
+
early_stopping=True,
|
| 68 |
+
no_repeat_ngram_size=3
|
| 69 |
+
)
|
| 70 |
+
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 71 |
+
|
| 72 |
+
def apply_style_prompt(text: str, mode: str, final: bool=False) -> str:
|
| 73 |
+
if mode == "concise":
|
| 74 |
+
inst = "๋ค์ ํ๊ตญ์ด ํ
์คํธ๋ฅผ ํต์ฌ๋ง ๊ฐ๊ฒฐํ๊ฒ ์์ฝํ์ธ์."
|
| 75 |
+
elif mode == "explanatory":
|
| 76 |
+
inst = "๋ค์ ํ๊ตญ์ด ํ
์คํธ๋ฅผ ๋งฅ๋ฝ์ ๋ณด์กดํ๋ฉฐ ์ดํดํ๊ธฐ ์ฝ๊ฒ ์์ฝํ์ธ์."
|
| 77 |
+
else:
|
| 78 |
+
inst = "๋ค์ ํ๊ตญ์ด ํ
์คํธ๋ฅผ bullet ํํ๋ก ํต์ฌ๋ง ์์ฝํ์ธ์."
|
| 79 |
+
if final:
|
| 80 |
+
inst += " ์ด ์์ฝ์ ์ต์ข
๋ณธ์
๋๋ค."
|
| 81 |
+
return f"{inst}\n\n[ํ
์คํธ]\n{text}"
|
| 82 |
+
|
| 83 |
+
def postprocess(summary: str, mode: str) -> str:
|
| 84 |
+
s = summary.strip()
|
| 85 |
+
s = re.sub(r"\s+", " ", s)
|
| 86 |
+
if mode == "bullets":
|
| 87 |
+
bullets = re.split(r"\s*[-โข]\s*", s)
|
| 88 |
+
bullets = [b.strip() for b in bullets if b.strip()]
|
| 89 |
+
if len(bullets) > 1:
|
| 90 |
+
s = "\n".join([f"- {b}" for b in bullets])
|
| 91 |
+
else:
|
| 92 |
+
parts = re.split(r"(?<=[\.!?])\s+", s)
|
| 93 |
+
parts = [p.strip() for p in parts if p.strip()]
|
| 94 |
+
s = "\n".join([f"- {p}" for p in parts])
|
| 95 |
+
return s
|
| 96 |
+
|
| 97 |
+
def summarize_long(text: str, target_chars: int, mode: str):
|
| 98 |
+
text = normalize_text(text)
|
| 99 |
+
if not text:
|
| 100 |
+
return "โ ๏ธ ์์ฝํ ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์."
|
| 101 |
+
approx_tokens = token_length(text)
|
| 102 |
+
if approx_tokens <= 1000:
|
| 103 |
+
min_len = max(60, int(target_chars * 0.4 / 2))
|
| 104 |
+
max_len = max(120, int(target_chars * 0.8 / 2))
|
| 105 |
+
return postprocess(summarize_raw(apply_style_prompt(text, mode), min_len, max_len), mode)
|
| 106 |
+
sentences = split_into_sentences(text)
|
| 107 |
+
chunks = chunk_by_tokens(sentences, max_tokens=900)
|
| 108 |
+
partial_summaries = []
|
| 109 |
+
budget_total = int(target_chars * 1.5)
|
| 110 |
+
per_chunk_chars = max(250, budget_total // max(1, len(chunks)))
|
| 111 |
+
for c in chunks:
|
| 112 |
+
min_len = max(50, int(per_chunk_chars * 0.4 / 2))
|
| 113 |
+
max_len = max(100, int(per_chunk_chars * 0.9 / 2))
|
| 114 |
+
psum = summarize_raw(apply_style_prompt(c, mode), min_len, max_len)
|
| 115 |
+
partial_summaries.append(psum)
|
| 116 |
+
merged = normalize_text(" ".join(partial_summaries))
|
| 117 |
+
final_min = max(80, int(target_chars * 0.45 / 2))
|
| 118 |
+
final_max = max(160, int(target_chars * 1.05 / 2))
|
| 119 |
+
return postprocess(summarize_raw(apply_style_prompt(merged, mode, final=True), final_min, final_max), mode)
|
| 120 |
+
|
| 121 |
+
# ===== Gradio UI =====
|
| 122 |
+
def ui_summarize(text, target_len, style):
|
| 123 |
+
mode = {"๊ฐ๊ฒฐํ":"concise", "์ค๋ช
ํ":"explanatory", "ํต์ฌ bullet":"bullets"}[style]
|
| 124 |
+
return summarize_long(text, int(target_len), mode)
|
| 125 |
|
|
|
|
| 126 |
with gr.Blocks() as demo:
|
| 127 |
+
gr.Markdown("## ๐ KoBART ํ๊ตญ์ด ์์ฝ๊ธฐ (๊ณต๊ฐ ๋ชจ๋ธ gogamza/kobart-base-v2)")
|
| 128 |
with gr.Row():
|
| 129 |
with gr.Column():
|
| 130 |
+
input_text = gr.Textbox(label="์๋ฌธ ์
๋ ฅ", lines=16)
|
| 131 |
+
style = gr.Radio(["๊ฐ๊ฒฐํ", "์ค๋ช
ํ", "ํต์ฌ bullet"], value="๊ฐ๊ฒฐํ", label="์์ฝ ์คํ์ผ")
|
| 132 |
+
target_len = gr.Slider(300, 1500, value=1000, step=50, label="๋ชฉํ ์์ฝ ๊ธธ์ด(๋ฌธ์)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
btn = gr.Button("์์ฝ ์คํ")
|
| 134 |
with gr.Column():
|
| 135 |
+
output_text = gr.Textbox(label="์์ฝ ๊ฒฐ๊ณผ", lines=16)
|
| 136 |
+
btn.click(ui_summarize, inputs=[input_text, target_len, style], outputs=output_text)
|
|
|
|
| 137 |
|
|
|
|
| 138 |
if __name__ == "__main__":
|
| 139 |
demo.launch()
|