rupkotha / core /model_config.py
Deb
Prepare for HF Space: disable mock, module-level demo, README
06223b0
Raw
History Blame Contribute Delete
8.31 kB
# core/model_config.py
from dataclasses import dataclass
# UI mock mode — DISABLED for the published Hugging Face Space (always uses the real
# Modal-served models). The core wrappers keep a mock branch for offline UI dev; set
# this True locally to use canned story/audio without calling Modal. Keep False here.
UI_MOCK: bool = False
# ─────────────────────────────────────────────
# ACTIVE STACK — Stack A is the sole submission stack (OpenBMB prize path).
# The StackConfig machinery below is kept so a stack could be added/swapped, but
# only "A" is defined; Stacks B/C were dropped.
ACTIVE_STACK: str = "A"
# ─────────────────────────────────────────────
# WHERE INFERENCE RUNS
# "modal" → models served on Modal cloud GPUs (current; see CLAUDE.md infra note)
# "local" → models served on the local machine (legacy offline path; unsupported)
COMPUTE_LOCATION: str = "modal"
# Modal GPU tier for inference containers. A10G (24 GB) covers Stack A's runtime.
MODAL_GPU: str = "A10G"
# Keep one container warm to hide cold starts. =1 during a demo (no cold start,
# but bills idle GPU — incl. the A100 FT-serve container, so revert to 0 after).
# =0 scales to zero and only bills per request (cold start ~3-5 min on first call).
MODAL_MIN_CONTAINERS: int = 0
# ─────────────────────────────────────────────
@dataclass(frozen=True)
class StackConfig:
name: str
description: str
vision_model: str # Ollama model tag
vision_backend: str # "ollama" (all stacks use Ollama)
stt_model: str # faster-whisper model size
stt_bn_model: str | None # optional Bengali-specific STT model (HF repo)
tts_en_backend: str # "voxcpm2"
tts_bn_backend: str # "indic_tts" (chosen) | "indic_parler" (alt)
tts_bn_ref_audio: str | None # unused (was for the removed IndicF5 voice-clone)
total_params_b: float # informational — for README generation
openbmb_prize_eligible: bool
STACKS: dict[str, StackConfig] = {
"A": StackConfig(
name="Stack A — OpenBMB Prize Path",
description="MiniCPM-V 4.5 + VoxCPM2 + AI4Bharat Indic-TTS. ~12.4B. OpenBMB prize eligible.",
# Default (~Q4, 6.1GB). q8_0 was tested and gave NO Bengali quality gain at
# higher cost/latency — precision is not the bottleneck, the 8B model's
# Bengali capability is. Bengali quality is addressed via two-pass (Lever C).
vision_model="openbmb/minicpm-v4.5",
vision_backend="ollama",
stt_model="large-v3",
stt_bn_model="bangla-asr/whisper-medium-bn",
tts_en_backend="voxcpm2",
# Bengali TTS: AI4Bharat Indic-TTS (FastPitch). Chosen over Indic Parler-TTS
# (sounded artificial) and IndicF5 (voice clone — removed: too slow even on
# A100, needs a reference clip). FastPitch is fast, no reference needed.
tts_bn_backend="indic_tts",
tts_bn_ref_audio=None,
total_params_b=12.4,
openbmb_prize_eligible=True,
),
}
def get_config() -> StackConfig:
"""Returns the active stack config. Import this everywhere model details are needed."""
if ACTIVE_STACK not in STACKS:
raise ValueError(
f"ACTIVE_STACK='{ACTIVE_STACK}' is not valid. Stack A is the only defined stack."
)
return STACKS[ACTIVE_STACK]
def get_all_stacks() -> dict[str, StackConfig]:
"""Returns all defined stacks (currently just Stack A)."""
return STACKS
# HF repo IDs for the TTS backends. Model names live ONLY in this file — the
# StackConfig stores the backend *key* ('voxcpm2' | 'indic_parler' | 'indic_tts');
# this maps that key to the actual repo passed to core/modal_infra.py.
TTS_BACKEND_REPOS: dict[str, str] = {
"voxcpm2": "openbmb/VoxCPM2", # English (Voice Design)
"indic_parler": "ai4bharat/indic-parler-tts",
# AI4Bharat Indic-TTS (FastPitch + HiFi-GAN) — no HF repo; the value is the
# GitHub-release checkpoint zip (per-language). Bengali = bn.zip (~1.5 GB).
# Dedicated, MOS-tuned, no reference clip; fixed voice (no persona control).
"indic_tts": "https://github.com/AI4Bharat/Indic-TTS/releases/download/v1-checkpoints-release/bn.zip",
}
def get_tts_repo(backend: str) -> str:
"""Resolve a TTS backend key to its HuggingFace repo ID."""
try:
return TTS_BACKEND_REPOS[backend]
except KeyError:
raise ValueError(
f"Unknown TTS backend '{backend}'. Known: {list(TTS_BACKEND_REPOS)}"
)
# Per-language decoding params for the vision/story model (passed to Ollama).
# Bengali uses a more conservative profile: lower temperature + min_p floor +
# repetition penalty suppress the wrong-script (Latin) tokens, invented non-words,
# and phrase repetition that high-temperature sampling produces in a lower-resource
# language. English can afford a livelier profile. Tune these here only.
VISION_GEN_OPTIONS: dict[str, dict] = {
"en": {
"temperature": 0.8,
"top_p": 0.95,
"repeat_penalty": 1.1,
"num_predict": 500, # bound the response; a bedtime story is short
},
"bn": {
"temperature": 0.45,
"top_p": 0.9,
"top_k": 40,
"min_p": 0.05,
"repeat_penalty": 1.18,
"repeat_last_n": 64,
"num_predict": 700, # Bengali uses more tokens per word; still bounded
},
}
def get_vision_options(language: str) -> dict:
"""Return a copy of the decoding params for the given language ('en'|'bn')."""
return dict(VISION_GEN_OPTIONS.get(language, VISION_GEN_OPTIONS["en"]))
# Translation-pivot path (research option #1): MiniCPM writes the story in English
# (its strength), then IndicTrans2 translates it to fluent Bengali. Model name lives
# here only. 1B (gated, MIT) for quality; dist-200M is the faster/smaller option.
TRANSLATION_MODEL = "ai4bharat/indictrans2-en-indic-1B"
# IndicTrans2 FLORES-style language codes.
INDICTRANS_LANG_CODES: dict[str, str] = {"en": "eng_Latn", "bn": "ben_Beng"}
def get_indictrans_code(language: str) -> str:
"""Map our 'en'/'bn' codes to IndicTrans2's FLORES codes."""
try:
return INDICTRANS_LANG_CODES[language]
except KeyError:
raise ValueError(f"No IndicTrans2 code for language '{language}'.")
# ── Bengali distillation fine-tuning (see finetune/) ────────────────────────
# Teacher that writes native Bengali story labels from a drawing. Gemma 3 is
# multimodal and writes excellent Bengali (চাঁদমামা/পুকুর register). 27B gives the
# best labels (fewer code-switch leaks); 12B is faster. Label quality caps the
# student, and data-gen is a one-time job, so quality is prioritised here.
TEACHER_MODEL = "gemma3:27b"
# Base student that gets fine-tuned (the Stack A vision model, HF repo form for
# training — Ollama tag form for serving is in the StackConfig).
STUDENT_BASE_REPO = "openbmb/MiniCPM-V-4_5"
# Once a LoRA is trained + merged and served via vLLM, set this to the merged
# model path/repo and route the vision path to it. None = use the base model.
# Set 2026-06-13 after the held-out eval (finetune/eval_ft.py): the fine-tuned
# model decisively beats the base on Bengali (base output was garbled + looping;
# FT is coherent native রূপকথা), confirmed by a Bengali speaker. Bengali now routes
# to the FT model served by finetune/serve_vllm.py (app `rupkotha-ft-serve`).
FINETUNED_VISION_MODEL: str | None = "/data/out/minicpm-v-bengali-merged"
def get_compute() -> dict:
"""Returns the active compute-location settings (Modal infra). Import this in
core/modal_infra.py and the core/ wrappers — never hardcode GPU tier or location."""
if COMPUTE_LOCATION not in ("modal", "local"):
raise ValueError(f"COMPUTE_LOCATION='{COMPUTE_LOCATION}' is not valid. Use 'modal' or 'local'.")
return {
"location": COMPUTE_LOCATION,
"gpu": MODAL_GPU,
"min_containers": MODAL_MIN_CONTAINERS,
}