saeedabdulmuizz commited on
Commit
d3c2948
·
verified ·
1 Parent(s): 7da8398

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -15,7 +15,7 @@ except ImportError:
15
  import soundfile as sf
16
  import traceback
17
  from huggingface_hub import hf_hub_download
18
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
19
  from peft import PeftModel
20
  from matcha.models.matcha_tts import MatchaTTS
21
  from matcha.hifigan.models import Generator as HiFiGAN
@@ -55,8 +55,17 @@ def load_models():
55
  TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
56
  TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
57
 
 
 
 
58
  def load_translation_models():
59
- print("[*] Loading Sarvam Translate Adapter...")
 
 
 
 
 
 
60
  try:
61
  # Load the tokenizer with left padding (required for causal LM)
62
  tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
@@ -64,21 +73,14 @@ def load_translation_models():
64
  if tokenizer.pad_token is None:
65
  tokenizer.pad_token = tokenizer.eos_token
66
 
67
- # Use 4-bit quantization to fit in 16GB memory limit
68
- print("[*] Using 4-bit quantization to reduce memory usage...")
69
- quantization_config = BitsAndBytesConfig(
70
- load_in_4bit=True,
71
- bnb_4bit_compute_dtype=torch.float16,
72
- bnb_4bit_use_double_quant=True,
73
- bnb_4bit_quant_type="nf4"
74
- )
75
-
76
- # Load the base model with 4-bit quantization
77
- print("[*] Loading base model as AutoModelForCausalLM (4-bit)...")
78
  base_model = AutoModelForCausalLM.from_pretrained(
79
- TRANSLATION_BASE_MODEL,
80
- quantization_config=quantization_config,
81
- device_map="auto",
 
82
  trust_remote_code=True
83
  )
84
 
@@ -87,18 +89,25 @@ def load_translation_models():
87
  model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
88
  model.eval()
89
 
90
- print(f"[+] Translation model loaded successfully.")
 
 
 
91
  return tokenizer, model
92
  except Exception as e:
93
  print(f"[-] Error loading translation model: {e}")
94
  traceback.print_exc()
95
  return None, None
96
 
 
97
  model, vocoder = load_models()
98
- trans_tokenizer, trans_model = load_translation_models()
99
 
100
  def _translate_impl(text):
101
  """Internal translation implementation - matching evaluate_model.py approach."""
 
 
 
102
  if trans_model is None:
103
  return "Translation model unavailable."
104
 
@@ -113,7 +122,7 @@ def _translate_impl(text):
113
  prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
114
  inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
115
 
116
- # Move inputs to model's device (handles device_map="auto")
117
  inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
118
 
119
  print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
@@ -155,14 +164,9 @@ def _translate_impl(text):
155
  traceback.print_exc()
156
  return "Error during translation generation."
157
 
158
- # Wrap with GPU decorator if available
159
- if SPACES_AVAILABLE:
160
- @spaces.GPU
161
- def translate(text):
162
- return _translate_impl(text)
163
- else:
164
- def translate(text):
165
- return _translate_impl(text)
166
 
167
 
168
  # --- Update the function signature to accept two arguments ---
 
15
  import soundfile as sf
16
  import traceback
17
  from huggingface_hub import hf_hub_download
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM
19
  from peft import PeftModel
20
  from matcha.models.matcha_tts import MatchaTTS
21
  from matcha.hifigan.models import Generator as HiFiGAN
 
55
  TRANSLATION_BASE_MODEL = "sarvamai/sarvam-translate"
56
  TRANSLATION_ADAPTER = "GAASH-Lab/Sarvam-Kashmiri-finetuned"
57
 
58
+ # Global cache for translation model (loaded lazily when GPU is available)
59
+ _trans_cache = {"tokenizer": None, "model": None, "loaded": False}
60
+
61
  def load_translation_models():
62
+ """Load translation model lazily on first use (CPU deployment)."""
63
+ global _trans_cache
64
+
65
+ if _trans_cache["loaded"]:
66
+ return _trans_cache["tokenizer"], _trans_cache["model"]
67
+
68
+ print("[*] Loading Sarvam Translate Adapter (CPU mode)...")
69
  try:
70
  # Load the tokenizer with left padding (required for causal LM)
71
  tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
 
73
  if tokenizer.pad_token is None:
74
  tokenizer.pad_token = tokenizer.eos_token
75
 
76
+ # Load the base model on CPU with bfloat16 to reduce memory
77
+ # bfloat16 is better supported on CPU than float16
78
+ print("[*] Loading base model on CPU (bfloat16)...")
 
 
 
 
 
 
 
 
79
  base_model = AutoModelForCausalLM.from_pretrained(
80
+ TRANSLATION_BASE_MODEL,
81
+ torch_dtype=torch.bfloat16,
82
+ device_map="cpu",
83
+ low_cpu_mem_usage=True,
84
  trust_remote_code=True
85
  )
86
 
 
89
  model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
90
  model.eval()
91
 
92
+ print(f"[+] Translation model loaded successfully on CPU.")
93
+ _trans_cache["tokenizer"] = tokenizer
94
+ _trans_cache["model"] = model
95
+ _trans_cache["loaded"] = True
96
  return tokenizer, model
97
  except Exception as e:
98
  print(f"[-] Error loading translation model: {e}")
99
  traceback.print_exc()
100
  return None, None
101
 
102
+ # Load TTS models at startup (they're smaller)
103
  model, vocoder = load_models()
104
+ # Translation model will be loaded lazily when GPU is available
105
 
106
  def _translate_impl(text):
107
  """Internal translation implementation - matching evaluate_model.py approach."""
108
+ # Load model lazily (will be cached after first load)
109
+ trans_tokenizer, trans_model = load_translation_models()
110
+
111
  if trans_model is None:
112
  return "Translation model unavailable."
113
 
 
122
  prompt = trans_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
123
  inputs = trans_tokenizer(prompt, padding=True, truncation=True, max_length=512, return_tensors="pt")
124
 
125
+ # Move inputs to model's device
126
  inputs = {k: v.to(trans_model.device) for k, v in inputs.items()}
127
 
128
  print(f"[DEBUG] Input tokens: {inputs['input_ids'].shape[1]}")
 
164
  traceback.print_exc()
165
  return "Error during translation generation."
166
 
167
+ # Simple wrapper function for CPU deployment
168
+ def translate(text):
169
+ return _translate_impl(text)
 
 
 
 
 
170
 
171
 
172
  # --- Update the function signature to accept two arguments ---