Spaces:
Runtime error
Runtime error
| import os | |
| import zipfile | |
| import requests | |
| import gradio as gr | |
| import whisper | |
| import subprocess | |
| import uuid | |
| import torch | |
| import re | |
| import matplotlib.pyplot as plt | |
| import language_tool_python | |
| import difflib | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| pipeline as hf_pipeline, | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Optional evaluation libraries | |
| try: | |
| from rouge_score import rouge_scorer | |
| except ImportError: | |
| rouge_scorer = None | |
| print("[Warning] rouge_score ν¨ν€μ§κ° μμ΅λλ€. pip install rouge-score") | |
| try: | |
| from bert_score import score as bert_score_func | |
| except ImportError: | |
| bert_score_func = None | |
| print("[Warning] bert-score ν¨ν€μ§κ° μμ΅λλ€. pip install bert-score") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # νκΈ λ§μΆ€λ² κ²μ¬(pyβhanspell) | |
| try: | |
| from hanspell import spell_checker | |
| except ImportError: | |
| spell_checker = None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LanguageTool λ£° κΈ°λ° κ΅μ (μμ΄ μ μ©) | |
| try: | |
| lt_tool = language_tool_python.LanguageTool('en-US') | |
| except Exception as e: | |
| lt_tool = None | |
| print(f"[Warning] LanguageTool μ΄κΈ°ν μ€ν¨: {e}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FFmpeg | |
| yt_dlp_path = "C:/Windows/System32/yt-dlp.exe" | |
| ffmpeg_path = "C:/ProgramData/chocolatey/bin" | |
| def download_ffmpeg(dest_bin): | |
| if os.path.isdir(dest_bin) and os.path.isfile(os.path.join(dest_bin, "ffmpeg.exe")): | |
| return dest_bin | |
| url = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip" | |
| zip_path = os.path.join(os.getcwd(), "ffmpeg.zip") | |
| extract_root = os.path.dirname(dest_bin) | |
| os.makedirs(extract_root, exist_ok=True) | |
| resp = requests.get(url, stream=True); resp.raise_for_status() | |
| with open(zip_path, "wb") as f: | |
| for chunk in resp.iter_content(8192): f.write(chunk) | |
| with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_root) | |
| os.remove(zip_path) | |
| for root, _, files in os.walk(extract_root): | |
| if "ffmpeg.exe" in files: | |
| os.makedirs(dest_bin, exist_ok=True) | |
| for fn in ("ffmpeg.exe","ffprobe.exe","ffplay.exe"): | |
| src, dst = os.path.join(root,fn), os.path.join(dest_bin,fn) | |
| if os.path.isfile(src): os.replace(src, dst) | |
| return dest_bin | |
| raise RuntimeError("FFmpeg μ€μΉ μ€ν¨") | |
| download_ffmpeg(ffmpeg_path) | |
| os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ.get("PATH","") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Whisper | |
| asr_model = whisper.load_model("medium") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # μμ½ λͺ¨λΈ(λͺ¨λΈ/ν ν¬λμ΄μ μ§μ μ¬μ©, pipeline X) | |
| SUMMARY_MODELS = { | |
| "mT5_multilingual_XLSum": "csebuetnlp/mT5_multilingual_XLSum", | |
| "Pegasus XSum": "google/pegasus-xsum", | |
| "BART-large CNN": "facebook/bart-large-cnn", | |
| "DistilBART CNN": "sshleifer/distilbart-cnn-12-6" | |
| } | |
| tokenizers, models = {}, {} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_summarizer(label: str): | |
| if label in models: | |
| return | |
| repo = SUMMARY_MODELS[label] | |
| tok = AutoTokenizer.from_pretrained(repo, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device) | |
| model.eval() | |
| tokenizers[label] = tok | |
| models[label] = model | |
| if rouge_scorer: | |
| scorer = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # λ¬Έλ² κ΅μ | |
| GRAMMAR_MODELS = { | |
| "LanguageTool-en": None, | |
| "py-hanspell": None, | |
| "GEC-νκ΅μ΄": "Soyoung97/gec_kr" | |
| } | |
| grammar_pipes = {} | |
| def load_grammar_pipe(name: str): | |
| repo = GRAMMAR_MODELS[name] | |
| grammar_pipes[name] = hf_pipeline( | |
| "text2text-generation", | |
| model=repo, | |
| tokenizer=AutoTokenizer.from_pretrained(repo), | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| def correct_spelling(text, max_chunk=500): | |
| if not spell_checker: return text | |
| parts, curr = re.split(r'([.?!]\s*)', text), "" | |
| segs, out = [], [] | |
| for p in parts: | |
| if len(curr)+len(p) <= max_chunk: curr += p | |
| else: segs.append(curr); curr = p | |
| if curr: segs.append(curr) | |
| for s in segs: | |
| try: out.append(spell_checker.check(s).checked) | |
| except: out.append(s) | |
| return " ".join(o.strip() for o in out) | |
| def correct_text(text, method="GEC-νκ΅μ΄"): | |
| if method=="py-hanspell": | |
| return correct_spelling(text) | |
| if method=="LanguageTool-en" and lt_tool: | |
| matches = lt_tool.check(text) | |
| return language_tool_python.utils.correct(text, matches) | |
| if method=="GEC-νκ΅μ΄": | |
| if method not in grammar_pipes: | |
| load_grammar_pipe(method) | |
| pipe = grammar_pipes[method] | |
| sents = re.split(r'(?<=[.?!])\s+', text) | |
| corrected=[] | |
| for sent in sents: | |
| gen = pipe(sent, max_length=256, min_length=1, do_sample=False)[0]["generated_text"] | |
| corrected.append(gen.strip()) | |
| return " ".join(corrected) | |
| return text | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # κ΅μ λ₯ + Diff | |
| def calculate_correction_rate(original, corrected): | |
| orig_tokens = original.split() | |
| corr_tokens = corrected.split() | |
| sm = difflib.SequenceMatcher(None, orig_tokens, corr_tokens) | |
| diff_count = sum((i2 - i1) for tag, i1, i2, j1, j2 in sm.get_opcodes() if tag != 'equal') | |
| total = max(len(orig_tokens), 1) | |
| return round(100 * diff_count / total, 2) | |
| def highlight_diff(original: str, corrected: str) -> str: | |
| diff = difflib.ndiff(original.split(), corrected.split()) | |
| html_parts = [] | |
| for token in diff: | |
| if token.startswith("+ "): | |
| html_parts.append(f"<span style='color:red;'>{token[2:]}</span>") | |
| elif token.startswith("- "): | |
| continue | |
| else: | |
| html_parts.append(token[2:]) | |
| return " ".join(html_parts) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # YouTube | |
| def download_audio(url): | |
| fname = f"yt_{uuid.uuid4().hex[:8]}.mp3" | |
| cmd = [yt_dlp_path,"-f","bestaudio","--extract-audio","--audio-format","mp3","-o",fname,url] | |
| res = subprocess.run(cmd, capture_output=True, text=True) | |
| if res.returncode!=0: raise RuntimeError(res.stderr) | |
| return fname | |
| def get_transcript(url, state): | |
| if state and state.get("url")==url: | |
| return state["orig"], state | |
| audio = download_audio(url) | |
| res = asr_model.transcribe(audio) | |
| orig = res.get("text","") | |
| os.remove(audio) | |
| return orig, {"url":url, "orig":orig} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # μμ ν μ²ν¬ μμ½ (model.generate μ§μ νΈμΆ) | |
| def summarize_long_text(text: str, label: str, chunk_size: int = 512) -> str: | |
| load_summarizer(label) | |
| tok = tokenizers[label] | |
| model= models[label] | |
| enc = tok(text, return_tensors="pt", truncation=False) | |
| ids = enc.input_ids[0] | |
| summaries = [] | |
| max_ctx = getattr(model.config, "max_position_embeddings", 1024) - 4 | |
| chunk_size = min(chunk_size, max_ctx) | |
| for i in range(0, len(ids), chunk_size): | |
| chunk_ids = ids[i:i+chunk_size].unsqueeze(0).to(device) | |
| out_ids = model.generate( | |
| chunk_ids, | |
| max_new_tokens=128, | |
| num_beams=4, | |
| do_sample=False | |
| ) | |
| summ = tok.decode(out_ids[0], skip_special_tokens=True) | |
| summaries.append(summ) | |
| combined = " ".join(summaries) | |
| enc2 = tok(combined, return_tensors="pt", truncation=True, max_length=max_ctx).to(device) | |
| out_ids = model.generate( | |
| **enc2, | |
| max_new_tokens=128, | |
| num_beams=4, | |
| do_sample=False | |
| ) | |
| final = tok.decode(out_ids[0], skip_special_tokens=True) | |
| return final | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def summarize_single(url, label, grammar_method, transcript_state): | |
| orig, new_state = get_transcript(url, transcript_state) | |
| corr = correct_text(orig, method=grammar_method) | |
| corr_rate = calculate_correction_rate(orig, corr) | |
| corr_html = f"<div><b>κ΅μ λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}" | |
| summary = summarize_long_text(corr, label) if len(corr) > 100 else "β οΈ μμ½ λΆκ°" | |
| rouge_vals=[0,0,0] | |
| if rouge_scorer and summary.strip(): | |
| sc = scorer.score(orig, summary) | |
| rouge_vals=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure] | |
| bert_f1=0 | |
| if bert_score_func and summary.strip(): | |
| try: | |
| _,_,F = bert_score_func([summary],[orig],lang="ko") | |
| except Exception: | |
| _,_,F = bert_score_func([summary],[orig],lang="en") | |
| bert_f1=float(F.mean()) | |
| fig,ax=plt.subplots() | |
| ax.bar(["R1","R2","RL","BERT-F1"], rouge_vals+[bert_f1]) | |
| ax.set_ylim(0,1); ax.set_ylabel("Score"); ax.set_title("Summary Fidelity") | |
| plt.tight_layout() | |
| return orig, corr_html, summary, fig, new_state | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def summarize_all(url, grammar_method, transcript_state): | |
| orig, new_state = get_transcript(url, transcript_state) | |
| corr = correct_text(orig, method=grammar_method) | |
| corr_rate = calculate_correction_rate(orig, corr) | |
| corr_html = f"<div><b>κ΅μ λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}" | |
| figs, interps, rv_list, bf_list = [], [], [], [] | |
| summaries_plain = [] | |
| labels = list(SUMMARY_MODELS.keys()) | |
| for label in labels: | |
| summ = summarize_long_text(corr, label) | |
| summaries_plain.append(summ) | |
| rv=[0,0,0]; bf=0 | |
| if rouge_scorer: | |
| sc = scorer.score(orig, summ) | |
| rv=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure] | |
| if bert_score_func: | |
| try: | |
| _,_,F = bert_score_func([summ],[orig],lang="ko") | |
| except Exception: | |
| _,_,F = bert_score_func([summ],[orig],lang="en") | |
| bf=float(F.mean()) | |
| rv_list.append(rv); bf_list.append(bf) | |
| fig,ax=plt.subplots() | |
| ax.bar(["R1","R2","RL","BERT-F1"], rv+[bf]) | |
| ax.set_ylim(0,1); ax.set_title(label) | |
| plt.tight_layout(); figs.append(fig) | |
| note="μ 보 μμ€ λ§μ" | |
| if bf>0.8: note="ν΅μ¬ μ 보 μ λ°μ" | |
| elif bf>0.5: note="μ£Όμ λ΄μ© ν¬ν¨" | |
| interps.append(f"{label}: {note} (F1={bf:.2f})") | |
| html = "<h3>λͺ¨λΈλ³ μμ½ & Fidelity Metrics</h3>" | |
| html+= f"<p><b>κ΅μ λ₯ :</b> {corr_rate}%</p>" | |
| html+= "<table border='1' style='border-collapse:collapse; width:100%; table-layout:fixed;'>" | |
| html+= "<tr><th style='width:12%'>λͺ¨λΈ</th><th style='width:58%'>μμ½λ¬Έ</th><th style='width:5%'>R1</th><th style='width:5%'>R2</th><th style='width:5%'>RL</th><th style='width:7%'>BERT-F1</th><th style='width:8%'>ν΄μ</th></tr>" | |
| for i,label in enumerate(labels): | |
| r1,r2,rl = rv_list[i] | |
| bf = bf_list[i] | |
| note = "μ 보 μμ€ λ§μ" | |
| if bf>0.8: note="ν΅μ¬ μ 보 μ λ°μ" | |
| elif bf>0.5: note="μ£Όμ λ΄μ© ν¬ν¨" | |
| summ_html = summaries_plain[i].replace("<", "<") | |
| html+= ( | |
| f"<tr>" | |
| f"<td>{label}</td>" | |
| f"<td style='white-space:pre-wrap; word-break:break-word'>{summ_html}</td>" | |
| f"<td>{r1:.2f}</td><td>{r2:.2f}</td><td>{rl:.2f}</td>" | |
| f"<td>{bf:.2f}</td><td>{note}</td>" | |
| f"</tr>" | |
| ) | |
| html+="</table>" | |
| return [orig, corr_html] + figs + interps + [html, new_state] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def save_summary(url, label): | |
| orig, _ = get_transcript(url, None) | |
| corr = correct_text(orig, "GEC-νκ΅μ΄") | |
| summary = summarize_long_text(corr, label) | |
| path = os.path.join(os.getcwd(), f"summary_{label}.txt") | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(summary) | |
| return path | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CSS (κ΅μ μλ§μ λ°μ€μ²λΌ 보μ΄κ²) | |
| CUSTOM_CSS = """ | |
| #corr_box, #corr_box_all { | |
| border: 1px solid #ccc; | |
| padding: 10px; | |
| border-radius: 6px; | |
| background-color: #fafafa; | |
| max-height: 300px; | |
| overflow-y: auto; | |
| white-space: pre-wrap; | |
| } | |
| """ | |
| # Gradio | |
| with gr.Blocks(css=CUSTOM_CSS) as demo: | |
| gr.Markdown("## π¬ YouTube μμ½ μλΉμ€ (κ΅μ + κ΅μ λ₯ + Diff κ°μ‘°, μμ μ²ν¬μμ½)") | |
| with gr.Tabs(): | |
| with gr.TabItem("λ¨μΌ λͺ¨λΈ μμ½"): | |
| url_input = gr.Textbox(label="YouTube URL") | |
| model_sel = gr.Dropdown(list(SUMMARY_MODELS.keys()), label="μμ½ λͺ¨λΈ") | |
| grammar_sel = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="κ΅μ λͺ¨λΈ", value="GEC-νκ΅μ΄") | |
| transcript_state = gr.State(None) | |
| btn_single = gr.Button("μμ½ μ€ν") | |
| orig_tb = gr.Textbox(label="μλ¬Έ μλ§", lines=10) | |
| corr_tb = gr.HTML(label="κ΅μ μλ§ (λ³κ²½μ κ°μ‘°)", elem_id="corr_box") | |
| sum_tb = gr.Textbox(label="μμ½ κ²°κ³Ό", lines=8) | |
| fidelity_plot = gr.Plot(label="Fidelity Metrics") | |
| save_btn = gr.Button("μμ½ μ μ₯") | |
| download_single = gr.File(label="λ€μ΄λ‘λ νμΌ") | |
| btn_single.click( | |
| fn=summarize_single, | |
| inputs=[url_input, model_sel, grammar_sel, transcript_state], | |
| outputs=[orig_tb, corr_tb, sum_tb, fidelity_plot, transcript_state] | |
| ) | |
| save_btn.click( | |
| fn=save_summary, | |
| inputs=[url_input, model_sel], | |
| outputs=[download_single] | |
| ) | |
| with gr.TabItem("μ 체 λͺ¨λΈ λΉκ΅"): | |
| url_all = gr.Textbox(label="YouTube URL") | |
| grammar_sel_all = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="κ΅μ λͺ¨λΈ", value="GEC-νκ΅μ΄") | |
| transcript_state_all = gr.State(None) | |
| btn_all = gr.Button("λͺ¨λ μ€ν") | |
| orig_all = gr.Textbox(label="μλ¬Έ μλ§", lines=10) | |
| corr_all = gr.HTML(label="κ΅μ μλ§ (λ³κ²½μ κ°μ‘°)", elem_id="corr_box_all") | |
| plot_components, interp_components = [], [] | |
| for label in SUMMARY_MODELS: | |
| plot_components.append(gr.Plot(label=f"{label} Metrics")) | |
| interp_components.append(gr.HTML(label=f"{label} ν΄μ")) | |
| agg_table = gr.HTML(label="λͺ¨λΈλ³ μμ½ & Fidelity Metrics") | |
| save_all_sel = gr.Radio(list(SUMMARY_MODELS.keys()), label="μ μ₯ λͺ¨λΈ μ§μ ") | |
| save_all_btn = gr.Button("μ ν μμ½ μ μ₯") | |
| download_all = gr.File(label="λ€μ΄λ‘λ νμΌ") | |
| btn_all.click( | |
| fn=summarize_all, | |
| inputs=[url_all, grammar_sel_all, transcript_state_all], | |
| outputs=[orig_all, corr_all] + plot_components + interp_components + [agg_table, transcript_state_all] | |
| ) | |
| save_all_btn.click( | |
| fn=save_summary, | |
| inputs=[url_all, save_all_sel], | |
| outputs=[download_all] | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch(server_name="0.0.0.0", share=True) |