Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ except ImportError:
|
|
| 15 |
import soundfile as sf
|
| 16 |
import traceback
|
| 17 |
from huggingface_hub import hf_hub_download
|
| 18 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 19 |
from peft import PeftModel
|
| 20 |
from matcha.models.matcha_tts import MatchaTTS
|
| 21 |
from matcha.hifigan.models import Generator as HiFiGAN
|
|
@@ -55,8 +55,17 @@ def load_models():
|
|
| 55 |
TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
|
| 56 |
TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
|
| 57 |
|
|
|
|
|
|
|
|
|
|
| 58 |
def load_translation_models():
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
try:
|
| 61 |
# Load the tokenizer with left padding (required for causal LM)
|
| 62 |
tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
|
|
@@ -64,21 +73,14 @@ def load_translation_models():
|
|
| 64 |
if tokenizer.pad_token is None:
|
| 65 |
tokenizer.pad_token = tokenizer.eos_token
|
| 66 |
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
load_in_4bit=True,
|
| 71 |
-
bnb_4bit_compute_dtype=torch.float16,
|
| 72 |
-
bnb_4bit_use_double_quant=True,
|
| 73 |
-
bnb_4bit_quant_type="nf4"
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# Load the base model with 4-bit quantization
|
| 77 |
-
print("[*] Loading base model as AutoModelForCausalLM (4-bit)...")
|
| 78 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
-
TRANSLATION_BASE_MODEL,
|
| 80 |
-
|
| 81 |
-
device_map="
|
|
|
|
| 82 |
trust_remote_code=True
|
| 83 |
)
|
| 84 |
|
|
@@ -87,18 +89,25 @@ def load_translation_models():
|
|
| 87 |
model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
|
| 88 |
model.eval()
|
| 89 |
|
| 90 |
-
print(f"[+] Translation model loaded successfully.")
|
|
|
|
|
|
|
|
|
|
| 91 |
return tokenizer, model
|
| 92 |
except Exception as e:
|
| 93 |
print(f"[-] Error loading translation model: {e}")
|
| 94 |
traceback.print_exc()
|
| 95 |
return None, None
|
| 96 |
|
|
|
|
| 97 |
model, vocoder = load_models()
|
| 98 |
-
|
| 99 |
|
| 100 |
def _translate_impl(text):
|
| 101 |
"""Internal translation implementation - matching evaluate_model.py approach."""
|
|
|
|
|
|
|
|
|
|
| 102 |
if trans_model is None:
|
| 103 |
return "Translation model unavailable."
|
| 104 |
|
|
@@ -113,7 +122,7 @@ def _translate_impl(text):
|
|
| 113 |
prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 114 |
inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
| 115 |
|
| 116 |
-
# Move inputs to model's device
|
| 117 |
inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
|
| 118 |
|
| 119 |
print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
|
|
@@ -155,14 +164,9 @@ def _translate_impl(text):
|
|
| 155 |
traceback.print_exc()
|
| 156 |
return "Error during translation generation."
|
| 157 |
|
| 158 |
-
#
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
def translate(text):
|
| 162 |
-
return _translate_impl(text)
|
| 163 |
-
else:
|
| 164 |
-
def translate(text):
|
| 165 |
-
return _translate_impl(text)
|
| 166 |
|
| 167 |
|
| 168 |
# --- Update the function signature to accept two arguments ---
|
|
|
|
| 15 |
import soundfile as sf
|
| 16 |
import traceback
|
| 17 |
from huggingface_hub import hf_hub_download
|
| 18 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 19 |
from peft import PeftModel
|
| 20 |
from matcha.models.matcha_tts import MatchaTTS
|
| 21 |
from matcha.hifigan.models import Generator as HiFiGAN
|
|
|
|
| 55 |
TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
|
| 56 |
TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
|
| 57 |
|
| 58 |
+
# Global cache for translation model (loaded lazily when GPU is available)
|
| 59 |
+
_trans_cache = {"tokenizer": None, "model": None, "loaded": False}
|
| 60 |
+
|
| 61 |
def load_translation_models():
|
| 62 |
+
"""Load translation model lazily on first use (CPU deployment)."""
|
| 63 |
+
global _trans_cache
|
| 64 |
+
|
| 65 |
+
if _trans_cache["loaded"]:
|
| 66 |
+
return _trans_cache["tokenizer"], _trans_cache["model"]
|
| 67 |
+
|
| 68 |
+
print("[*] Loading Sarvam Translate Adapter (CPU mode)...")
|
| 69 |
try:
|
| 70 |
# Load the tokenizer with left padding (required for causal LM)
|
| 71 |
tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
|
|
|
|
| 73 |
if tokenizer.pad_token is None:
|
| 74 |
tokenizer.pad_token = tokenizer.eos_token
|
| 75 |
|
| 76 |
+
# Load the base model on CPU with bfloat16 to reduce memory
|
| 77 |
+
# bfloat16 is better supported on CPU than float16
|
| 78 |
+
print("[*] Loading base model on CPU (bfloat16)...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
+
TRANSLATION_BASE_MODEL,
|
| 81 |
+
torch_dtype=torch.bfloat16,
|
| 82 |
+
device_map="cpu",
|
| 83 |
+
low_cpu_mem_usage=True,
|
| 84 |
trust_remote_code=True
|
| 85 |
)
|
| 86 |
|
|
|
|
| 89 |
model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
|
| 90 |
model.eval()
|
| 91 |
|
| 92 |
+
print(f"[+] Translation model loaded successfully on CPU.")
|
| 93 |
+
_trans_cache["tokenizer"] = tokenizer
|
| 94 |
+
_trans_cache["model"] = model
|
| 95 |
+
_trans_cache["loaded"] = True
|
| 96 |
return tokenizer, model
|
| 97 |
except Exception as e:
|
| 98 |
print(f"[-] Error loading translation model: {e}")
|
| 99 |
traceback.print_exc()
|
| 100 |
return None, None
|
| 101 |
|
| 102 |
+
# Load TTS models at startup (they're smaller)
|
| 103 |
model, vocoder = load_models()
|
| 104 |
+
# Translation model will be loaded lazily when GPU is available
|
| 105 |
|
| 106 |
def _translate_impl(text):
|
| 107 |
"""Internal translation implementation - matching evaluate_model.py approach."""
|
| 108 |
+
# Load model lazily (will be cached after first load)
|
| 109 |
+
trans_tokenizer, trans_model = load_translation_models()
|
| 110 |
+
|
| 111 |
if trans_model is None:
|
| 112 |
return "Translation model unavailable."
|
| 113 |
|
|
|
|
| 122 |
prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 123 |
inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
| 124 |
|
| 125 |
+
# Move inputs to model's device
|
| 126 |
inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
|
| 127 |
|
| 128 |
print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
|
|
|
|
| 164 |
traceback.print_exc()
|
| 165 |
return "Error during translation generation."
|
| 166 |
|
| 167 |
+
# Simple wrapper function for CPU deployment
|
| 168 |
+
def translate(text):
|
| 169 |
+
return _translate_impl(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
# --- Update the function signature to accept two arguments ---
|