orgoflu commited on
Commit
4e89892
ยท
verified ยท
1 Parent(s): 7668c1e
Files changed (1) hide show
  1. app.py +121 -37
app.py CHANGED
@@ -1,55 +1,139 @@
 
 
1
  import gradio as gr
 
2
  from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
3
 
4
- # 1. ๋ชจ๋ธ & ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
5
- MODEL_NAME = "gangyeolkim/kobart-korean-summarizer-v2"
 
6
  tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
7
  model = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
8
 
9
- # 2. ์š”์•ฝ ํ•จ์ˆ˜
10
- def summarize(text, min_len, max_len):
11
- if not text.strip():
12
- return "โš ๏ธ ์š”์•ฝํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."
13
-
14
- # ํ† ํฐํ™”
15
- inputs = tokenizer(
16
- [text],
17
- max_length=1024,
18
- truncation=True,
19
- return_tensors="pt"
20
  )
 
 
21
 
22
- # ๋ชจ๋ธ ์ถ”๋ก 
23
- summary_ids = model.generate(
24
- inputs["input_ids"],
25
- num_beams=4,
26
- min_length=min_len,
27
- max_length=max_len,
28
- early_stopping=True
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # ๋””์ฝ”๋”ฉ
32
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
33
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # 3. Gradio UI
36
  with gr.Blocks() as demo:
37
- gr.Markdown("## ๐Ÿ“ KoBART ํ•œ๊ตญ์–ด ์š”์•ฝ๊ธฐ (CPU ์ตœ์ ํ™” ๊ฐ€๋Šฅ)")
38
  with gr.Row():
39
  with gr.Column():
40
- input_text = gr.Textbox(
41
- label="์›๋ฌธ ์ž…๋ ฅ (์ตœ๋Œ€ 2000์ž)",
42
- lines=15,
43
- placeholder="์—ฌ๊ธฐ์— ์š”์•ฝํ•  ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."
44
- )
45
- min_len = gr.Slider(50, 500, value=100, step=10, label="์ตœ์†Œ ์š”์•ฝ ๊ธธ์ด")
46
- max_len = gr.Slider(100, 1500, value=300, step=10, label="์ตœ๋Œ€ ์š”์•ฝ ๊ธธ์ด")
47
  btn = gr.Button("์š”์•ฝ ์‹คํ–‰")
48
  with gr.Column():
49
- output_text = gr.Textbox(label="์š”์•ฝ ๊ฒฐ๊ณผ", lines=15)
50
-
51
- btn.click(summarize, inputs=[input_text, min_len, max_len], outputs=output_text)
52
 
53
- # 4. ์‹คํ–‰
54
  if __name__ == "__main__":
55
  demo.launch()
 
1
+ import re
2
+ import math
3
  import gradio as gr
4
+ import torch
5
  from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
6
 
7
+ # โœ… ๊ณต๊ฐœ KoBART ๋ชจ๋ธ
8
+ MODEL_NAME = "gogamza/kobart-base-v2"
9
+
10
  tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
11
  model = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
12
 
13
+ # CPU ๋™์  ์–‘์žํ™” ์ ์šฉ
14
+ try:
15
+ model = torch.quantization.quantize_dynamic(
16
+ model, {torch.nn.Linear}, dtype=torch.qint8
 
 
 
 
 
 
 
17
  )
18
+ except Exception:
19
+ pass
20
 
21
+ model.eval()
22
+
23
+ # ===== ์œ ํ‹ธ ํ•จ์ˆ˜ =====
24
+ def normalize_text(text: str) -> str:
25
+ return re.sub(r"\s+", " ", text).strip()
26
+
27
+ def split_into_sentences(text: str):
28
+ text = text.replace("\n", " ")
29
+ parts = re.split(r"(?<=[\.!?])\s+", text)
30
+ return [p.strip() for p in parts if p.strip()]
31
+
32
+ def token_length(s: str) -> int:
33
+ return len(tokenizer.encode(s, add_special_tokens=False))
34
+
35
+ def chunk_by_tokens(sentences, max_tokens=900):
36
+ chunks, cur, cur_tokens = [], [], 0
37
+ for s in sentences:
38
+ tl = token_length(s)
39
+ if tl > max_tokens:
40
+ piece_size = max(200, int(len(s) * (max_tokens / tl)))
41
+ for i in range(0, len(s), piece_size):
42
+ sub = s[i:i+piece_size]
43
+ if sub.strip():
44
+ chunks.append(sub.strip())
45
+ cur, cur_tokens = [], 0
46
+ continue
47
+ if cur_tokens + tl <= max_tokens:
48
+ cur.append(s)
49
+ cur_tokens += tl
50
+ else:
51
+ if cur:
52
+ chunks.append(" ".join(cur))
53
+ cur, cur_tokens = [s], tl
54
+ if cur:
55
+ chunks.append(" ".join(cur))
56
+ return chunks
57
 
58
+ # ===== ์š”์•ฝ ํ•จ์ˆ˜ =====
59
+ def summarize_raw(text: str, min_len: int, max_len: int) -> str:
60
+ inputs = tokenizer([text], max_length=1024, truncation=True, return_tensors="pt")
61
+ with torch.no_grad():
62
+ summary_ids = model.generate(
63
+ inputs["input_ids"],
64
+ num_beams=4,
65
+ min_length=min_len,
66
+ max_length=max_len,
67
+ early_stopping=True,
68
+ no_repeat_ngram_size=3
69
+ )
70
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
71
+
72
+ def apply_style_prompt(text: str, mode: str, final: bool=False) -> str:
73
+ if mode == "concise":
74
+ inst = "๋‹ค์Œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ํ•ต์‹ฌ๋งŒ ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์š”์•ฝํ•˜์„ธ์š”."
75
+ elif mode == "explanatory":
76
+ inst = "๋‹ค์Œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ๋งฅ๋ฝ์„ ๋ณด์กดํ•˜๋ฉฐ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ฒŒ ์š”์•ฝํ•˜์„ธ์š”."
77
+ else:
78
+ inst = "๋‹ค์Œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ bullet ํ˜•ํƒœ๋กœ ํ•ต์‹ฌ๋งŒ ์š”์•ฝํ•˜์„ธ์š”."
79
+ if final:
80
+ inst += " ์ด ์š”์•ฝ์€ ์ตœ์ข…๋ณธ์ž…๋‹ˆ๋‹ค."
81
+ return f"{inst}\n\n[ํ…์ŠคํŠธ]\n{text}"
82
+
83
+ def postprocess(summary: str, mode: str) -> str:
84
+ s = summary.strip()
85
+ s = re.sub(r"\s+", " ", s)
86
+ if mode == "bullets":
87
+ bullets = re.split(r"\s*[-โ€ข]\s*", s)
88
+ bullets = [b.strip() for b in bullets if b.strip()]
89
+ if len(bullets) > 1:
90
+ s = "\n".join([f"- {b}" for b in bullets])
91
+ else:
92
+ parts = re.split(r"(?<=[\.!?])\s+", s)
93
+ parts = [p.strip() for p in parts if p.strip()]
94
+ s = "\n".join([f"- {p}" for p in parts])
95
+ return s
96
+
97
+ def summarize_long(text: str, target_chars: int, mode: str):
98
+ text = normalize_text(text)
99
+ if not text:
100
+ return "โš ๏ธ ์š”์•ฝํ•  ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."
101
+ approx_tokens = token_length(text)
102
+ if approx_tokens <= 1000:
103
+ min_len = max(60, int(target_chars * 0.4 / 2))
104
+ max_len = max(120, int(target_chars * 0.8 / 2))
105
+ return postprocess(summarize_raw(apply_style_prompt(text, mode), min_len, max_len), mode)
106
+ sentences = split_into_sentences(text)
107
+ chunks = chunk_by_tokens(sentences, max_tokens=900)
108
+ partial_summaries = []
109
+ budget_total = int(target_chars * 1.5)
110
+ per_chunk_chars = max(250, budget_total // max(1, len(chunks)))
111
+ for c in chunks:
112
+ min_len = max(50, int(per_chunk_chars * 0.4 / 2))
113
+ max_len = max(100, int(per_chunk_chars * 0.9 / 2))
114
+ psum = summarize_raw(apply_style_prompt(c, mode), min_len, max_len)
115
+ partial_summaries.append(psum)
116
+ merged = normalize_text(" ".join(partial_summaries))
117
+ final_min = max(80, int(target_chars * 0.45 / 2))
118
+ final_max = max(160, int(target_chars * 1.05 / 2))
119
+ return postprocess(summarize_raw(apply_style_prompt(merged, mode, final=True), final_min, final_max), mode)
120
+
121
+ # ===== Gradio UI =====
122
+ def ui_summarize(text, target_len, style):
123
+ mode = {"๊ฐ„๊ฒฐํ˜•":"concise", "์„ค๋ช…ํ˜•":"explanatory", "ํ•ต์‹ฌ bullet":"bullets"}[style]
124
+ return summarize_long(text, int(target_len), mode)
125
 
 
126
  with gr.Blocks() as demo:
127
+ gr.Markdown("## ๐Ÿ“ KoBART ํ•œ๊ตญ์–ด ์š”์•ฝ๊ธฐ (๊ณต๊ฐœ ๋ชจ๋ธ gogamza/kobart-base-v2)")
128
  with gr.Row():
129
  with gr.Column():
130
+ input_text = gr.Textbox(label="์›๋ฌธ ์ž…๋ ฅ", lines=16)
131
+ style = gr.Radio(["๊ฐ„๊ฒฐํ˜•", "์„ค๋ช…ํ˜•", "ํ•ต์‹ฌ bullet"], value="๊ฐ„๊ฒฐํ˜•", label="์š”์•ฝ ์Šคํƒ€์ผ")
132
+ target_len = gr.Slider(300, 1500, value=1000, step=50, label="๋ชฉํ‘œ ์š”์•ฝ ๊ธธ์ด(๋ฌธ์ž)")
 
 
 
 
133
  btn = gr.Button("์š”์•ฝ ์‹คํ–‰")
134
  with gr.Column():
135
+ output_text = gr.Textbox(label="์š”์•ฝ ๊ฒฐ๊ณผ", lines=16)
136
+ btn.click(ui_summarize, inputs=[input_text, target_len, style], outputs=output_text)
 
137
 
 
138
  if __name__ == "__main__":
139
  demo.launch()