jade2zhong's picture
Update app.py
b16a677 verified
import json
import os
import re
from pathlib import Path
import gradio as gr
import numpy as np
import pdfplumber
from docx import Document
from openai import OpenAI
from sentence_transformers import SentenceTransformer
from transformers import pipeline
EMBEDDING_MODEL = os.getenv(
"EMBEDDING_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct-1M")
HF_TOKEN = os.getenv("HF_TOKEN")
DEFAULT_MULTILINGUAL_ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-small")
ASR_PROFILES = {
"English optimized - Whisper small.en": {
"model": os.getenv("ASR_MODEL_EN", "openai/whisper-small.en"),
"language": None,
"description": "Best default for English-only lectures and presentations.",
},
"Chinese - Whisper multilingual small": {
"model": os.getenv("ASR_MODEL_ZH", DEFAULT_MULTILINGUAL_ASR_MODEL),
"language": "chinese",
"description": "Use this for Mandarin recordings and Chinese documents.",
},
"Auto detect - Whisper multilingual small": {
"model": os.getenv("ASR_MODEL_AUTO", DEFAULT_MULTILINGUAL_ASR_MODEL),
"language": None,
"description": "Use this when the recording language is uncertain or mixed.",
},
}
asr_pipelines = {}
embedding_model = None
llm_client = None
APP_CSS = """
:root {
--brand: #0f766e;
--brand-strong: #115e59;
--ink: #111827;
--muted: #64748b;
--line: #d8ded9;
--paper: #ffffff;
--wash: #f6f7f2;
--accent: #c2410c;
}
body,
.gradio-container {
background: var(--wash) !important;
color: var(--ink);
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
}
.main {
max-width: 1180px !important;
margin: 0 auto !important;
}
.app-shell {
padding: 28px 28px 12px;
border-bottom: 1px solid var(--line);
}
.app-kicker {
margin: 0 0 8px;
color: var(--brand-strong);
font-size: 12px;
font-weight: 700;
letter-spacing: 0.08em;
text-transform: uppercase;
}
.app-title {
margin: 0;
color: var(--ink);
font-size: 34px;
line-height: 1.12;
letter-spacing: 0;
}
.app-subtitle {
margin: 12px 0 0;
max-width: 780px;
color: var(--muted);
font-size: 16px;
line-height: 1.6;
}
.status-strip {
display: grid;
grid-template-columns: repeat(3, minmax(0, 1fr));
gap: 10px;
margin-top: 20px;
}
.status-item {
background: #ffffff;
border: 1px solid var(--line);
border-radius: 8px;
padding: 12px 14px;
}
.status-label {
color: var(--muted);
font-size: 12px;
margin-bottom: 4px;
}
.status-value {
color: var(--ink);
font-weight: 700;
font-size: 14px;
}
.gradio-container .block {
border-radius: 8px !important;
}
.gradio-container button.primary {
background: var(--brand) !important;
border-color: var(--brand) !important;
}
.gradio-container button.primary:hover {
background: var(--brand-strong) !important;
border-color: var(--brand-strong) !important;
}
textarea,
input,
.wrap {
border-radius: 8px !important;
}
.output-panel textarea {
font-size: 14px !important;
line-height: 1.55 !important;
}
.correction-notes,
.evidence-panel {
background: var(--paper);
}
@media (max-width: 760px) {
.app-shell {
padding: 22px 18px 8px;
}
.app-title {
font-size: 28px;
}
.status-strip {
grid-template-columns: 1fr;
}
}
"""
def get_asr_pipeline(model_id: str):
if model_id not in asr_pipelines:
asr_pipelines[model_id] = pipeline(
"automatic-speech-recognition",
model=model_id,
device=-1,
)
return asr_pipelines[model_id]
def get_embedding_model():
global embedding_model
if embedding_model is None:
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
return embedding_model
def get_llm_client():
global llm_client
if not HF_TOKEN:
return None
if llm_client is None:
llm_client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=HF_TOKEN,
)
return llm_client
def read_text_file(path: Path) -> str:
for encoding in ("utf-8", "gb18030"):
try:
return path.read_text(encoding=encoding)
except UnicodeDecodeError:
continue
return path.read_text(errors="ignore")
def extract_document_text(file_path: str) -> str:
path = Path(file_path)
suffix = path.suffix.lower()
if suffix == ".txt":
text = read_text_file(path)
elif suffix == ".pdf":
pages = []
with pdfplumber.open(path) as pdf:
for page in pdf.pages:
pages.append(page.extract_text() or "")
text = "\n".join(pages)
elif suffix == ".docx":
doc = Document(path)
text = "\n".join(p.text for p in doc.paragraphs)
else:
raise ValueError("Only PDF, DOCX, and TXT documents are supported.")
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def split_into_chunks(text: str, max_chars: int = 700, overlap: int = 90) -> list[str]:
paragraphs = re.split(r"\n\s*\n+", text)
pieces = []
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
pieces.extend(re.split(r"(?<=[.!?;:])\s+", paragraph))
pieces = [p.strip() for p in pieces if p and p.strip()]
chunks = []
current = ""
for piece in pieces:
if len(piece) > max_chars:
if current:
chunks.append(current)
current = ""
step = max_chars - overlap
for start in range(0, len(piece), step):
chunks.append(piece[start : start + max_chars])
continue
candidate = piece if not current else f"{current}\n{piece}"
if len(candidate) <= max_chars:
current = candidate
else:
chunks.append(current)
current = piece
if current:
chunks.append(current)
return [chunk for chunk in chunks if len(chunk) >= 20]
def resolve_asr_profile(profile_name: str) -> dict:
return ASR_PROFILES.get(profile_name, next(iter(ASR_PROFILES.values())))
def transcribe_audio(audio_path: str, profile_name: str) -> str:
profile = resolve_asr_profile(profile_name)
generate_kwargs = {"task": "transcribe"}
if profile["language"]:
generate_kwargs["language"] = profile["language"]
result = get_asr_pipeline(profile["model"])(audio_path, generate_kwargs=generate_kwargs)
if isinstance(result, dict):
return str(result.get("text", "")).strip()
return str(result).strip()
def retrieve_contexts(raw_transcript: str, chunks: list[str], top_k: int):
model = get_embedding_model()
doc_vectors = model.encode(chunks, normalize_embeddings=True)
query_vector = model.encode([raw_transcript], normalize_embeddings=True)[0]
scores = np.matmul(doc_vectors, query_vector)
top_indices = np.argsort(scores)[::-1][:top_k]
return [(int(i), float(scores[i]), chunks[int(i)]) for i in top_indices]
def build_correction_prompt(raw_transcript: str, contexts) -> list[dict]:
context_text = "\n\n".join(
f"[Document passage {idx + 1} | similarity {score:.3f}]\n{text}"
for idx, score, text in contexts
)
system_prompt = (
"You are a strict ASR correction assistant. Correct the transcript only when the "
"provided document context gives clear evidence. Focus on homophones, near-sound "
"mistakes, technical terms, names, acronyms, chapter titles, and domain-specific "
"phrases. Preserve the original sentence structure as much as possible. Do not "
"summarize, rewrite freely, or add information that was not spoken."
)
user_prompt = f"""
Correct the ASR transcript using the document passages below.
Rules:
1. Treat the raw transcript as the primary text.
2. Make only evidence-backed corrections.
3. Prefer keeping the original word when the document context is not strong enough.
4. Output JSON only. Do not output Markdown.
JSON schema:
{{
"corrected_text": "the complete corrected transcript",
"changes": [
{{
"original": "incorrect word or phrase",
"corrected": "corrected word or phrase",
"reason": "why the document supports this correction"
}}
]
}}
Document passages:
{context_text}
Raw ASR transcript:
{raw_transcript}
""".strip()
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
def parse_json_response(text: str):
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, flags=re.S)
if match:
return json.loads(match.group(0))
raise ValueError("The language model did not return valid JSON.")
def correct_with_llm(raw_transcript: str, contexts):
client = get_llm_client()
if client is None:
return {
"corrected_text": raw_transcript,
"changes": [
{
"original": "LLM correction skipped",
"corrected": "LLM correction skipped",
"reason": "HF_TOKEN is not set. Add HF_TOKEN locally or in Hugging Face Spaces secrets.",
}
],
}
completion = client.chat.completions.create(
model=LLM_MODEL,
messages=build_correction_prompt(raw_transcript, contexts),
temperature=0.1,
max_tokens=1200,
)
content = completion.choices[0].message.content
return parse_json_response(content)
def format_contexts(contexts) -> str:
blocks = []
for rank, (idx, score, text) in enumerate(contexts, start=1):
blocks.append(f"### Passage {rank}\nSimilarity: `{score:.3f}`\n\n{text}")
return "\n\n---\n\n".join(blocks)
def format_changes(changes) -> str:
if not changes:
return "No document-backed correction was needed."
lines = []
for item in changes:
original = item.get("original", "")
corrected = item.get("corrected", "")
reason = item.get("reason", "")
lines.append(f"- `{original}` -> `{corrected}`: {reason}")
return "\n".join(lines)
def run_app(document_file, audio_file, profile_name, top_k):
if document_file is None:
raise gr.Error("Upload a PDF, DOCX, or TXT reference document first.")
if audio_file is None:
raise gr.Error("Upload or record an audio sample first.")
document_text = extract_document_text(document_file)
if not document_text:
raise gr.Error("No text was extracted from the document. Scanned PDFs need OCR first.")
chunks = split_into_chunks(document_text)
if not chunks:
raise gr.Error("The document is too short to build context.")
raw_transcript = transcribe_audio(audio_file, profile_name)
if not raw_transcript:
raise gr.Error("No speech text was recognized from the audio.")
contexts = retrieve_contexts(raw_transcript, chunks, int(top_k))
correction = correct_with_llm(raw_transcript, contexts)
corrected_text = correction.get("corrected_text", raw_transcript)
changes = correction.get("changes", [])
return (
raw_transcript,
corrected_text,
format_changes(changes),
format_contexts(contexts),
)
theme = gr.themes.Soft(
primary_hue="teal",
secondary_hue="orange",
neutral_hue="zinc",
radius_size="sm",
)
with gr.Blocks(
title="Context-Aware Audio Correction",
theme=theme,
css=APP_CSS,
) as demo:
gr.HTML(
"""
<section class="app-shell">
<p class="app-kicker">Hugging Face ASR + document retrieval</p>
<h1 class="app-title">Context-Aware Audio Correction</h1>
<p class="app-subtitle">
Upload a reference document and an audio clip. The app transcribes speech,
retrieves matching document passages, and corrects likely ASR mistakes using
only document-backed evidence.
</p>
<div class="status-strip">
<div class="status-item">
<div class="status-label">ASR profiles</div>
<div class="status-value">English / Chinese / Auto</div>
</div>
<div class="status-item">
<div class="status-label">Context engine</div>
<div class="status-value">Sentence embeddings</div>
</div>
<div class="status-item">
<div class="status-label">Correction policy</div>
<div class="status-value">Evidence-bound</div>
</div>
</div>
</section>
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
document_input = gr.File(
label="Reference document",
file_types=[".pdf", ".docx", ".txt"],
type="filepath",
)
audio_input = gr.Audio(
label="Audio sample",
sources=["upload", "microphone"],
type="filepath",
)
with gr.Column(scale=1, min_width=320):
profile_input = gr.Radio(
label="Recognition profile",
choices=list(ASR_PROFILES.keys()),
value="English optimized - Whisper small.en",
info=(
"English uses an English-only Whisper model. Chinese and Auto use "
"the multilingual Whisper model."
),
)
top_k_input = gr.Slider(
label="Document passages to retrieve",
minimum=1,
maximum=8,
value=4,
step=1,
)
submit_button = gr.Button("Transcribe and correct", variant="primary")
with gr.Row(elem_classes=["output-panel"]):
raw_output = gr.Textbox(label="Raw Whisper transcript", lines=9)
corrected_output = gr.Textbox(label="Context-corrected transcript", lines=9)
changes_output = gr.Markdown(
label="Correction notes",
elem_classes=["correction-notes"],
)
contexts_output = gr.Markdown(
label="Document evidence",
elem_classes=["evidence-panel"],
)
submit_button.click(
fn=run_app,
inputs=[document_input, audio_input, profile_input, top_k_input],
outputs=[raw_output, corrected_output, changes_output, contexts_output],
)
if __name__ == "__main__":
demo.launch(share=True)