import os import zipfile import requests import gradio as gr import whisper import subprocess import uuid import torch import re import matplotlib.pyplot as plt import language_tool_python import difflib from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as hf_pipeline, ) # ──────────────────────────────────────────────────────────────── # Optional evaluation libraries try: from rouge_score import rouge_scorer except ImportError: rouge_scorer = None print("[Warning] rouge_score 패키지가 없습니다. pip install rouge-score") try: from bert_score import score as bert_score_func except ImportError: bert_score_func = None print("[Warning] bert-score 패키지가 없습니다. pip install bert-score") # ──────────────────────────────────────────────────────────────── # 한글 맞춤법 검사(py‑hanspell) try: from hanspell import spell_checker except ImportError: spell_checker = None # ──────────────────────────────────────────────────────────────── # LanguageTool 룰 기반 교정 (영어 전용) try: lt_tool = language_tool_python.LanguageTool('en-US') except Exception as e: lt_tool = None print(f"[Warning] LanguageTool 초기화 실패: {e}") # ──────────────────────────────────────────────────────────────── # FFmpeg yt_dlp_path = "C:/Windows/System32/yt-dlp.exe" ffmpeg_path = "C:/ProgramData/chocolatey/bin" def download_ffmpeg(dest_bin): if os.path.isdir(dest_bin) and os.path.isfile(os.path.join(dest_bin, "ffmpeg.exe")): return dest_bin url = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip" zip_path = os.path.join(os.getcwd(), "ffmpeg.zip") extract_root = os.path.dirname(dest_bin) os.makedirs(extract_root, exist_ok=True) resp = requests.get(url, stream=True); resp.raise_for_status() with open(zip_path, "wb") as f: for chunk in resp.iter_content(8192): f.write(chunk) with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_root) os.remove(zip_path) for root, _, files in os.walk(extract_root): if "ffmpeg.exe" in files: os.makedirs(dest_bin, exist_ok=True) for fn in ("ffmpeg.exe","ffprobe.exe","ffplay.exe"): src, dst = os.path.join(root,fn), os.path.join(dest_bin,fn) if os.path.isfile(src): os.replace(src, dst) return dest_bin raise RuntimeError("FFmpeg 설치 실패") download_ffmpeg(ffmpeg_path) os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ.get("PATH","") # ──────────────────────────────────────────────────────────────── # Whisper asr_model = whisper.load_model("medium") # ──────────────────────────────────────────────────────────────── # 요약 모델(모델/토크나이저 직접 사용, pipeline X) SUMMARY_MODELS = { "mT5_multilingual_XLSum": "csebuetnlp/mT5_multilingual_XLSum", "Pegasus XSum": "google/pegasus-xsum", "BART-large CNN": "facebook/bart-large-cnn", "DistilBART CNN": "sshleifer/distilbart-cnn-12-6" } tokenizers, models = {}, {} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_summarizer(label: str): if label in models: return repo = SUMMARY_MODELS[label] tok = AutoTokenizer.from_pretrained(repo, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device) model.eval() tokenizers[label] = tok models[label] = model if rouge_scorer: scorer = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True) # ──────────────────────────────────────────────────────────────── # 문법 교정 GRAMMAR_MODELS = { "LanguageTool-en": None, "py-hanspell": None, "GEC-한국어": "Soyoung97/gec_kr" } grammar_pipes = {} def load_grammar_pipe(name: str): repo = GRAMMAR_MODELS[name] grammar_pipes[name] = hf_pipeline( "text2text-generation", model=repo, tokenizer=AutoTokenizer.from_pretrained(repo), device=0 if torch.cuda.is_available() else -1 ) def correct_spelling(text, max_chunk=500): if not spell_checker: return text parts, curr = re.split(r'([.?!]\s*)', text), "" segs, out = [], [] for p in parts: if len(curr)+len(p) <= max_chunk: curr += p else: segs.append(curr); curr = p if curr: segs.append(curr) for s in segs: try: out.append(spell_checker.check(s).checked) except: out.append(s) return " ".join(o.strip() for o in out) def correct_text(text, method="GEC-한국어"): if method=="py-hanspell": return correct_spelling(text) if method=="LanguageTool-en" and lt_tool: matches = lt_tool.check(text) return language_tool_python.utils.correct(text, matches) if method=="GEC-한국어": if method not in grammar_pipes: load_grammar_pipe(method) pipe = grammar_pipes[method] sents = re.split(r'(?<=[.?!])\s+', text) corrected=[] for sent in sents: gen = pipe(sent, max_length=256, min_length=1, do_sample=False)[0]["generated_text"] corrected.append(gen.strip()) return " ".join(corrected) return text # ──────────────────────────────────────────────────────────────── # 교정률 + Diff def calculate_correction_rate(original, corrected): orig_tokens = original.split() corr_tokens = corrected.split() sm = difflib.SequenceMatcher(None, orig_tokens, corr_tokens) diff_count = sum((i2 - i1) for tag, i1, i2, j1, j2 in sm.get_opcodes() if tag != 'equal') total = max(len(orig_tokens), 1) return round(100 * diff_count / total, 2) def highlight_diff(original: str, corrected: str) -> str: diff = difflib.ndiff(original.split(), corrected.split()) html_parts = [] for token in diff: if token.startswith("+ "): html_parts.append(f"{token[2:]}") elif token.startswith("- "): continue else: html_parts.append(token[2:]) return " ".join(html_parts) # ──────────────────────────────────────────────────────────────── # YouTube def download_audio(url): fname = f"yt_{uuid.uuid4().hex[:8]}.mp3" cmd = [yt_dlp_path,"-f","bestaudio","--extract-audio","--audio-format","mp3","-o",fname,url] res = subprocess.run(cmd, capture_output=True, text=True) if res.returncode!=0: raise RuntimeError(res.stderr) return fname def get_transcript(url, state): if state and state.get("url")==url: return state["orig"], state audio = download_audio(url) res = asr_model.transcribe(audio) orig = res.get("text","") os.remove(audio) return orig, {"url":url, "orig":orig} # ──────────────────────────────────────────────────────────────── # 안전한 청크 요약 (model.generate 직접 호출) def summarize_long_text(text: str, label: str, chunk_size: int = 512) -> str: load_summarizer(label) tok = tokenizers[label] model= models[label] enc = tok(text, return_tensors="pt", truncation=False) ids = enc.input_ids[0] summaries = [] max_ctx = getattr(model.config, "max_position_embeddings", 1024) - 4 chunk_size = min(chunk_size, max_ctx) for i in range(0, len(ids), chunk_size): chunk_ids = ids[i:i+chunk_size].unsqueeze(0).to(device) out_ids = model.generate( chunk_ids, max_new_tokens=128, num_beams=4, do_sample=False ) summ = tok.decode(out_ids[0], skip_special_tokens=True) summaries.append(summ) combined = " ".join(summaries) enc2 = tok(combined, return_tensors="pt", truncation=True, max_length=max_ctx).to(device) out_ids = model.generate( **enc2, max_new_tokens=128, num_beams=4, do_sample=False ) final = tok.decode(out_ids[0], skip_special_tokens=True) return final # ──────────────────────────────────────────────────────────────── def summarize_single(url, label, grammar_method, transcript_state): orig, new_state = get_transcript(url, transcript_state) corr = correct_text(orig, method=grammar_method) corr_rate = calculate_correction_rate(orig, corr) corr_html = f"
교정률: {corr_rate}%
" html+= "| 모델 | 요약문 | R1 | R2 | RL | BERT-F1 | 해석 |
|---|---|---|---|---|---|---|
| {label} | " f"{summ_html} | " f"{r1:.2f} | {r2:.2f} | {rl:.2f} | " f"{bf:.2f} | {note} | " f"