saeedabdulmuizz commited on
Commit
6469bde
·
verified ·
1 Parent(s): 5b50bdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -15,7 +15,7 @@ except ImportError:
15
  import soundfile as sf
16
  import traceback
17
  from huggingface_hub import hf_hub_download
18
- from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText
19
  from peft import PeftModel
20
  from matcha.models.matcha_tts import MatchaTTS
21
  from matcha.hifigan.models import Generator as HiFiGAN
@@ -61,21 +61,27 @@ def load_translation_models():
61
  # Load the tokenizer
62
  tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
63
 
64
- # Load the base model with the correct class for Gemma 3 (Multimodal Causal LM)
65
- # using standard loading to avoid offloading/partitioning issues with PEFT
66
- base_model = AutoModelForImageTextToText.from_pretrained(
 
67
  TRANSLATION_BASE_MODEL,
68
  torch_dtype=torch.float16,
69
  trust_remote_code=True
70
  )
71
 
72
  # Load the LoRA adapter
 
73
  model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
74
 
 
 
 
 
75
  # Move to device
76
  model.to(DEVICE)
77
  model.eval()
78
- print("[+] Translation model loaded successfully.")
79
  return tokenizer, model
80
  except Exception as e:
81
  print(f"[-] Error loading translation model: {e}")
 
15
  import soundfile as sf
16
  import traceback
17
  from huggingface_hub import hf_hub_download
18
+ from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
19
  from peft import PeftModel
20
  from matcha.models.matcha_tts import MatchaTTS
21
  from matcha.hifigan.models import Generator as HiFiGAN
 
61
  # Load the tokenizer
62
  tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_BASE_MODEL, trust_remote_code=True)
63
 
64
+ # Load the base model with the EXACT class used during training (Gemma3ForConditionalGeneration)
65
+ # This ensures LoRA layers map correctly
66
+ print("[*] Loading base model as Gemma3ForConditionalGeneration...")
67
+ base_model = Gemma3ForConditionalGeneration.from_pretrained(
68
  TRANSLATION_BASE_MODEL,
69
  torch_dtype=torch.float16,
70
  trust_remote_code=True
71
  )
72
 
73
  # Load the LoRA adapter
74
+ print("[*] Loading LoRA adapter...")
75
  model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
76
 
77
+ # Merge the adapter weights into the base model for faster inference
78
+ print("[*] Merging adapter weights...")
79
+ model = model.merge_and_unload()
80
+
81
  # Move to device
82
  model.to(DEVICE)
83
  model.eval()
84
+ print(f"[+] Translation model loaded successfully on {DEVICE}.")
85
  return tokenizer, model
86
  except Exception as e:
87
  print(f"[-] Error loading translation model: {e}")