| import json |
| import os |
| import re |
| from pathlib import Path |
|
|
| import gradio as gr |
| import numpy as np |
| import pdfplumber |
| from docx import Document |
| from openai import OpenAI |
| from sentence_transformers import SentenceTransformer |
| from transformers import pipeline |
|
|
|
|
| EMBEDDING_MODEL = os.getenv( |
| "EMBEDDING_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
| ) |
| LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct-1M") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| DEFAULT_MULTILINGUAL_ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-small") |
|
|
| ASR_PROFILES = { |
| "English optimized - Whisper small.en": { |
| "model": os.getenv("ASR_MODEL_EN", "openai/whisper-small.en"), |
| "language": None, |
| "description": "Best default for English-only lectures and presentations.", |
| }, |
| "Chinese - Whisper multilingual small": { |
| "model": os.getenv("ASR_MODEL_ZH", DEFAULT_MULTILINGUAL_ASR_MODEL), |
| "language": "chinese", |
| "description": "Use this for Mandarin recordings and Chinese documents.", |
| }, |
| "Auto detect - Whisper multilingual small": { |
| "model": os.getenv("ASR_MODEL_AUTO", DEFAULT_MULTILINGUAL_ASR_MODEL), |
| "language": None, |
| "description": "Use this when the recording language is uncertain or mixed.", |
| }, |
| } |
|
|
| asr_pipelines = {} |
| embedding_model = None |
| llm_client = None |
|
|
|
|
| APP_CSS = """ |
| :root { |
| --brand: #0f766e; |
| --brand-strong: #115e59; |
| --ink: #111827; |
| --muted: #64748b; |
| --line: #d8ded9; |
| --paper: #ffffff; |
| --wash: #f6f7f2; |
| --accent: #c2410c; |
| } |
| |
| body, |
| .gradio-container { |
| background: var(--wash) !important; |
| color: var(--ink); |
| font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; |
| } |
| |
| .main { |
| max-width: 1180px !important; |
| margin: 0 auto !important; |
| } |
| |
| .app-shell { |
| padding: 28px 28px 12px; |
| border-bottom: 1px solid var(--line); |
| } |
| |
| .app-kicker { |
| margin: 0 0 8px; |
| color: var(--brand-strong); |
| font-size: 12px; |
| font-weight: 700; |
| letter-spacing: 0.08em; |
| text-transform: uppercase; |
| } |
| |
| .app-title { |
| margin: 0; |
| color: var(--ink); |
| font-size: 34px; |
| line-height: 1.12; |
| letter-spacing: 0; |
| } |
| |
| .app-subtitle { |
| margin: 12px 0 0; |
| max-width: 780px; |
| color: var(--muted); |
| font-size: 16px; |
| line-height: 1.6; |
| } |
| |
| .status-strip { |
| display: grid; |
| grid-template-columns: repeat(3, minmax(0, 1fr)); |
| gap: 10px; |
| margin-top: 20px; |
| } |
| |
| .status-item { |
| background: #ffffff; |
| border: 1px solid var(--line); |
| border-radius: 8px; |
| padding: 12px 14px; |
| } |
| |
| .status-label { |
| color: var(--muted); |
| font-size: 12px; |
| margin-bottom: 4px; |
| } |
| |
| .status-value { |
| color: var(--ink); |
| font-weight: 700; |
| font-size: 14px; |
| } |
| |
| .gradio-container .block { |
| border-radius: 8px !important; |
| } |
| |
| .gradio-container button.primary { |
| background: var(--brand) !important; |
| border-color: var(--brand) !important; |
| } |
| |
| .gradio-container button.primary:hover { |
| background: var(--brand-strong) !important; |
| border-color: var(--brand-strong) !important; |
| } |
| |
| textarea, |
| input, |
| .wrap { |
| border-radius: 8px !important; |
| } |
| |
| .output-panel textarea { |
| font-size: 14px !important; |
| line-height: 1.55 !important; |
| } |
| |
| .correction-notes, |
| .evidence-panel { |
| background: var(--paper); |
| } |
| |
| @media (max-width: 760px) { |
| .app-shell { |
| padding: 22px 18px 8px; |
| } |
| |
| .app-title { |
| font-size: 28px; |
| } |
| |
| .status-strip { |
| grid-template-columns: 1fr; |
| } |
| } |
| """ |
|
|
|
|
| def get_asr_pipeline(model_id: str): |
| if model_id not in asr_pipelines: |
| asr_pipelines[model_id] = pipeline( |
| "automatic-speech-recognition", |
| model=model_id, |
| device=-1, |
| ) |
| return asr_pipelines[model_id] |
|
|
|
|
| def get_embedding_model(): |
| global embedding_model |
| if embedding_model is None: |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL) |
| return embedding_model |
|
|
|
|
| def get_llm_client(): |
| global llm_client |
| if not HF_TOKEN: |
| return None |
| if llm_client is None: |
| llm_client = OpenAI( |
| base_url="https://router.huggingface.co/v1", |
| api_key=HF_TOKEN, |
| ) |
| return llm_client |
|
|
|
|
| def read_text_file(path: Path) -> str: |
| for encoding in ("utf-8", "gb18030"): |
| try: |
| return path.read_text(encoding=encoding) |
| except UnicodeDecodeError: |
| continue |
| return path.read_text(errors="ignore") |
|
|
|
|
| def extract_document_text(file_path: str) -> str: |
| path = Path(file_path) |
| suffix = path.suffix.lower() |
|
|
| if suffix == ".txt": |
| text = read_text_file(path) |
| elif suffix == ".pdf": |
| pages = [] |
| with pdfplumber.open(path) as pdf: |
| for page in pdf.pages: |
| pages.append(page.extract_text() or "") |
| text = "\n".join(pages) |
| elif suffix == ".docx": |
| doc = Document(path) |
| text = "\n".join(p.text for p in doc.paragraphs) |
| else: |
| raise ValueError("Only PDF, DOCX, and TXT documents are supported.") |
|
|
| text = re.sub(r"[ \t]+", " ", text) |
| text = re.sub(r"\n{3,}", "\n\n", text) |
| return text.strip() |
|
|
|
|
| def split_into_chunks(text: str, max_chars: int = 700, overlap: int = 90) -> list[str]: |
| paragraphs = re.split(r"\n\s*\n+", text) |
| pieces = [] |
| for paragraph in paragraphs: |
| paragraph = paragraph.strip() |
| if not paragraph: |
| continue |
| pieces.extend(re.split(r"(?<=[.!?;:])\s+", paragraph)) |
|
|
| pieces = [p.strip() for p in pieces if p and p.strip()] |
|
|
| chunks = [] |
| current = "" |
| for piece in pieces: |
| if len(piece) > max_chars: |
| if current: |
| chunks.append(current) |
| current = "" |
| step = max_chars - overlap |
| for start in range(0, len(piece), step): |
| chunks.append(piece[start : start + max_chars]) |
| continue |
|
|
| candidate = piece if not current else f"{current}\n{piece}" |
| if len(candidate) <= max_chars: |
| current = candidate |
| else: |
| chunks.append(current) |
| current = piece |
|
|
| if current: |
| chunks.append(current) |
|
|
| return [chunk for chunk in chunks if len(chunk) >= 20] |
|
|
|
|
| def resolve_asr_profile(profile_name: str) -> dict: |
| return ASR_PROFILES.get(profile_name, next(iter(ASR_PROFILES.values()))) |
|
|
|
|
| def transcribe_audio(audio_path: str, profile_name: str) -> str: |
| profile = resolve_asr_profile(profile_name) |
| generate_kwargs = {"task": "transcribe"} |
| if profile["language"]: |
| generate_kwargs["language"] = profile["language"] |
|
|
| result = get_asr_pipeline(profile["model"])(audio_path, generate_kwargs=generate_kwargs) |
| if isinstance(result, dict): |
| return str(result.get("text", "")).strip() |
| return str(result).strip() |
|
|
|
|
| def retrieve_contexts(raw_transcript: str, chunks: list[str], top_k: int): |
| model = get_embedding_model() |
| doc_vectors = model.encode(chunks, normalize_embeddings=True) |
| query_vector = model.encode([raw_transcript], normalize_embeddings=True)[0] |
| scores = np.matmul(doc_vectors, query_vector) |
| top_indices = np.argsort(scores)[::-1][:top_k] |
| return [(int(i), float(scores[i]), chunks[int(i)]) for i in top_indices] |
|
|
|
|
| def build_correction_prompt(raw_transcript: str, contexts) -> list[dict]: |
| context_text = "\n\n".join( |
| f"[Document passage {idx + 1} | similarity {score:.3f}]\n{text}" |
| for idx, score, text in contexts |
| ) |
|
|
| system_prompt = ( |
| "You are a strict ASR correction assistant. Correct the transcript only when the " |
| "provided document context gives clear evidence. Focus on homophones, near-sound " |
| "mistakes, technical terms, names, acronyms, chapter titles, and domain-specific " |
| "phrases. Preserve the original sentence structure as much as possible. Do not " |
| "summarize, rewrite freely, or add information that was not spoken." |
| ) |
| user_prompt = f""" |
| Correct the ASR transcript using the document passages below. |
| |
| Rules: |
| 1. Treat the raw transcript as the primary text. |
| 2. Make only evidence-backed corrections. |
| 3. Prefer keeping the original word when the document context is not strong enough. |
| 4. Output JSON only. Do not output Markdown. |
| |
| JSON schema: |
| {{ |
| "corrected_text": "the complete corrected transcript", |
| "changes": [ |
| {{ |
| "original": "incorrect word or phrase", |
| "corrected": "corrected word or phrase", |
| "reason": "why the document supports this correction" |
| }} |
| ] |
| }} |
| |
| Document passages: |
| {context_text} |
| |
| Raw ASR transcript: |
| {raw_transcript} |
| """.strip() |
|
|
| return [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ] |
|
|
|
|
| def parse_json_response(text: str): |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| match = re.search(r"\{.*\}", text, flags=re.S) |
| if match: |
| return json.loads(match.group(0)) |
| raise ValueError("The language model did not return valid JSON.") |
|
|
|
|
| def correct_with_llm(raw_transcript: str, contexts): |
| client = get_llm_client() |
| if client is None: |
| return { |
| "corrected_text": raw_transcript, |
| "changes": [ |
| { |
| "original": "LLM correction skipped", |
| "corrected": "LLM correction skipped", |
| "reason": "HF_TOKEN is not set. Add HF_TOKEN locally or in Hugging Face Spaces secrets.", |
| } |
| ], |
| } |
|
|
| completion = client.chat.completions.create( |
| model=LLM_MODEL, |
| messages=build_correction_prompt(raw_transcript, contexts), |
| temperature=0.1, |
| max_tokens=1200, |
| ) |
| content = completion.choices[0].message.content |
| return parse_json_response(content) |
|
|
|
|
| def format_contexts(contexts) -> str: |
| blocks = [] |
| for rank, (idx, score, text) in enumerate(contexts, start=1): |
| blocks.append(f"### Passage {rank}\nSimilarity: `{score:.3f}`\n\n{text}") |
| return "\n\n---\n\n".join(blocks) |
|
|
|
|
| def format_changes(changes) -> str: |
| if not changes: |
| return "No document-backed correction was needed." |
|
|
| lines = [] |
| for item in changes: |
| original = item.get("original", "") |
| corrected = item.get("corrected", "") |
| reason = item.get("reason", "") |
| lines.append(f"- `{original}` -> `{corrected}`: {reason}") |
| return "\n".join(lines) |
|
|
|
|
| def run_app(document_file, audio_file, profile_name, top_k): |
| if document_file is None: |
| raise gr.Error("Upload a PDF, DOCX, or TXT reference document first.") |
| if audio_file is None: |
| raise gr.Error("Upload or record an audio sample first.") |
|
|
| document_text = extract_document_text(document_file) |
| if not document_text: |
| raise gr.Error("No text was extracted from the document. Scanned PDFs need OCR first.") |
|
|
| chunks = split_into_chunks(document_text) |
| if not chunks: |
| raise gr.Error("The document is too short to build context.") |
|
|
| raw_transcript = transcribe_audio(audio_file, profile_name) |
| if not raw_transcript: |
| raise gr.Error("No speech text was recognized from the audio.") |
|
|
| contexts = retrieve_contexts(raw_transcript, chunks, int(top_k)) |
| correction = correct_with_llm(raw_transcript, contexts) |
|
|
| corrected_text = correction.get("corrected_text", raw_transcript) |
| changes = correction.get("changes", []) |
|
|
| return ( |
| raw_transcript, |
| corrected_text, |
| format_changes(changes), |
| format_contexts(contexts), |
| ) |
|
|
|
|
| theme = gr.themes.Soft( |
| primary_hue="teal", |
| secondary_hue="orange", |
| neutral_hue="zinc", |
| radius_size="sm", |
| ) |
|
|
| with gr.Blocks( |
| title="Context-Aware Audio Correction", |
| theme=theme, |
| css=APP_CSS, |
| ) as demo: |
| gr.HTML( |
| """ |
| <section class="app-shell"> |
| <p class="app-kicker">Hugging Face ASR + document retrieval</p> |
| <h1 class="app-title">Context-Aware Audio Correction</h1> |
| <p class="app-subtitle"> |
| Upload a reference document and an audio clip. The app transcribes speech, |
| retrieves matching document passages, and corrects likely ASR mistakes using |
| only document-backed evidence. |
| </p> |
| <div class="status-strip"> |
| <div class="status-item"> |
| <div class="status-label">ASR profiles</div> |
| <div class="status-value">English / Chinese / Auto</div> |
| </div> |
| <div class="status-item"> |
| <div class="status-label">Context engine</div> |
| <div class="status-value">Sentence embeddings</div> |
| </div> |
| <div class="status-item"> |
| <div class="status-label">Correction policy</div> |
| <div class="status-value">Evidence-bound</div> |
| </div> |
| </div> |
| </section> |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1, min_width=320): |
| document_input = gr.File( |
| label="Reference document", |
| file_types=[".pdf", ".docx", ".txt"], |
| type="filepath", |
| ) |
| audio_input = gr.Audio( |
| label="Audio sample", |
| sources=["upload", "microphone"], |
| type="filepath", |
| ) |
| with gr.Column(scale=1, min_width=320): |
| profile_input = gr.Radio( |
| label="Recognition profile", |
| choices=list(ASR_PROFILES.keys()), |
| value="English optimized - Whisper small.en", |
| info=( |
| "English uses an English-only Whisper model. Chinese and Auto use " |
| "the multilingual Whisper model." |
| ), |
| ) |
| top_k_input = gr.Slider( |
| label="Document passages to retrieve", |
| minimum=1, |
| maximum=8, |
| value=4, |
| step=1, |
| ) |
| submit_button = gr.Button("Transcribe and correct", variant="primary") |
|
|
| with gr.Row(elem_classes=["output-panel"]): |
| raw_output = gr.Textbox(label="Raw Whisper transcript", lines=9) |
| corrected_output = gr.Textbox(label="Context-corrected transcript", lines=9) |
|
|
| changes_output = gr.Markdown( |
| label="Correction notes", |
| elem_classes=["correction-notes"], |
| ) |
| contexts_output = gr.Markdown( |
| label="Document evidence", |
| elem_classes=["evidence-panel"], |
| ) |
|
|
| submit_button.click( |
| fn=run_app, |
| inputs=[document_input, audio_input, profile_input, top_k_input], |
| outputs=[raw_output, corrected_output, changes_output, contexts_output], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |
|
|