import json import os import re from pathlib import Path import gradio as gr import numpy as np import pdfplumber from docx import Document from openai import OpenAI from sentence_transformers import SentenceTransformer from transformers import pipeline EMBEDDING_MODEL = os.getenv( "EMBEDDING_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" ) LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct-1M") HF_TOKEN = os.getenv("HF_TOKEN") DEFAULT_MULTILINGUAL_ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-small") ASR_PROFILES = { "English optimized - Whisper small.en": { "model": os.getenv("ASR_MODEL_EN", "openai/whisper-small.en"), "language": None, "description": "Best default for English-only lectures and presentations.", }, "Chinese - Whisper multilingual small": { "model": os.getenv("ASR_MODEL_ZH", DEFAULT_MULTILINGUAL_ASR_MODEL), "language": "chinese", "description": "Use this for Mandarin recordings and Chinese documents.", }, "Auto detect - Whisper multilingual small": { "model": os.getenv("ASR_MODEL_AUTO", DEFAULT_MULTILINGUAL_ASR_MODEL), "language": None, "description": "Use this when the recording language is uncertain or mixed.", }, } asr_pipelines = {} embedding_model = None llm_client = None APP_CSS = """ :root { --brand: #0f766e; --brand-strong: #115e59; --ink: #111827; --muted: #64748b; --line: #d8ded9; --paper: #ffffff; --wash: #f6f7f2; --accent: #c2410c; } body, .gradio-container { background: var(--wash) !important; color: var(--ink); font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; } .main { max-width: 1180px !important; margin: 0 auto !important; } .app-shell { padding: 28px 28px 12px; border-bottom: 1px solid var(--line); } .app-kicker { margin: 0 0 8px; color: var(--brand-strong); font-size: 12px; font-weight: 700; letter-spacing: 0.08em; text-transform: uppercase; } .app-title { margin: 0; color: var(--ink); font-size: 34px; line-height: 1.12; letter-spacing: 0; } .app-subtitle { margin: 12px 0 0; max-width: 780px; color: var(--muted); font-size: 16px; line-height: 1.6; } .status-strip { display: grid; grid-template-columns: repeat(3, minmax(0, 1fr)); gap: 10px; margin-top: 20px; } .status-item { background: #ffffff; border: 1px solid var(--line); border-radius: 8px; padding: 12px 14px; } .status-label { color: var(--muted); font-size: 12px; margin-bottom: 4px; } .status-value { color: var(--ink); font-weight: 700; font-size: 14px; } .gradio-container .block { border-radius: 8px !important; } .gradio-container button.primary { background: var(--brand) !important; border-color: var(--brand) !important; } .gradio-container button.primary:hover { background: var(--brand-strong) !important; border-color: var(--brand-strong) !important; } textarea, input, .wrap { border-radius: 8px !important; } .output-panel textarea { font-size: 14px !important; line-height: 1.55 !important; } .correction-notes, .evidence-panel { background: var(--paper); } @media (max-width: 760px) { .app-shell { padding: 22px 18px 8px; } .app-title { font-size: 28px; } .status-strip { grid-template-columns: 1fr; } } """ def get_asr_pipeline(model_id: str): if model_id not in asr_pipelines: asr_pipelines[model_id] = pipeline( "automatic-speech-recognition", model=model_id, device=-1, ) return asr_pipelines[model_id] def get_embedding_model(): global embedding_model if embedding_model is None: embedding_model = SentenceTransformer(EMBEDDING_MODEL) return embedding_model def get_llm_client(): global llm_client if not HF_TOKEN: return None if llm_client is None: llm_client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=HF_TOKEN, ) return llm_client def read_text_file(path: Path) -> str: for encoding in ("utf-8", "gb18030"): try: return path.read_text(encoding=encoding) except UnicodeDecodeError: continue return path.read_text(errors="ignore") def extract_document_text(file_path: str) -> str: path = Path(file_path) suffix = path.suffix.lower() if suffix == ".txt": text = read_text_file(path) elif suffix == ".pdf": pages = [] with pdfplumber.open(path) as pdf: for page in pdf.pages: pages.append(page.extract_text() or "") text = "\n".join(pages) elif suffix == ".docx": doc = Document(path) text = "\n".join(p.text for p in doc.paragraphs) else: raise ValueError("Only PDF, DOCX, and TXT documents are supported.") text = re.sub(r"[ \t]+", " ", text) text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() def split_into_chunks(text: str, max_chars: int = 700, overlap: int = 90) -> list[str]: paragraphs = re.split(r"\n\s*\n+", text) pieces = [] for paragraph in paragraphs: paragraph = paragraph.strip() if not paragraph: continue pieces.extend(re.split(r"(?<=[.!?;:])\s+", paragraph)) pieces = [p.strip() for p in pieces if p and p.strip()] chunks = [] current = "" for piece in pieces: if len(piece) > max_chars: if current: chunks.append(current) current = "" step = max_chars - overlap for start in range(0, len(piece), step): chunks.append(piece[start : start + max_chars]) continue candidate = piece if not current else f"{current}\n{piece}" if len(candidate) <= max_chars: current = candidate else: chunks.append(current) current = piece if current: chunks.append(current) return [chunk for chunk in chunks if len(chunk) >= 20] def resolve_asr_profile(profile_name: str) -> dict: return ASR_PROFILES.get(profile_name, next(iter(ASR_PROFILES.values()))) def transcribe_audio(audio_path: str, profile_name: str) -> str: profile = resolve_asr_profile(profile_name) generate_kwargs = {"task": "transcribe"} if profile["language"]: generate_kwargs["language"] = profile["language"] result = get_asr_pipeline(profile["model"])(audio_path, generate_kwargs=generate_kwargs) if isinstance(result, dict): return str(result.get("text", "")).strip() return str(result).strip() def retrieve_contexts(raw_transcript: str, chunks: list[str], top_k: int): model = get_embedding_model() doc_vectors = model.encode(chunks, normalize_embeddings=True) query_vector = model.encode([raw_transcript], normalize_embeddings=True)[0] scores = np.matmul(doc_vectors, query_vector) top_indices = np.argsort(scores)[::-1][:top_k] return [(int(i), float(scores[i]), chunks[int(i)]) for i in top_indices] def build_correction_prompt(raw_transcript: str, contexts) -> list[dict]: context_text = "\n\n".join( f"[Document passage {idx + 1} | similarity {score:.3f}]\n{text}" for idx, score, text in contexts ) system_prompt = ( "You are a strict ASR correction assistant. Correct the transcript only when the " "provided document context gives clear evidence. Focus on homophones, near-sound " "mistakes, technical terms, names, acronyms, chapter titles, and domain-specific " "phrases. Preserve the original sentence structure as much as possible. Do not " "summarize, rewrite freely, or add information that was not spoken." ) user_prompt = f""" Correct the ASR transcript using the document passages below. Rules: 1. Treat the raw transcript as the primary text. 2. Make only evidence-backed corrections. 3. Prefer keeping the original word when the document context is not strong enough. 4. Output JSON only. Do not output Markdown. JSON schema: {{ "corrected_text": "the complete corrected transcript", "changes": [ {{ "original": "incorrect word or phrase", "corrected": "corrected word or phrase", "reason": "why the document supports this correction" }} ] }} Document passages: {context_text} Raw ASR transcript: {raw_transcript} """.strip() return [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] def parse_json_response(text: str): try: return json.loads(text) except json.JSONDecodeError: match = re.search(r"\{.*\}", text, flags=re.S) if match: return json.loads(match.group(0)) raise ValueError("The language model did not return valid JSON.") def correct_with_llm(raw_transcript: str, contexts): client = get_llm_client() if client is None: return { "corrected_text": raw_transcript, "changes": [ { "original": "LLM correction skipped", "corrected": "LLM correction skipped", "reason": "HF_TOKEN is not set. Add HF_TOKEN locally or in Hugging Face Spaces secrets.", } ], } completion = client.chat.completions.create( model=LLM_MODEL, messages=build_correction_prompt(raw_transcript, contexts), temperature=0.1, max_tokens=1200, ) content = completion.choices[0].message.content return parse_json_response(content) def format_contexts(contexts) -> str: blocks = [] for rank, (idx, score, text) in enumerate(contexts, start=1): blocks.append(f"### Passage {rank}\nSimilarity: `{score:.3f}`\n\n{text}") return "\n\n---\n\n".join(blocks) def format_changes(changes) -> str: if not changes: return "No document-backed correction was needed." lines = [] for item in changes: original = item.get("original", "") corrected = item.get("corrected", "") reason = item.get("reason", "") lines.append(f"- `{original}` -> `{corrected}`: {reason}") return "\n".join(lines) def run_app(document_file, audio_file, profile_name, top_k): if document_file is None: raise gr.Error("Upload a PDF, DOCX, or TXT reference document first.") if audio_file is None: raise gr.Error("Upload or record an audio sample first.") document_text = extract_document_text(document_file) if not document_text: raise gr.Error("No text was extracted from the document. Scanned PDFs need OCR first.") chunks = split_into_chunks(document_text) if not chunks: raise gr.Error("The document is too short to build context.") raw_transcript = transcribe_audio(audio_file, profile_name) if not raw_transcript: raise gr.Error("No speech text was recognized from the audio.") contexts = retrieve_contexts(raw_transcript, chunks, int(top_k)) correction = correct_with_llm(raw_transcript, contexts) corrected_text = correction.get("corrected_text", raw_transcript) changes = correction.get("changes", []) return ( raw_transcript, corrected_text, format_changes(changes), format_contexts(contexts), ) theme = gr.themes.Soft( primary_hue="teal", secondary_hue="orange", neutral_hue="zinc", radius_size="sm", ) with gr.Blocks( title="Context-Aware Audio Correction", theme=theme, css=APP_CSS, ) as demo: gr.HTML( """

Hugging Face ASR + document retrieval

Context-Aware Audio Correction

Upload a reference document and an audio clip. The app transcribes speech, retrieves matching document passages, and corrects likely ASR mistakes using only document-backed evidence.

ASR profiles
English / Chinese / Auto
Context engine
Sentence embeddings
Correction policy
Evidence-bound
""" ) with gr.Row(): with gr.Column(scale=1, min_width=320): document_input = gr.File( label="Reference document", file_types=[".pdf", ".docx", ".txt"], type="filepath", ) audio_input = gr.Audio( label="Audio sample", sources=["upload", "microphone"], type="filepath", ) with gr.Column(scale=1, min_width=320): profile_input = gr.Radio( label="Recognition profile", choices=list(ASR_PROFILES.keys()), value="English optimized - Whisper small.en", info=( "English uses an English-only Whisper model. Chinese and Auto use " "the multilingual Whisper model." ), ) top_k_input = gr.Slider( label="Document passages to retrieve", minimum=1, maximum=8, value=4, step=1, ) submit_button = gr.Button("Transcribe and correct", variant="primary") with gr.Row(elem_classes=["output-panel"]): raw_output = gr.Textbox(label="Raw Whisper transcript", lines=9) corrected_output = gr.Textbox(label="Context-corrected transcript", lines=9) changes_output = gr.Markdown( label="Correction notes", elem_classes=["correction-notes"], ) contexts_output = gr.Markdown( label="Document evidence", elem_classes=["evidence-panel"], ) submit_button.click( fn=run_app, inputs=[document_input, audio_input, profile_input, top_k_input], outputs=[raw_output, corrected_output, changes_output, contexts_output], ) if __name__ == "__main__": demo.launch(share=True)