nangunan commited on
Commit
cda63cc
Β·
verified Β·
1 Parent(s): 8b391db

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -407
app.py DELETED
@@ -1,407 +0,0 @@
1
- import os
2
- import zipfile
3
- import requests
4
- import gradio as gr
5
- import whisper
6
- import subprocess
7
- import uuid
8
- import torch
9
- import re
10
- import matplotlib.pyplot as plt
11
- import language_tool_python
12
- import difflib
13
- from transformers import (
14
- AutoTokenizer,
15
- AutoModelForSeq2SeqLM,
16
- pipeline as hf_pipeline,
17
- )
18
-
19
- # ────────────────────────────────────────────────────────────────
20
- # Optional evaluation libraries
21
- try:
22
- from rouge_score import rouge_scorer
23
- except ImportError:
24
- rouge_scorer = None
25
- print("[Warning] rouge_score νŒ¨ν‚€μ§€κ°€ μ—†μŠ΅λ‹ˆλ‹€. pip install rouge-score")
26
-
27
- try:
28
- from bert_score import score as bert_score_func
29
- except ImportError:
30
- bert_score_func = None
31
- print("[Warning] bert-score νŒ¨ν‚€μ§€κ°€ μ—†μŠ΅λ‹ˆλ‹€. pip install bert-score")
32
-
33
- # ────────────────────────────────────────────────────────────────
34
- # ν•œκΈ€ λ§žμΆ€λ²• 검사(py‑hanspell)
35
- try:
36
- from hanspell import spell_checker
37
- except ImportError:
38
- spell_checker = None
39
-
40
- # ────────────────────────────────────────────────────────────────
41
- # LanguageTool λ£° 기반 ꡐ정 (μ˜μ–΄ μ „μš©)
42
- try:
43
- lt_tool = language_tool_python.LanguageTool('en-US')
44
- except Exception as e:
45
- lt_tool = None
46
- print(f"[Warning] LanguageTool μ΄ˆκΈ°ν™” μ‹€νŒ¨: {e}")
47
-
48
- # ────────────────────────────────────────────────────────────────
49
- # FFmpeg
50
- yt_dlp_path = r"C:/Windows/System32/yt-dlp.exe"
51
- ffmpeg_path = r"C:/ffmpeg/bin"
52
- def download_ffmpeg(dest_bin):
53
- if os.path.isdir(dest_bin) and os.path.isfile(os.path.join(dest_bin, "ffmpeg.exe")):
54
- return dest_bin
55
- url = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip"
56
- zip_path = os.path.join(os.getcwd(), "ffmpeg.zip")
57
- extract_root = os.path.dirname(dest_bin)
58
- os.makedirs(extract_root, exist_ok=True)
59
- resp = requests.get(url, stream=True); resp.raise_for_status()
60
- with open(zip_path, "wb") as f:
61
- for chunk in resp.iter_content(8192): f.write(chunk)
62
- with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_root)
63
- os.remove(zip_path)
64
- for root, _, files in os.walk(extract_root):
65
- if "ffmpeg.exe" in files:
66
- os.makedirs(dest_bin, exist_ok=True)
67
- for fn in ("ffmpeg.exe","ffprobe.exe","ffplay.exe"):
68
- src, dst = os.path.join(root,fn), os.path.join(dest_bin,fn)
69
- if os.path.isfile(src): os.replace(src, dst)
70
- return dest_bin
71
- raise RuntimeError("FFmpeg μ„€μΉ˜ μ‹€νŒ¨")
72
-
73
- download_ffmpeg(ffmpeg_path)
74
- os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ.get("PATH","")
75
-
76
- # ────────────────────────────────────────────────────────────────
77
- # Whisper
78
- asr_model = whisper.load_model("medium")
79
-
80
- # ────────────────────────────────────────────────────────────────
81
- # μš”μ•½ λͺ¨λΈ(λͺ¨λΈ/ν† ν¬λ‚˜μ΄μ € 직접 μ‚¬μš©, pipeline X)
82
- SUMMARY_MODELS = {
83
- "mT5_multilingual_XLSum": "csebuetnlp/mT5_multilingual_XLSum",
84
- "Pegasus XSum": "google/pegasus-xsum",
85
- "BART-large CNN": "facebook/bart-large-cnn",
86
- "DistilBART CNN": "sshleifer/distilbart-cnn-12-6"
87
- }
88
- tokenizers, models = {}, {}
89
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
-
91
- def load_summarizer(label: str):
92
- if label in models:
93
- return
94
- repo = SUMMARY_MODELS[label]
95
- tok = AutoTokenizer.from_pretrained(repo, use_fast=False)
96
- model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device)
97
- model.eval()
98
- tokenizers[label] = tok
99
- models[label] = model
100
-
101
- if rouge_scorer:
102
- scorer = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True)
103
-
104
- # ────────────────────────────────────────────────────────────────
105
- # 문법 ꡐ정
106
- GRAMMAR_MODELS = {
107
- "LanguageTool-en": None,
108
- "py-hanspell": None,
109
- "GEC-ν•œκ΅­μ–΄": "Soyoung97/gec_kr"
110
- }
111
- grammar_pipes = {}
112
-
113
- def load_grammar_pipe(name: str):
114
- repo = GRAMMAR_MODELS[name]
115
- grammar_pipes[name] = hf_pipeline(
116
- "text2text-generation",
117
- model=repo,
118
- tokenizer=AutoTokenizer.from_pretrained(repo),
119
- device=0 if torch.cuda.is_available() else -1
120
- )
121
-
122
- def correct_spelling(text, max_chunk=500):
123
- if not spell_checker: return text
124
- parts, curr = re.split(r'([.?!]\s*)', text), ""
125
- segs, out = [], []
126
- for p in parts:
127
- if len(curr)+len(p) <= max_chunk: curr += p
128
- else: segs.append(curr); curr = p
129
- if curr: segs.append(curr)
130
- for s in segs:
131
- try: out.append(spell_checker.check(s).checked)
132
- except: out.append(s)
133
- return " ".join(o.strip() for o in out)
134
-
135
- def correct_text(text, method="GEC-ν•œκ΅­μ–΄"):
136
- if method=="py-hanspell":
137
- return correct_spelling(text)
138
- if method=="LanguageTool-en" and lt_tool:
139
- matches = lt_tool.check(text)
140
- return language_tool_python.utils.correct(text, matches)
141
- if method=="GEC-ν•œκ΅­μ–΄":
142
- if method not in grammar_pipes:
143
- load_grammar_pipe(method)
144
- pipe = grammar_pipes[method]
145
- sents = re.split(r'(?<=[.?!])\s+', text)
146
- corrected=[]
147
- for sent in sents:
148
- gen = pipe(sent, max_length=256, min_length=1, do_sample=False)[0]["generated_text"]
149
- corrected.append(gen.strip())
150
- return " ".join(corrected)
151
- return text
152
-
153
- # ────────────────────────────────────────────────────────────────
154
- # ꡐ정λ₯  + Diff
155
- def calculate_correction_rate(original, corrected):
156
- orig_tokens = original.split()
157
- corr_tokens = corrected.split()
158
- sm = difflib.SequenceMatcher(None, orig_tokens, corr_tokens)
159
- diff_count = sum((i2 - i1) for tag, i1, i2, j1, j2 in sm.get_opcodes() if tag != 'equal')
160
- total = max(len(orig_tokens), 1)
161
- return round(100 * diff_count / total, 2)
162
-
163
- def highlight_diff(original: str, corrected: str) -> str:
164
- diff = difflib.ndiff(original.split(), corrected.split())
165
- html_parts = []
166
- for token in diff:
167
- if token.startswith("+ "):
168
- html_parts.append(f"<span style='color:red;'>{token[2:]}</span>")
169
- elif token.startswith("- "):
170
- continue
171
- else:
172
- html_parts.append(token[2:])
173
- return " ".join(html_parts)
174
-
175
- # ────────────────────────────────────────────────────────────────
176
- # YouTube
177
- def download_audio(url):
178
- fname = f"yt_{uuid.uuid4().hex[:8]}.mp3"
179
- cmd = [yt_dlp_path,"-f","bestaudio","--extract-audio","--audio-format","mp3","-o",fname,url]
180
- res = subprocess.run(cmd, capture_output=True, text=True)
181
- if res.returncode!=0: raise RuntimeError(res.stderr)
182
- return fname
183
-
184
- def get_transcript(url, state):
185
- if state and state.get("url")==url:
186
- return state["orig"], state
187
- audio = download_audio(url)
188
- res = asr_model.transcribe(audio)
189
- orig = res.get("text","")
190
- os.remove(audio)
191
- return orig, {"url":url, "orig":orig}
192
-
193
- # ────────────────────────────────────────────────────────────────
194
- # μ•ˆμ „ν•œ 청크 μš”μ•½ (model.generate 직접 호좜)
195
- def summarize_long_text(text: str, label: str, chunk_size: int = 512) -> str:
196
- load_summarizer(label)
197
- tok = tokenizers[label]
198
- model= models[label]
199
-
200
- enc = tok(text, return_tensors="pt", truncation=False)
201
- ids = enc.input_ids[0]
202
- summaries = []
203
-
204
- max_ctx = getattr(model.config, "max_position_embeddings", 1024) - 4
205
- chunk_size = min(chunk_size, max_ctx)
206
-
207
- for i in range(0, len(ids), chunk_size):
208
- chunk_ids = ids[i:i+chunk_size].unsqueeze(0).to(device)
209
- out_ids = model.generate(
210
- chunk_ids,
211
- max_new_tokens=128,
212
- num_beams=4,
213
- do_sample=False
214
- )
215
- summ = tok.decode(out_ids[0], skip_special_tokens=True)
216
- summaries.append(summ)
217
-
218
- combined = " ".join(summaries)
219
- enc2 = tok(combined, return_tensors="pt", truncation=True, max_length=max_ctx).to(device)
220
- out_ids = model.generate(
221
- **enc2,
222
- max_new_tokens=128,
223
- num_beams=4,
224
- do_sample=False
225
- )
226
- final = tok.decode(out_ids[0], skip_special_tokens=True)
227
- return final
228
-
229
- # ────────────────────────────────────────────────────────────────
230
- def summarize_single(url, label, grammar_method, transcript_state):
231
- orig, new_state = get_transcript(url, transcript_state)
232
- corr = correct_text(orig, method=grammar_method)
233
- corr_rate = calculate_correction_rate(orig, corr)
234
- corr_html = f"<div><b>ꡐ정λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}"
235
-
236
- summary = summarize_long_text(corr, label) if len(corr) > 100 else "⚠️ μš”μ•½ λΆˆκ°€"
237
-
238
- rouge_vals=[0,0,0]
239
- if rouge_scorer and summary.strip():
240
- sc = scorer.score(orig, summary)
241
- rouge_vals=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure]
242
-
243
- bert_f1=0
244
- if bert_score_func and summary.strip():
245
- try:
246
- _,_,F = bert_score_func([summary],[orig],lang="ko")
247
- except Exception:
248
- _,_,F = bert_score_func([summary],[orig],lang="en")
249
- bert_f1=float(F.mean())
250
-
251
- fig,ax=plt.subplots()
252
- ax.bar(["R1","R2","RL","BERT-F1"], rouge_vals+[bert_f1])
253
- ax.set_ylim(0,1); ax.set_ylabel("Score"); ax.set_title("Summary Fidelity")
254
- plt.tight_layout()
255
-
256
- return orig, corr_html, summary, fig, new_state
257
-
258
- # ────────────────────────────────────────────────────────────────
259
- def summarize_all(url, grammar_method, transcript_state):
260
- orig, new_state = get_transcript(url, transcript_state)
261
- corr = correct_text(orig, method=grammar_method)
262
- corr_rate = calculate_correction_rate(orig, corr)
263
- corr_html = f"<div><b>ꡐ정λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}"
264
-
265
- figs, interps, rv_list, bf_list = [], [], [], []
266
- summaries_plain = []
267
- labels = list(SUMMARY_MODELS.keys())
268
-
269
- for label in labels:
270
- summ = summarize_long_text(corr, label)
271
- summaries_plain.append(summ)
272
-
273
- rv=[0,0,0]; bf=0
274
- if rouge_scorer:
275
- sc = scorer.score(orig, summ)
276
- rv=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure]
277
- if bert_score_func:
278
- try:
279
- _,_,F = bert_score_func([summ],[orig],lang="ko")
280
- except Exception:
281
- _,_,F = bert_score_func([summ],[orig],lang="en")
282
- bf=float(F.mean())
283
- rv_list.append(rv); bf_list.append(bf)
284
-
285
- fig,ax=plt.subplots()
286
- ax.bar(["R1","R2","RL","BERT-F1"], rv+[bf])
287
- ax.set_ylim(0,1); ax.set_title(label)
288
- plt.tight_layout(); figs.append(fig)
289
-
290
- note="정보 손싀 많음"
291
- if bf>0.8: note="핡심 정보 잘 반영"
292
- elif bf>0.5: note="μ£Όμš” λ‚΄μš© 포함"
293
- interps.append(f"{label}: {note} (F1={bf:.2f})")
294
-
295
- html = "<h3>λͺ¨λΈλ³„ μš”μ•½ & Fidelity Metrics</h3>"
296
- html+= f"<p><b>ꡐ정λ₯ :</b> {corr_rate}%</p>"
297
- html+= "<table border='1' style='border-collapse:collapse; width:100%; table-layout:fixed;'>"
298
- html+= "<tr><th style='width:12%'>λͺ¨λΈ</th><th style='width:58%'>μš”μ•½λ¬Έ</th><th style='width:5%'>R1</th><th style='width:5%'>R2</th><th style='width:5%'>RL</th><th style='width:7%'>BERT-F1</th><th style='width:8%'>해석</th></tr>"
299
-
300
- for i,label in enumerate(labels):
301
- r1,r2,rl = rv_list[i]
302
- bf = bf_list[i]
303
- note = "정보 손싀 많음"
304
- if bf>0.8: note="핡심 정보 잘 반영"
305
- elif bf>0.5: note="μ£Όμš” λ‚΄μš© 포함"
306
-
307
- summ_html = summaries_plain[i].replace("<", "&lt;")
308
- html+= (
309
- f"<tr>"
310
- f"<td>{label}</td>"
311
- f"<td style='white-space:pre-wrap; word-break:break-word'>{summ_html}</td>"
312
- f"<td>{r1:.2f}</td><td>{r2:.2f}</td><td>{rl:.2f}</td>"
313
- f"<td>{bf:.2f}</td><td>{note}</td>"
314
- f"</tr>"
315
- )
316
- html+="</table>"
317
-
318
- return [orig, corr_html] + figs + interps + [html, new_state]
319
-
320
- # ────────────────────────────────────────────────────────────────
321
- def save_summary(url, label):
322
- orig, _ = get_transcript(url, None)
323
- corr = correct_text(orig, "GEC-ν•œκ΅­μ–΄")
324
- summary = summarize_long_text(corr, label)
325
- path = os.path.join(os.getcwd(), f"summary_{label}.txt")
326
- with open(path, "w", encoding="utf-8") as f:
327
- f.write(summary)
328
- return path
329
-
330
- # ────────────────────────────────────────────────────────────────
331
- # CSS (ꡐ정 μžλ§‰μ„ λ°•μŠ€μ²˜λŸΌ 보이게)
332
- CUSTOM_CSS = """
333
- #corr_box, #corr_box_all {
334
- border: 1px solid #ccc;
335
- padding: 10px;
336
- border-radius: 6px;
337
- background-color: #fafafa;
338
- max-height: 300px;
339
- overflow-y: auto;
340
- white-space: pre-wrap;
341
- }
342
- """
343
-
344
- # Gradio
345
- with gr.Blocks(css=CUSTOM_CSS) as demo:
346
- gr.Markdown("## 🎬 YouTube μš”μ•½ μ„œλΉ„μŠ€ (ꡐ정 + ꡐ정λ₯  + Diff κ°•μ‘°, μ•ˆμ „ μ²­ν¬μš”μ•½)")
347
-
348
- with gr.Tabs():
349
- with gr.TabItem("단일 λͺ¨λΈ μš”μ•½"):
350
- url_input = gr.Textbox(label="YouTube URL")
351
- model_sel = gr.Dropdown(list(SUMMARY_MODELS.keys()), label="μš”μ•½ λͺ¨λΈ")
352
- grammar_sel = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="ꡐ정 λͺ¨λΈ", value="GEC-ν•œκ΅­μ–΄")
353
- transcript_state = gr.State(None)
354
- btn_single = gr.Button("μš”μ•½ μ‹€ν–‰")
355
-
356
- orig_tb = gr.Textbox(label="원문 μžλ§‰", lines=10)
357
- corr_tb = gr.HTML(label="ꡐ정 μžλ§‰ (변경점 κ°•μ‘°)", elem_id="corr_box")
358
- sum_tb = gr.Textbox(label="μš”μ•½ κ²°κ³Ό", lines=8)
359
- fidelity_plot = gr.Plot(label="Fidelity Metrics")
360
- save_btn = gr.Button("μš”μ•½ μ €μž₯")
361
- download_single = gr.File(label="λ‹€μš΄λ‘œλ“œ 파일")
362
-
363
- btn_single.click(
364
- fn=summarize_single,
365
- inputs=[url_input, model_sel, grammar_sel, transcript_state],
366
- outputs=[orig_tb, corr_tb, sum_tb, fidelity_plot, transcript_state]
367
- )
368
- save_btn.click(
369
- fn=save_summary,
370
- inputs=[url_input, model_sel],
371
- outputs=[download_single]
372
- )
373
-
374
- with gr.TabItem("전체 λͺ¨λΈ 비ꡐ"):
375
- url_all = gr.Textbox(label="YouTube URL")
376
- grammar_sel_all = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="ꡐ정 λͺ¨λΈ", value="GEC-ν•œκ΅­μ–΄")
377
- transcript_state_all = gr.State(None)
378
- btn_all = gr.Button("λͺ¨λ‘ μ‹€ν–‰")
379
-
380
- orig_all = gr.Textbox(label="원문 μžλ§‰", lines=10)
381
- corr_all = gr.HTML(label="ꡐ정 μžλ§‰ (변경점 κ°•μ‘°)", elem_id="corr_box_all")
382
-
383
- plot_components, interp_components = [], []
384
- for label in SUMMARY_MODELS:
385
- plot_components.append(gr.Plot(label=f"{label} Metrics"))
386
- interp_components.append(gr.HTML(label=f"{label} 해석"))
387
-
388
- agg_table = gr.HTML(label="λͺ¨λΈλ³„ μš”μ•½ & Fidelity Metrics")
389
- save_all_sel = gr.Radio(list(SUMMARY_MODELS.keys()), label="μ €μž₯ λͺ¨λΈ μ§€μ •")
390
- save_all_btn = gr.Button("선택 μš”μ•½ μ €μž₯")
391
- download_all = gr.File(label="λ‹€μš΄λ‘œλ“œ 파일")
392
-
393
- btn_all.click(
394
- fn=summarize_all,
395
- inputs=[url_all, grammar_sel_all, transcript_state_all],
396
- outputs=[orig_all, corr_all] + plot_components + interp_components + [agg_table, transcript_state_all]
397
- )
398
- save_all_btn.click(
399
- fn=save_summary,
400
- inputs=[url_all, save_all_sel],
401
- outputs=[download_all]
402
- )
403
-
404
- if __name__ == '__main__':
405
- # μžλ™ 포트 ν• λ‹Ή
406
- demo.launch(server_name="127.0.0.1")
407
- # ν˜Ήμ€ μ™„μ „ μžλ™: demo.launch()