11st_team_space / app.py
nangunan's picture
Update app.py
1acc3a8 verified
import os
import zipfile
import requests
import gradio as gr
import whisper
import subprocess
import uuid
import torch
import re
import matplotlib.pyplot as plt
import language_tool_python
import difflib
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline as hf_pipeline,
)
# ────────────────────────────────────────────────────────────────
# Optional evaluation libraries
try:
from rouge_score import rouge_scorer
except ImportError:
rouge_scorer = None
print("[Warning] rouge_score νŒ¨ν‚€μ§€κ°€ μ—†μŠ΅λ‹ˆλ‹€. pip install rouge-score")
try:
from bert_score import score as bert_score_func
except ImportError:
bert_score_func = None
print("[Warning] bert-score νŒ¨ν‚€μ§€κ°€ μ—†μŠ΅λ‹ˆλ‹€. pip install bert-score")
# ────────────────────────────────────────────────────────────────
# ν•œκΈ€ λ§žμΆ€λ²• 검사(py‑hanspell)
try:
from hanspell import spell_checker
except ImportError:
spell_checker = None
# ────────────────────────────────────────────────────────────────
# LanguageTool λ£° 기반 ꡐ정 (μ˜μ–΄ μ „μš©)
try:
lt_tool = language_tool_python.LanguageTool('en-US')
except Exception as e:
lt_tool = None
print(f"[Warning] LanguageTool μ΄ˆκΈ°ν™” μ‹€νŒ¨: {e}")
# ────────────────────────────────────────────────────────────────
# FFmpeg
yt_dlp_path = "C:/Windows/System32/yt-dlp.exe"
ffmpeg_path = "C:/ProgramData/chocolatey/bin"
def download_ffmpeg(dest_bin):
if os.path.isdir(dest_bin) and os.path.isfile(os.path.join(dest_bin, "ffmpeg.exe")):
return dest_bin
url = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip"
zip_path = os.path.join(os.getcwd(), "ffmpeg.zip")
extract_root = os.path.dirname(dest_bin)
os.makedirs(extract_root, exist_ok=True)
resp = requests.get(url, stream=True); resp.raise_for_status()
with open(zip_path, "wb") as f:
for chunk in resp.iter_content(8192): f.write(chunk)
with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_root)
os.remove(zip_path)
for root, _, files in os.walk(extract_root):
if "ffmpeg.exe" in files:
os.makedirs(dest_bin, exist_ok=True)
for fn in ("ffmpeg.exe","ffprobe.exe","ffplay.exe"):
src, dst = os.path.join(root,fn), os.path.join(dest_bin,fn)
if os.path.isfile(src): os.replace(src, dst)
return dest_bin
raise RuntimeError("FFmpeg μ„€μΉ˜ μ‹€νŒ¨")
download_ffmpeg(ffmpeg_path)
os.environ["PATH"] = ffmpeg_path + os.pathsep + os.environ.get("PATH","")
# ────────────────────────────────────────────────────────────────
# Whisper
asr_model = whisper.load_model("medium")
# ────────────────────────────────────────────────────────────────
# μš”μ•½ λͺ¨λΈ(λͺ¨λΈ/ν† ν¬λ‚˜μ΄μ € 직접 μ‚¬μš©, pipeline X)
SUMMARY_MODELS = {
"mT5_multilingual_XLSum": "csebuetnlp/mT5_multilingual_XLSum",
"Pegasus XSum": "google/pegasus-xsum",
"BART-large CNN": "facebook/bart-large-cnn",
"DistilBART CNN": "sshleifer/distilbart-cnn-12-6"
}
tokenizers, models = {}, {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_summarizer(label: str):
if label in models:
return
repo = SUMMARY_MODELS[label]
tok = AutoTokenizer.from_pretrained(repo, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(repo).to(device)
model.eval()
tokenizers[label] = tok
models[label] = model
if rouge_scorer:
scorer = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True)
# ────────────────────────────────────────────────────────────────
# 문법 ꡐ정
GRAMMAR_MODELS = {
"LanguageTool-en": None,
"py-hanspell": None,
"GEC-ν•œκ΅­μ–΄": "Soyoung97/gec_kr"
}
grammar_pipes = {}
def load_grammar_pipe(name: str):
repo = GRAMMAR_MODELS[name]
grammar_pipes[name] = hf_pipeline(
"text2text-generation",
model=repo,
tokenizer=AutoTokenizer.from_pretrained(repo),
device=0 if torch.cuda.is_available() else -1
)
def correct_spelling(text, max_chunk=500):
if not spell_checker: return text
parts, curr = re.split(r'([.?!]\s*)', text), ""
segs, out = [], []
for p in parts:
if len(curr)+len(p) <= max_chunk: curr += p
else: segs.append(curr); curr = p
if curr: segs.append(curr)
for s in segs:
try: out.append(spell_checker.check(s).checked)
except: out.append(s)
return " ".join(o.strip() for o in out)
def correct_text(text, method="GEC-ν•œκ΅­μ–΄"):
if method=="py-hanspell":
return correct_spelling(text)
if method=="LanguageTool-en" and lt_tool:
matches = lt_tool.check(text)
return language_tool_python.utils.correct(text, matches)
if method=="GEC-ν•œκ΅­μ–΄":
if method not in grammar_pipes:
load_grammar_pipe(method)
pipe = grammar_pipes[method]
sents = re.split(r'(?<=[.?!])\s+', text)
corrected=[]
for sent in sents:
gen = pipe(sent, max_length=256, min_length=1, do_sample=False)[0]["generated_text"]
corrected.append(gen.strip())
return " ".join(corrected)
return text
# ────────────────────────────────────────────────────────────────
# ꡐ정λ₯  + Diff
def calculate_correction_rate(original, corrected):
orig_tokens = original.split()
corr_tokens = corrected.split()
sm = difflib.SequenceMatcher(None, orig_tokens, corr_tokens)
diff_count = sum((i2 - i1) for tag, i1, i2, j1, j2 in sm.get_opcodes() if tag != 'equal')
total = max(len(orig_tokens), 1)
return round(100 * diff_count / total, 2)
def highlight_diff(original: str, corrected: str) -> str:
diff = difflib.ndiff(original.split(), corrected.split())
html_parts = []
for token in diff:
if token.startswith("+ "):
html_parts.append(f"<span style='color:red;'>{token[2:]}</span>")
elif token.startswith("- "):
continue
else:
html_parts.append(token[2:])
return " ".join(html_parts)
# ────────────────────────────────────────────────────────────────
# YouTube
def download_audio(url):
fname = f"yt_{uuid.uuid4().hex[:8]}.mp3"
cmd = [yt_dlp_path,"-f","bestaudio","--extract-audio","--audio-format","mp3","-o",fname,url]
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode!=0: raise RuntimeError(res.stderr)
return fname
def get_transcript(url, state):
if state and state.get("url")==url:
return state["orig"], state
audio = download_audio(url)
res = asr_model.transcribe(audio)
orig = res.get("text","")
os.remove(audio)
return orig, {"url":url, "orig":orig}
# ────────────────────────────────────────────────────────────────
# μ•ˆμ „ν•œ 청크 μš”μ•½ (model.generate 직접 호좜)
def summarize_long_text(text: str, label: str, chunk_size: int = 512) -> str:
load_summarizer(label)
tok = tokenizers[label]
model= models[label]
enc = tok(text, return_tensors="pt", truncation=False)
ids = enc.input_ids[0]
summaries = []
max_ctx = getattr(model.config, "max_position_embeddings", 1024) - 4
chunk_size = min(chunk_size, max_ctx)
for i in range(0, len(ids), chunk_size):
chunk_ids = ids[i:i+chunk_size].unsqueeze(0).to(device)
out_ids = model.generate(
chunk_ids,
max_new_tokens=128,
num_beams=4,
do_sample=False
)
summ = tok.decode(out_ids[0], skip_special_tokens=True)
summaries.append(summ)
combined = " ".join(summaries)
enc2 = tok(combined, return_tensors="pt", truncation=True, max_length=max_ctx).to(device)
out_ids = model.generate(
**enc2,
max_new_tokens=128,
num_beams=4,
do_sample=False
)
final = tok.decode(out_ids[0], skip_special_tokens=True)
return final
# ────────────────────────────────────────────────────────────────
def summarize_single(url, label, grammar_method, transcript_state):
orig, new_state = get_transcript(url, transcript_state)
corr = correct_text(orig, method=grammar_method)
corr_rate = calculate_correction_rate(orig, corr)
corr_html = f"<div><b>ꡐ정λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}"
summary = summarize_long_text(corr, label) if len(corr) > 100 else "⚠️ μš”μ•½ λΆˆκ°€"
rouge_vals=[0,0,0]
if rouge_scorer and summary.strip():
sc = scorer.score(orig, summary)
rouge_vals=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure]
bert_f1=0
if bert_score_func and summary.strip():
try:
_,_,F = bert_score_func([summary],[orig],lang="ko")
except Exception:
_,_,F = bert_score_func([summary],[orig],lang="en")
bert_f1=float(F.mean())
fig,ax=plt.subplots()
ax.bar(["R1","R2","RL","BERT-F1"], rouge_vals+[bert_f1])
ax.set_ylim(0,1); ax.set_ylabel("Score"); ax.set_title("Summary Fidelity")
plt.tight_layout()
return orig, corr_html, summary, fig, new_state
# ────────────────────────────────────────────────────────────────
def summarize_all(url, grammar_method, transcript_state):
orig, new_state = get_transcript(url, transcript_state)
corr = correct_text(orig, method=grammar_method)
corr_rate = calculate_correction_rate(orig, corr)
corr_html = f"<div><b>ꡐ정λ₯ :</b> {corr_rate}%</div><hr/>{highlight_diff(orig, corr)}"
figs, interps, rv_list, bf_list = [], [], [], []
summaries_plain = []
labels = list(SUMMARY_MODELS.keys())
for label in labels:
summ = summarize_long_text(corr, label)
summaries_plain.append(summ)
rv=[0,0,0]; bf=0
if rouge_scorer:
sc = scorer.score(orig, summ)
rv=[sc["rouge1"].fmeasure, sc["rouge2"].fmeasure, sc["rougeL"].fmeasure]
if bert_score_func:
try:
_,_,F = bert_score_func([summ],[orig],lang="ko")
except Exception:
_,_,F = bert_score_func([summ],[orig],lang="en")
bf=float(F.mean())
rv_list.append(rv); bf_list.append(bf)
fig,ax=plt.subplots()
ax.bar(["R1","R2","RL","BERT-F1"], rv+[bf])
ax.set_ylim(0,1); ax.set_title(label)
plt.tight_layout(); figs.append(fig)
note="정보 손싀 많음"
if bf>0.8: note="핡심 정보 잘 반영"
elif bf>0.5: note="μ£Όμš” λ‚΄μš© 포함"
interps.append(f"{label}: {note} (F1={bf:.2f})")
html = "<h3>λͺ¨λΈλ³„ μš”μ•½ & Fidelity Metrics</h3>"
html+= f"<p><b>ꡐ정λ₯ :</b> {corr_rate}%</p>"
html+= "<table border='1' style='border-collapse:collapse; width:100%; table-layout:fixed;'>"
html+= "<tr><th style='width:12%'>λͺ¨λΈ</th><th style='width:58%'>μš”μ•½λ¬Έ</th><th style='width:5%'>R1</th><th style='width:5%'>R2</th><th style='width:5%'>RL</th><th style='width:7%'>BERT-F1</th><th style='width:8%'>해석</th></tr>"
for i,label in enumerate(labels):
r1,r2,rl = rv_list[i]
bf = bf_list[i]
note = "정보 손싀 많음"
if bf>0.8: note="핡심 정보 잘 반영"
elif bf>0.5: note="μ£Όμš” λ‚΄μš© 포함"
summ_html = summaries_plain[i].replace("<", "&lt;")
html+= (
f"<tr>"
f"<td>{label}</td>"
f"<td style='white-space:pre-wrap; word-break:break-word'>{summ_html}</td>"
f"<td>{r1:.2f}</td><td>{r2:.2f}</td><td>{rl:.2f}</td>"
f"<td>{bf:.2f}</td><td>{note}</td>"
f"</tr>"
)
html+="</table>"
return [orig, corr_html] + figs + interps + [html, new_state]
# ────────────────────────────────────────────────────────────────
def save_summary(url, label):
orig, _ = get_transcript(url, None)
corr = correct_text(orig, "GEC-ν•œκ΅­μ–΄")
summary = summarize_long_text(corr, label)
path = os.path.join(os.getcwd(), f"summary_{label}.txt")
with open(path, "w", encoding="utf-8") as f:
f.write(summary)
return path
# ────────────────────────────────────────────────────────────────
# CSS (ꡐ정 μžλ§‰μ„ λ°•μŠ€μ²˜λŸΌ 보이게)
CUSTOM_CSS = """
#corr_box, #corr_box_all {
border: 1px solid #ccc;
padding: 10px;
border-radius: 6px;
background-color: #fafafa;
max-height: 300px;
overflow-y: auto;
white-space: pre-wrap;
}
"""
# Gradio
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.Markdown("## 🎬 YouTube μš”μ•½ μ„œλΉ„μŠ€ (ꡐ정 + ꡐ정λ₯  + Diff κ°•μ‘°, μ•ˆμ „ μ²­ν¬μš”μ•½)")
with gr.Tabs():
with gr.TabItem("단일 λͺ¨λΈ μš”μ•½"):
url_input = gr.Textbox(label="YouTube URL")
model_sel = gr.Dropdown(list(SUMMARY_MODELS.keys()), label="μš”μ•½ λͺ¨λΈ")
grammar_sel = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="ꡐ정 λͺ¨λΈ", value="GEC-ν•œκ΅­μ–΄")
transcript_state = gr.State(None)
btn_single = gr.Button("μš”μ•½ μ‹€ν–‰")
orig_tb = gr.Textbox(label="원문 μžλ§‰", lines=10)
corr_tb = gr.HTML(label="ꡐ정 μžλ§‰ (변경점 κ°•μ‘°)", elem_id="corr_box")
sum_tb = gr.Textbox(label="μš”μ•½ κ²°κ³Ό", lines=8)
fidelity_plot = gr.Plot(label="Fidelity Metrics")
save_btn = gr.Button("μš”μ•½ μ €μž₯")
download_single = gr.File(label="λ‹€μš΄λ‘œλ“œ 파일")
btn_single.click(
fn=summarize_single,
inputs=[url_input, model_sel, grammar_sel, transcript_state],
outputs=[orig_tb, corr_tb, sum_tb, fidelity_plot, transcript_state]
)
save_btn.click(
fn=save_summary,
inputs=[url_input, model_sel],
outputs=[download_single]
)
with gr.TabItem("전체 λͺ¨λΈ 비ꡐ"):
url_all = gr.Textbox(label="YouTube URL")
grammar_sel_all = gr.Dropdown(list(GRAMMAR_MODELS.keys()), label="ꡐ정 λͺ¨λΈ", value="GEC-ν•œκ΅­μ–΄")
transcript_state_all = gr.State(None)
btn_all = gr.Button("λͺ¨λ‘ μ‹€ν–‰")
orig_all = gr.Textbox(label="원문 μžλ§‰", lines=10)
corr_all = gr.HTML(label="ꡐ정 μžλ§‰ (변경점 κ°•μ‘°)", elem_id="corr_box_all")
plot_components, interp_components = [], []
for label in SUMMARY_MODELS:
plot_components.append(gr.Plot(label=f"{label} Metrics"))
interp_components.append(gr.HTML(label=f"{label} 해석"))
agg_table = gr.HTML(label="λͺ¨λΈλ³„ μš”μ•½ & Fidelity Metrics")
save_all_sel = gr.Radio(list(SUMMARY_MODELS.keys()), label="μ €μž₯ λͺ¨λΈ μ§€μ •")
save_all_btn = gr.Button("선택 μš”μ•½ μ €μž₯")
download_all = gr.File(label="λ‹€μš΄λ‘œλ“œ 파일")
btn_all.click(
fn=summarize_all,
inputs=[url_all, grammar_sel_all, transcript_state_all],
outputs=[orig_all, corr_all] + plot_components + interp_components + [agg_table, transcript_state_all]
)
save_all_btn.click(
fn=save_summary,
inputs=[url_all, save_all_sel],
outputs=[download_all]
)
if __name__ == '__main__':
demo.launch(server_name="0.0.0.0", share=True)