import gc import asyncio import base64 import io import json import os import re import time import uuid from datetime import datetime, timezone from pathlib import Path from threading import Lock, Thread from typing import Optional import gradio as gr import numpy as np import torch from fastapi.responses import HTMLResponse from pypinyin import Style, lazy_pinyin from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig try: import transformers.utils.import_utils as transformers_import_utils if not hasattr(transformers_import_utils, "is_torch_fx_available"): transformers_import_utils.is_torch_fx_available = lambda: True except Exception: pass try: import spaces except Exception: class _SpacesFallback: @staticmethod def GPU(*args, **kwargs): def decorator(fn): return fn return decorator spaces = _SpacesFallback() DEFAULT_MODEL_ID = "Alphaplasti/ToneBridge-MiniCPM4.1-8B" MODEL_ID = os.getenv("MODEL_ID", DEFAULT_MODEL_ID).strip() or DEFAULT_MODEL_ID HF_TOKEN = (os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") or "").strip() or None TTS_PROVIDER = os.getenv("TTS_PROVIDER", "edge").strip().lower() or "edge" DEFAULT_TTS_MODEL_ID = "openbmb/VoxCPM2" TTS_MODEL_ID = os.getenv("TTS_MODEL_ID", DEFAULT_TTS_MODEL_ID).strip() or DEFAULT_TTS_MODEL_ID DEFAULT_ENABLE_SERVER_TTS = "false" if TTS_PROVIDER == "browser" else "true" ENABLE_SERVER_TTS = os.getenv("ENABLE_SERVER_TTS", DEFAULT_ENABLE_SERVER_TTS).strip().lower() in {"1", "true", "yes", "y"} SERVER_TTS_ENABLED = ENABLE_SERVER_TTS and TTS_PROVIDER != "browser" TTS_MAX_CHARS = int(os.getenv("TTS_MAX_CHARS", "180")) EDGE_TTS_VOICE = os.getenv("EDGE_TTS_VOICE", "zh-CN-YunjianNeural").strip() EDGE_TTS_RATE = os.getenv("EDGE_TTS_RATE", "+0%").strip() EDGE_TTS_PITCH = os.getenv("EDGE_TTS_PITCH", "+0Hz").strip() EDGE_TTS_VOLUME = os.getenv("EDGE_TTS_VOLUME", "+0%").strip() EDGE_TTS_KARAOKE_DURATION_FACTOR = float(os.getenv("EDGE_TTS_KARAOKE_DURATION_FACTOR", "0.86")) VOXCPM_VOICE_STYLE = os.getenv( "VOXCPM_VOICE_STYLE", "A calm adult male Mandarin teacher in his 30s or 40s, warm low-pitched voice, natural conversational speed, clear Standard Mandarin, not childlike, not female", ).strip() VOXCPM_CFG_VALUE = float(os.getenv("VOXCPM_CFG_VALUE", "2.0")) VOXCPM_INFERENCE_TIMESTEPS = int(os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "6")) VOXCPM_RETRY_BADCASE = os.getenv("VOXCPM_RETRY_BADCASE", "false").strip().lower() in {"1", "true", "yes", "y"} VOXCPM_OUTPUT_SAMPLE_RATE = int(os.getenv("VOXCPM_OUTPUT_SAMPLE_RATE", "24000")) MAX_INPUT_CHARS = int(os.getenv("MAX_INPUT_CHARS", "1200")) MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "220")) LOAD_IN_4BIT = os.getenv("LOAD_IN_4BIT", "true").strip().lower() in {"1", "true", "yes", "y"} PRELOAD_MODEL = os.getenv("PRELOAD_MODEL", "true").strip().lower() in {"1", "true", "yes", "y"} SPACE_DIR = Path(__file__).resolve().parent METRICS_FILE = Path(os.getenv("METRICS_FILE", "tonebridge_usage_metrics.jsonl")) METRICS_REPO_SYNC = os.getenv("METRICS_REPO_SYNC", "false").strip().lower() in {"1", "true", "yes", "y"} METRICS_REPO_ID = ( os.getenv("METRICS_REPO_ID") or os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") or "" ).strip() DEFAULT_METRICS_REPO_PATH = ( METRICS_FILE.name if METRICS_FILE.is_absolute() else str(METRICS_FILE).replace("\\", "/") ) METRICS_REPO_PATH = os.getenv("METRICS_REPO_PATH", DEFAULT_METRICS_REPO_PATH).strip().lstrip("/") HF_METRICS_TOKEN = ( os.getenv("HF_METRICS_TOKEN") or os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") or "" ).strip() or None METRICS_LOCK = Lock() metrics_sync_error: Optional[str] = None tokenizer = None model = None load_error: Optional[str] = None tts_model = None tts_load_error: Optional[str] = None app = gr.Server() SYSTEM_PROMPT = """You are ToneBridge, a Mandarin Chinese teacher for beginner learners. Your task is to correct ONE student Chinese sentence according to the selected context and tone. Your default behavior is conservative minimal correction. Do not create a richer new sentence. Do not improve style just because another phrasing is possible. Do not shorten, expand, or rewrite a correct sentence. Preserve the student's meaning, length, intention, and punctuation style as much as possible. Never add information that is absent from the original sentence. When in doubt, choose no correction. Inputs: Context: {context} Tone: {tone} Correction style: {correction_style} Student sentence: {sentence} Correction decision rule: A correction is allowed ONLY if the original sentence has a clear problem: - wrong character - wrong word - missing necessary word - extra incorrect word - wrong measure word - wrong word order - wrong grammar pattern - tone/politeness inappropriate for the selected context If the sentence is understandable, grammatical, and natural enough for the selected context, do NOT correct it. Acceptable variants are not errors. A more formal, shorter, smoother, or more common version is NOT a correction if the original is already acceptable. Important anti-overcorrection rules: - Do not remove 一 from 有一只猫 only to make it more casual. 有一只猫 and 有只猫 can both be correct. - Do not add 的 or 色 only to make an adjective-noun phrase sound more standard if the original is already acceptable. - Do not change basic location patterns such as "A 在 B 的旁边" if they are correct and natural. - Do not change word order unless the original word order is actually wrong. - Do not mark "word order" unless the corrected sentence visibly changes the order of words. - Do not correct punctuation-only issues unless punctuation creates real confusion. - Do not replace a correct casual sentence with a formal sentence unless the selected context requires formality. - For a casual or friendly tone, do not use 您 or 您好. Use 你 / 你好. - For a teacher, client, manager, or very formal context, 您 may be appropriate. Error type consistency: - If the corrected sentence is identical to the original, Error type must be "none". - If Error type is "none", the corrected sentence must be identical to the original. - If you replace one Chinese character with another that has the same or very close pinyin, Error type should be "character/input-method mistake", not politeness. - If you cannot explain the correction by pointing to a clear visible problem, return no correction. Output rules: Return exactly 5 short lines. Use exactly these labels in this order. Do not use markdown. Do not output pinyin. Do not output translations. Do not output hidden reasoning, chain-of-thought, or tags. Explanations must be only in English. Why and Tip must be English sentences. Do not explain in Chinese. You may mention isolated Chinese words or characters inside English explanations only when necessary. Chinese sentences must stay in Chinese characters. Give only one corrected sentence. Add at most one gentle emoji in Why or Tip, never inside Chinese sentences. Allowed Error type values: none character/input-method mistake wrong character wrong word missing word extra word measure word word order grammar tone Required format: Original sentence: Corrected sentence: Error type: Why: Tip: For a correct sentence: Original sentence: Corrected sentence: Error type: none Why: This sentence is correct and natural. 😊 Tip: Keep it as it is. Examples: Input: 红桌子上有一只猫 Output: Original sentence: 红桌子上有一只猫 Corrected sentence: 红桌子上有一只猫 Error type: none Why: This sentence is correct and natural. 😊 Tip: 有只猫 is only a casual variant, not a correction. Input: 桌子上猫有一只 Output: Original sentence: 桌子上猫有一只 Corrected sentence: 桌子上有一只猫 Error type: word order Why: In this location pattern, use place + 有 + object. 😊 Tip: Put 有 before the thing that exists. Input: 我想喝谁 Output: Original sentence: 我想喝谁 Corrected sentence: 我想喝水 Error type: character/input-method mistake Why: 谁 and 水 have close pinyin, but 水 means water. 😊 Tip: Check same-sound characters when typing. """ def normalize_space(text: str) -> str: return re.sub(r"\s+", " ", (text or "").strip()) def has_chinese(text: str) -> bool: return re.search(r"[\u4e00-\u9fff]", text or "") is not None def to_pinyin(text: str) -> str: return " ".join(lazy_pinyin(text or "", style=Style.TONE)) def should_add_pinyin_for_line(line: str) -> bool: labels = ("Original sentence", "Corrected sentence") return any(label in (line or "") for label in labels) def chinese_segments(text: str): pattern = r"[\u4e00-\u9fff,。!?、;:“”‘’()《》〈〉…—\s]+" return [seg.strip() for seg in re.findall(pattern, text or "") if has_chinese(seg)] def add_pinyin_under_chinese(text: str) -> str: lines = (text or "").splitlines() enriched = [] for line in lines: clean = line.strip() if not clean: continue enriched.append(clean) segments = chinese_segments(line) if should_add_pinyin_for_line(line) else [] if segments: enriched.append("*" + " / ".join(to_pinyin(seg) for seg in segments) + "*") enriched.append("") return "\n".join(enriched).strip() def add_section_emojis(text: str) -> str: labels = { "Original sentence": "📝 Original sentence", "Corrected sentence": "✅ Corrected sentence", "Error type": "🔎 Error type", "Why": "💡 Why", "Tip": "🌱 Tip", } out = text or "" for source, target in labels.items(): out = re.sub(rf"(?m)^(\s*){re.escape(source)}\s*:", rf"\1{target} :", out) return out def normalize_model_markdown(text: str) -> str: out = (text or "").strip() out = out.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\t", " ") labels = [ "Original sentence", "Corrected sentence", "Error type", "Why", "Tip", ] for label in labels: out = re.sub(rf"\s+(?={re.escape(label)}\s*:)", "\n", out) return out.strip() def clean_corrected_sentence_value(value: str) -> str: text = normalize_space((value or "").replace("*", "")) text = re.split(r"\s+(?:Original sentence|Error type|Why|Tip)\s*:", text, maxsplit=1)[0].strip() extra_after_sentence = re.match(r"^(.+?[。!?!?])(?=\s*[\u4e00-\u9fffA-Za-z])", text) if extra_after_sentence: text = extra_after_sentence.group(1) return text.strip() def clean_correction_output(text: str) -> str: lines = normalize_model_markdown(text).splitlines() cleaned = [] for line in lines: match = re.match(r"^(Corrected sentence\s*:\s*)(.+)$", line.strip(), flags=re.I) if match: cleaned.append(match.group(1) + clean_corrected_sentence_value(match.group(2))) else: cleaned.append(line) return "\n".join(cleaned).strip() def wrap_result(markdown: str) -> str: return markdown.strip() if markdown else "No correction was produced." def final_result(markdown: str) -> str: friendly = add_section_emojis(clean_correction_output(markdown)) return wrap_result(add_pinyin_under_chinese(friendly)) def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") def extract_result_field(markdown: str, label: str) -> str: text = normalize_model_markdown(markdown or "") match = re.search(rf"(?im)^.*?{re.escape(label)}\s*:\s*(.+)$", text) if not match: return "" value = match.group(1).replace("*", "").strip() if label.lower() == "corrected sentence": return clean_corrected_sentence_value(value) return normalize_space(value) def is_mostly_chinese_explanation(text: str) -> bool: value = text or "" chinese_count = len(re.findall(r"[\u4e00-\u9fff]", value)) latin_count = len(re.findall(r"[A-Za-z]", value)) return chinese_count >= 6 and chinese_count > latin_count def english_feedback_fallback(error_type: str, label: str) -> str: kind = normalize_space(error_type).lower() is_tip = label.lower() == "tip" if "none" in kind: return "Keep it as it is." if is_tip else "This sentence is correct and natural." if "character" in kind or "input" in kind: return ( "When typing, check characters with similar pronunciation." if is_tip else "One character changes the meaning; the corrected sentence uses the intended word." ) if "word order" in kind or "order" in kind: return ( "Practice the same sentence pattern with one small change at a time." if is_tip else "The correction fixes the word order so the Mandarin pattern is clearer." ) if "measure" in kind: return ( "Pair nouns with their usual measure words." if is_tip else "The correction uses a measure word that fits the noun better." ) if "tone" in kind or "register" in kind or "polite" in kind: return ( "Match the wording to the relationship and situation." if is_tip else "The correction makes the tone fit the selected context better." ) if "word" in kind: return ( "Check the meaning of each key word before sending." if is_tip else "The correction replaces a word that does not fit the intended meaning." ) return ( "Practice the sentence pattern with one small change at a time." if is_tip else "The correction fixes a grammar issue while keeping the original meaning." ) def build_plain_correction_output( original_sentence: str, corrected_sentence: str, error_type: str, why: str, tip: str, ) -> str: return "\n".join( [ f"Original sentence: {original_sentence}", f"Corrected sentence: {corrected_sentence}", f"Error type: {error_type or 'none'}", f"Why: {why}", f"Tip: {tip}", ] ) def generate_english_feedback_repair( original_sentence: str, corrected_sentence: str, error_type: str, why: str, tip: str, ) -> str: if model is None or tokenizer is None: return "" messages = [ { "role": "system", "content": ( "Rewrite Mandarin correction feedback. Keep Original sentence, Corrected sentence, " "and Error type unchanged. Rewrite only Why and Tip in beginner-friendly English. " "Do not explain in Chinese. Do not output pinyin. Return exactly the same five labels." ), }, { "role": "user", "content": build_plain_correction_output( original_sentence, corrected_sentence, error_type, why, tip, ), }, ] try: try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except TypeError: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer([text], return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=150, do_sample=False, use_cache=True, repetition_penalty=1.05, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated = outputs[0][inputs["input_ids"].shape[-1]:] repaired = tokenizer.decode(generated, skip_special_tokens=True).strip() del inputs, outputs, generated if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return strip_thinking(repaired) except Exception: return "" def ensure_english_feedback(answer: str, original_sentence: str) -> str: why = extract_result_field(answer, "Why") tip = extract_result_field(answer, "Tip") if not (is_mostly_chinese_explanation(why) or is_mostly_chinese_explanation(tip)): return answer original = extract_result_field(answer, "Original sentence") or original_sentence corrected = extract_result_field(answer, "Corrected sentence") or original error_type = extract_result_field(answer, "Error type") or "none" repaired = generate_english_feedback_repair(original, corrected, error_type, why, tip) repaired_why = extract_result_field(repaired, "Why") or why repaired_tip = extract_result_field(repaired, "Tip") or tip if not repaired_why or is_mostly_chinese_explanation(repaired_why): repaired_why = english_feedback_fallback(error_type, "Why") if not repaired_tip or is_mostly_chinese_explanation(repaired_tip): repaired_tip = english_feedback_fallback(error_type, "Tip") return build_plain_correction_output( original, corrected, error_type, repaired_why, repaired_tip, ) def metrics_file_path() -> Path: return METRICS_FILE if METRICS_FILE.is_absolute() else SPACE_DIR / METRICS_FILE def sync_usage_metrics_to_repo(commit_message: str) -> None: global metrics_sync_error if not METRICS_REPO_SYNC: return path = metrics_file_path() if not path.exists(): return if not METRICS_REPO_ID: metrics_sync_error = "Metrics repo sync is enabled, but METRICS_REPO_ID or SPACE_ID is missing." return if not HF_METRICS_TOKEN: metrics_sync_error = "Metrics repo sync is enabled, but HF_METRICS_TOKEN or HF_TOKEN is missing." return try: from huggingface_hub import upload_file upload_file( path_or_fileobj=str(path), path_in_repo=METRICS_REPO_PATH or path.name, repo_id=METRICS_REPO_ID, repo_type="space", token=HF_METRICS_TOKEN, commit_message=commit_message, ) metrics_sync_error = None except Exception as exc: metrics_sync_error = f"Metrics repo sync failed: {exc}" def read_usage_records_unlocked() -> list[dict]: path = metrics_file_path() if not path.exists(): return [] records = [] with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: record = json.loads(line) except json.JSONDecodeError: continue if isinstance(record, dict): records.append(record) return records def write_usage_records_unlocked(records: list[dict]) -> None: path = metrics_file_path() path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: for record in records: handle.write(json.dumps(record, ensure_ascii=False, sort_keys=True) + "\n") def append_usage_record(record: dict) -> None: path = metrics_file_path() path.parent.mkdir(parents=True, exist_ok=True) with METRICS_LOCK: with path.open("a", encoding="utf-8") as handle: handle.write(json.dumps(record, ensure_ascii=False, sort_keys=True) + "\n") sync_usage_metrics_to_repo("Update ToneBridge usage metrics") def update_usage_evaluation(request_id: str, evaluation: str) -> Optional[dict]: with METRICS_LOCK: records = read_usage_records_unlocked() updated_record = None for record in records: if record.get("request_id") == request_id: record["evaluation"] = evaluation record["evaluated_at"] = utc_now_iso() updated_record = record break if updated_record is not None: write_usage_records_unlocked(records) if updated_record is not None: sync_usage_metrics_to_repo("Update ToneBridge feedback metrics") return updated_record def metric_public_view(record: dict) -> dict: return { "request_id": record.get("request_id", ""), "created_at": record.get("created_at", ""), "original_sentence": record.get("original_sentence", ""), "corrected_sentence": record.get("corrected_sentence", ""), "evaluation": record.get("evaluation"), "generation_time_seconds": record.get("generation_time_seconds"), "status": record.get("status", ""), "context": record.get("context", ""), "target_tone": record.get("target_tone", ""), "correction_mode": record.get("correction_mode", ""), "error_type": record.get("error_type", ""), "model_id": record.get("model_id", ""), } def device_label() -> str: if torch.cuda.is_available(): name = torch.cuda.get_device_name(0) mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) return f"GPU: {name} ({mem_gb:.1f} GB)" return "CPU: no CUDA GPU detected" def load_model(): global tokenizer, model, load_error if model is not None and tokenizer is not None: return try: cuda_available = torch.cuda.is_available() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) load_kwargs = { "torch_dtype": "auto", "device_map": "auto", "trust_remote_code": True, "low_cpu_mem_usage": True, "token": HF_TOKEN, } if LOAD_IN_4BIT and cuda_available: load_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) try: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, attn_implementation="sdpa", **load_kwargs, ) except Exception: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, **load_kwargs, ) model.eval() load_error = None except Exception as exc: load_error = f"Model load failed: {exc}" tokenizer = None model = None if PRELOAD_MODEL: load_model() def correction_mode_guidance(correction_mode: str) -> str: if correction_mode == "Natural correction": return ( "Natural correction: make the sentence sound natural for the chosen context, " "but only if the original is actually unnatural, incorrect, or socially inappropriate. " "If the original is already correct and natural, keep it unchanged." ) return ( "Minimal correction: change only the characters, grammar, or word order that are necessary. " "Do not rewrite the sentence if a small correction is enough." ) def context_tone_guidance(context: str, target_tone: str) -> str: context_key = normalize_space(context).lower() tone_key = normalize_space(target_tone).lower() if context_key == "wechat" and "friendly" in tone_key: return ( "WeChat + Friendly: treat the sentence like a short instant message. " "Be concise, direct, and casual. If the original sounds formal, literary, ceremonial, " "or like an invitation letter, correct it as a tone/register issue. " "Avoid stiff phrases such as 敬请, 阁下, 拨冗, 莅临, 寒舍 unless the user explicitly wants formal wording. " "Prefer everyday wording with 你, 有空, 方便, 一下, 吗, or 吧 when appropriate. " "The corrected sentence should usually be short." ) if context_key == "wechat": return ( "WeChat context: prefer concise instant-message wording. " "Avoid ceremonial or overly literary phrasing unless the target tone is explicitly formal." ) return "No extra context-specific rule." def build_user_prompt(context: str, sentence: str, target_tone: str, correction_mode: str) -> str: context = normalize_space(context) target_tone = normalize_space(target_tone) sentence = (sentence or "").strip() if not context: context = "contexte non precise" return f"""Social context: {context} Target tone: {target_tone} Correction style: {correction_mode} Correction style instruction: {correction_mode_guidance(correction_mode)} Context and tone instruction: {context_tone_guidance(context, target_tone)} Explanation language: English only Student's Chinese sentence: {sentence} Before correcting, decide whether the sentence is already correct, natural, and appropriate for the context. If it is correct, keep exactly the same sentence in "Corrected sentence". In that case, use "none" as the error type and explain simply that the sentence is correct. Correct the sentence while preserving its intention and length. Prefer the smallest possible correction. Do not turn a short sentence into a long sentence. The "Corrected sentence" line must contain only one Chinese sentence. Do not add a second option, leftover characters, notes, vocabulary, or pinyin after it. Do not add names, emotions, encouragement, or information that was not in the original sentence. Do not replace a correct sentence with a paraphrase. For example, "火车站在超市的旁边" is correct and natural for "The train station is next to the supermarket"; do not correct it to "火车站旁有超市". If you replace one Chinese character with another character that has the same or very close pinyin, mention in "Why" that it is probably a character/input-method mistake. All explanations, titles, and tips must be in English. Add one line "Error type" with a short category: character/input mistake, grammar, word order, tone/register, naturalness, or none. Use real line breaks between sections. Do not output escaped newline characters like \\n. Do not write a long paragraph. Maximum 5 short lines. /no_think""" CONTEXT_TONE_PROFILES = { "friendly-informal": { "context": "friendly everyday conversation with a friend or close person", "tone": "informal friendly", "correction_style": "tone-aware", "instruction": ( "Keep the sentence simple, natural, and friendly. Prefer everyday spoken wording. " "Use \u4f60 when a pronoun is needed. Avoid \u60a8, \u662f\u5426, ceremonial, literary, or stiff formal wording." ), }, "work-informal": { "context": "workplace message to a colleague or familiar coworker", "tone": "informal professional", "correction_style": "tone-aware", "instruction": ( "Keep the sentence clear, polite, and work-appropriate without sounding stiff. " "Avoid slang, but do not over-formalize if the original is already natural." ), }, "work-formal": { "context": "workplace message to a manager, client, teacher, or formal contact", "tone": "formal professional", "correction_style": "tone-aware", "instruction": ( "Use respectful, professional wording when needed. \u60a8 and \u8bf7 may be appropriate. " "Avoid overly casual phrasing if the relationship requires formality." ), }, "wechat-informal": { "context": "WeChat message to a friend or close contact", "tone": "informal instant message", "correction_style": "tone-aware", "instruction": ( "Prefer short, direct instant-message wording. Use \u4f60, \u6709\u7a7a, \u65b9\u4fbf, " "\u4e00\u4e0b, \u5417, or \u5427 when appropriate. Avoid \u60a8, \u662f\u5426, " "\u656c\u8bf7, \u9601\u4e0b, \u62e8\u5197, \u8385\u4e34, and invitation-letter style." ), }, "wechat-formal": { "context": "WeChat message in a professional or formal relationship", "tone": "formal concise instant message", "correction_style": "tone-aware", "instruction": ( "Keep the message concise like WeChat, but respectful. \u8bf7 and \u60a8 may be appropriate. " "Avoid both casual slang and overly ceremonial letter-style wording." ), }, } CONTEXT_TONE_ALIASES = { "amical-informel": "friendly-informal", "amis-informel": "friendly-informal", "friends": "friendly-informal", "family": "friendly-informal", "friendly": "friendly-informal", "work": "work-formal", "work-informel": "work-informal", "work-formel": "work-formal", "wechat": "wechat-informal", "wechat-informel": "wechat-informal", "wechat-formel": "wechat-formal", } def normalize_context_tone(value: str) -> str: key = normalize_space(value).lower().replace("_", "-") key = re.sub(r"\s+", "-", key) return CONTEXT_TONE_ALIASES.get(key, key if key in CONTEXT_TONE_PROFILES else "friendly-informal") def context_tone_profile(value: str) -> dict: key = normalize_context_tone(value) profile = dict(CONTEXT_TONE_PROFILES[key]) profile["key"] = key return profile def build_user_prompt(context: str, sentence: str, target_tone: str = "", correction_mode: str = "") -> str: profile = context_tone_profile(context) sentence = (sentence or "").strip() return f"""Selected context-tone: {profile["key"]} Context: {profile["context"]} Tone: {profile["tone"]} Correction style: {profile["correction_style"]} Profile instruction: {profile["instruction"]} Explanation language: English only Student's Chinese sentence: {sentence} Before correcting, decide whether the sentence is already correct, natural, and appropriate for the selected context-tone. If it is correct, keep exactly the same sentence in "Corrected sentence". In that case, use "none" as the error type and explain simply that the sentence is correct. Correct the sentence while preserving its intention and length. Prefer the smallest possible correction. Do not turn a short sentence into a long sentence. The "Corrected sentence" line must contain only one Chinese sentence. Do not add a second option, leftover characters, notes, vocabulary, or pinyin after it. Do not add names, emotions, encouragement, or information that was not in the original sentence. Do not replace a correct sentence with a paraphrase. If you replace one Chinese character with another character that has the same or very close pinyin, mention in "Why" that it is probably a character/input-method mistake. All explanations, titles, and tips must be in English. Use real line breaks between sections. Do not output escaped newline characters like \\n. Do not write a long paragraph. Maximum 5 short lines. Now correct the input sentence. /no_think""" @spaces.GPU(duration=90) def _generate_correction_gpu( context: str, sentence: str, target_tone: str, correction_mode: str = "tone-aware", ) -> str: sentence = (sentence or "").strip() if not sentence: message = "Add a Chinese sentence first." return wrap_result(message) if len(sentence) > MAX_INPUT_CHARS: return wrap_result(f"The sentence is too long ({len(sentence)} characters). Current limit: {MAX_INPUT_CHARS}.") load_model() if load_error: return wrap_result(load_error) if model is None or tokenizer is None: message = "The model is not available." return wrap_result(message) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_user_prompt(context, sentence, target_tone, correction_mode)}, ] try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except TypeError: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer([text], return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, use_cache=True, repetition_penalty=1.05, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated = outputs[0][inputs["input_ids"].shape[-1]:] answer = tokenizer.decode(generated, skip_special_tokens=True).strip() answer = strip_thinking(answer) del inputs, outputs, generated if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() if not answer: answer = "The model did not produce a response." else: answer = ensure_english_feedback(answer, sentence) return final_result(answer) @app.api(name="corriger") def generate_correction( context: str, sentence: str, target_tone: str, correction_mode: str = "tone-aware", ) -> dict: original_sentence = (sentence or "").strip() profile = context_tone_profile(context) started = time.perf_counter() result = _generate_correction_gpu(context, sentence, target_tone, correction_mode) generation_time_seconds = round(time.perf_counter() - started, 3) corrected_sentence = extract_result_field(result, "Corrected sentence") error_type = extract_result_field(result, "Error type") should_record = bool(original_sentence) and len(original_sentence) <= MAX_INPUT_CHARS request_id = str(uuid.uuid4()) if should_record else "" metrics_error = "" if should_record: metric_status = "ok" if corrected_sentence else "unparsed_or_error" record = { "request_id": request_id, "created_at": utc_now_iso(), "model_id": MODEL_ID, "context": profile["key"], "target_tone": profile["tone"], "correction_mode": profile["correction_style"], "original_sentence": original_sentence, "corrected_sentence": corrected_sentence, "evaluation": None, "generation_time_seconds": generation_time_seconds, "error_type": error_type, "status": metric_status, } try: append_usage_record(record) except Exception as exc: metrics_error = f"Metrics save failed: {exc}" return { "ok": bool(result), "request_id": request_id, "result": result, "original_sentence": original_sentence, "corrected_sentence": corrected_sentence, "evaluation": None, "generation_time_seconds": generation_time_seconds, "status": "ok" if corrected_sentence else "unparsed_or_error", "metrics_error": metrics_error, "metrics_sync_error": metrics_sync_error, } @app.api(name="rate_response") def rate_response(request_id: str, evaluation: str) -> dict: request_id = normalize_space(request_id) evaluation_key = normalize_space(evaluation).lower().replace("-", "_").replace(" ", "_") aliases = { "up": "thumbs_up", "thumb_up": "thumbs_up", "thumbs_up": "thumbs_up", "positive": "thumbs_up", "down": "thumbs_down", "thumb_down": "thumbs_down", "thumbs_down": "thumbs_down", "negative": "thumbs_down", } normalized_evaluation = aliases.get(evaluation_key) if not request_id: return {"ok": False, "error": "Missing request_id."} if not normalized_evaluation: return {"ok": False, "error": "Evaluation must be thumbs_up or thumbs_down."} try: record = update_usage_evaluation(request_id, normalized_evaluation) except Exception as exc: return {"ok": False, "error": f"Metrics update failed: {exc}"} if record is None: return {"ok": False, "error": "Metric record not found."} return { "ok": True, "request_id": request_id, "evaluation": normalized_evaluation, "metrics_sync_error": metrics_sync_error, "record": metric_public_view(record), } @app.api(name="usage_metrics") def usage_metrics(limit: int = 500) -> dict: try: limit = max(1, min(int(limit or 500), 5000)) except Exception: limit = 500 try: with METRICS_LOCK: records = read_usage_records_unlocked() except Exception as exc: return {"ok": False, "error": f"Metrics read failed: {exc}", "records": []} recent = records[-limit:] return { "ok": True, "count": len(records), "returned": len(recent), "metrics_file": str(metrics_file_path()), "metrics_repo_sync": METRICS_REPO_SYNC, "metrics_repo_id": METRICS_REPO_ID, "metrics_repo_path": METRICS_REPO_PATH, "metrics_sync_error": metrics_sync_error, "records": [metric_public_view(record) for record in recent], } def clean_tts_text(text: str) -> str: text = re.sub(r"[^\u4e00-\u9fff,。!?、;:\s]", "", text or "") return normalize_space(text)[:TTS_MAX_CHARS] def trim_tts_silence(audio, sample_rate: int): arr = np.asarray(audio, dtype=np.float32) if arr.ndim > 2: arr = np.squeeze(arr) if arr.ndim == 2 and arr.shape[0] <= 2 and arr.shape[0] < arr.shape[1]: arr = arr.T if arr.ndim == 2 and arr.shape[1] == 1: arr = arr[:, 0] if not sample_rate or arr.size == 0: return arr, 0, 0 energy = np.max(np.abs(arr), axis=1) if arr.ndim == 2 else np.abs(arr) peak = float(np.max(energy)) if energy.size else 0.0 if peak <= 1e-6: return arr, 0, 0 threshold = max(peak * 0.025, 0.002) voiced = np.flatnonzero(energy > threshold) if voiced.size == 0: return arr, 0, 0 pad_start = int(sample_rate * 0.06) pad_end = int(sample_rate * 0.14) start = max(0, int(voiced[0]) - pad_start) end = min(len(energy), int(voiced[-1]) + pad_end) trimmed = arr[start:end] trim_start_ms = int(start / sample_rate * 1000) trim_end_ms = int((len(energy) - end) / sample_rate * 1000) return trimmed, trim_start_ms, trim_end_ms def resample_audio(audio, source_rate: int, target_rate: int): if not source_rate or not target_rate or source_rate == target_rate: return audio, source_rate if target_rate <= 0 or source_rate <= 0: return audio, source_rate arr = np.asarray(audio) if arr.size == 0: return arr, source_rate source_len = arr.shape[0] target_len = max(1, int(round(source_len * target_rate / source_rate))) source_positions = np.linspace(0, source_len - 1, num=source_len) target_positions = np.linspace(0, source_len - 1, num=target_len) if arr.ndim == 1: return np.interp(target_positions, source_positions, arr).astype(arr.dtype), target_rate channels = [ np.interp(target_positions, source_positions, arr[:, channel]) for channel in range(arr.shape[1]) ] return np.stack(channels, axis=1).astype(arr.dtype), target_rate def load_tts_model(): global tts_model, tts_load_error if tts_model is not None: return if not SERVER_TTS_ENABLED: tts_load_error = "Server TTS is disabled." return try: from voxcpm import VoxCPM try: tts_model = VoxCPM.from_pretrained(TTS_MODEL_ID, load_denoiser=False) except TypeError: tts_model = VoxCPM.from_pretrained(TTS_MODEL_ID) tts_load_error = None except Exception as exc: tts_model = None tts_load_error = f"Server TTS failed: {exc}" async def _edge_tts_audio_bytes(text: str) -> bytes: import edge_tts communicate = edge_tts.Communicate( text=text, voice=EDGE_TTS_VOICE, rate=EDGE_TTS_RATE, pitch=EDGE_TTS_PITCH, volume=EDGE_TTS_VOLUME, ) chunks = [] async for chunk in communicate.stream(): if chunk.get("type") == "audio" and chunk.get("data"): chunks.append(chunk["data"]) return b"".join(chunks) def run_async_safely(coro): try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) result = {} def runner(): try: result["value"] = asyncio.run(coro) except Exception as exc: result["error"] = exc thread = Thread(target=runner) thread.start() thread.join() if "error" in result: raise result["error"] return result.get("value") def generate_edge_tts(text: str, speaker: str = "edge-tts") -> dict: phrase = clean_tts_text(text) if not phrase: return {"ok": False, "error": "No Chinese text to read."} if not SERVER_TTS_ENABLED: return {"ok": False, "error": "Server TTS is disabled."} try: audio_bytes = run_async_safely(_edge_tts_audio_bytes(phrase)) if not audio_bytes: return {"ok": False, "error": "Edge TTS returned no audio."} payload = base64.b64encode(audio_bytes).decode("ascii") return { "ok": True, "audio": f"data:audio/mpeg;base64,{payload}", "duration_ms": 0, "speaker": speaker or EDGE_TTS_VOICE, "voice": EDGE_TTS_VOICE, "source": "edge-tts", } except Exception as exc: return {"ok": False, "error": f"Edge TTS generation failed: {exc}"} @spaces.GPU(duration=60) def _generate_tts_gpu(text: str, speaker: str = "VoxCPM2") -> dict: phrase = clean_tts_text(text) if not phrase: return {"ok": False, "error": "No Chinese text to read."} load_tts_model() if tts_load_error or tts_model is None: return {"ok": False, "error": tts_load_error or "Server TTS model is not available."} try: import soundfile as sf synthesis_text = f"({VOXCPM_VOICE_STYLE}){phrase}" if VOXCPM_VOICE_STYLE else phrase try: audio = tts_model.generate( text=synthesis_text, cfg_value=VOXCPM_CFG_VALUE, inference_timesteps=VOXCPM_INFERENCE_TIMESTEPS, normalize=True, denoise=False, retry_badcase=VOXCPM_RETRY_BADCASE, retry_badcase_max_times=1, ) except TypeError: audio = tts_model.generate( text=synthesis_text, cfg_value=VOXCPM_CFG_VALUE, inference_timesteps=VOXCPM_INFERENCE_TIMESTEPS, ) if isinstance(audio, (list, tuple)): audio = audio[0] if hasattr(audio, "detach"): audio = audio.detach().cpu().float().numpy() sample_rate = getattr(getattr(tts_model, "tts_model", None), "sample_rate", 48000) audio, trim_start_ms, trim_end_ms = trim_tts_silence(audio, sample_rate) audio, sample_rate = resample_audio(audio, sample_rate, VOXCPM_OUTPUT_SAMPLE_RATE) buffer = io.BytesIO() sf.write(buffer, audio, sample_rate, format="WAV") audio_bytes = buffer.getvalue() duration_ms = int(len(audio) / sample_rate * 1000) if sample_rate else 0 payload = base64.b64encode(audio_bytes).decode("ascii") return { "ok": True, "audio": f"data:audio/wav;base64,{payload}", "duration_ms": duration_ms, "sample_rate": sample_rate, "speaker": speaker or "VoxCPM2", "source": "server", "trim_start_ms": trim_start_ms, "trim_end_ms": trim_end_ms, } except Exception as exc: return {"ok": False, "error": f"Server TTS generation failed: {exc}"} @app.api(name="tts") def generate_tts(text: str, speaker: str = "edge-tts") -> dict: provider = normalize_space(TTS_PROVIDER).lower() if provider in {"edge", "edge-tts", "microsoft", "microsoft-edge"}: return generate_edge_tts(text, speaker) if provider in {"voxcpm", "voxcpm2", "server"}: return _generate_tts_gpu(text, speaker or "VoxCPM2") return {"ok": False, "error": f"Unsupported TTS provider: {TTS_PROVIDER}"} def strip_thinking(text: str) -> str: return re.sub(r"(?is).*?", "", text or "").strip() def runtime_info() -> str: loaded = "yes" if model is not None and tokenizer is not None else "no" return "\n".join( [ f"MODEL_ID: {MODEL_ID}", f"TTS_PROVIDER: {TTS_PROVIDER}", f"TTS_MODEL_ID: {TTS_MODEL_ID}", f"EDGE_TTS_VOICE: {EDGE_TTS_VOICE}", f"EDGE_TTS_KARAOKE_DURATION_FACTOR: {EDGE_TTS_KARAOKE_DURATION_FACTOR}", f"Model loaded: {loaded}", f"Server TTS enabled: {SERVER_TTS_ENABLED}", f"LOAD_IN_4BIT: {LOAD_IN_4BIT}", device_label(), f"MAX_INPUT_CHARS: {MAX_INPUT_CHARS}", f"MAX_NEW_TOKENS: {MAX_NEW_TOKENS}", f"METRICS_FILE: {metrics_file_path()}", f"METRICS_REPO_SYNC: {METRICS_REPO_SYNC}", f"METRICS_REPO_ID: {METRICS_REPO_ID or '(not configured)'}", f"METRICS_REPO_PATH: {METRICS_REPO_PATH}", f"METRICS_SYNC_ERROR: {metrics_sync_error or '(none)'}", f"VOXCPM_INFERENCE_TIMESTEPS: {VOXCPM_INFERENCE_TIMESTEPS}", f"VOXCPM_OUTPUT_SAMPLE_RATE: {VOXCPM_OUTPUT_SAMPLE_RATE}", f"VOXCPM_RETRY_BADCASE: {VOXCPM_RETRY_BADCASE}", ] ) FRONTEND_HTML = r""" ToneBridge Mandarin Coach
Mandarin sentence coach

ToneBridge

Build natural Mandarin sentences, one gentle correction at a time. 😊

Context aware Natural tone Reading voice
Write

Your sentence

ToneBridge applies a conservative tone-aware correction for the selected situation.

Voice mode listens until you click stop, then corrects the sentence and reads the corrected version aloud.

Examples: tap one to fill the form.

Learning notes
Your patterns will appear here.
  • Your last correction types will appear here.
Coach answer

Correction 😊

Ready
Ready when you are Your correction will appear here.
🎧 Reading

Replay the corrected sentence and follow the characters.

""" @app.get("/", response_class=HTMLResponse) async def index(): return ( FRONTEND_HTML .replace("__SERVER_TTS_ENABLED__", "true" if SERVER_TTS_ENABLED else "false") .replace("__TTS_PROVIDER__", TTS_PROVIDER) .replace("__EDGE_TTS_KARAOKE_DURATION_FACTOR__", str(EDGE_TTS_KARAOKE_DURATION_FACTOR)) ) demo = app if __name__ == "__main__": demo.launch(ssr_mode=False)