Luis J Camargo commited on
Commit
3aa360d
·
1 Parent(s): bb818a0

refactor: Streamline model loading by using `load_file` for safetensors and adjusting print statements.

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -48,16 +48,15 @@ AutoModel.register(WhisperEncoderOnlyConfig, WhisperEncoderOnlyForClassification
48
  # === LOAD MODEL ===
49
  MODEL_REPO = "tachiwin/language_classification_enconly_model_2"
50
 
51
- print("Loading processor and config...")
52
  processor = WhisperProcessor.from_pretrained(MODEL_REPO)
53
  config = WhisperEncoderOnlyConfig.from_pretrained(MODEL_REPO)
54
-
55
- print("Downloading model weights...")
56
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors")
57
-
58
- print("Initializing model...")
59
  model = WhisperEncoderOnlyForClassification(config)
60
- state_dict = torch.load(model_path, map_location="cpu")
 
 
 
 
61
  model.load_state_dict(state_dict)
62
  model.eval()
63
 
 
48
  # === LOAD MODEL ===
49
  MODEL_REPO = "tachiwin/language_classification_enconly_model_2"
50
 
51
+ print("Loading model...")
52
  processor = WhisperProcessor.from_pretrained(MODEL_REPO)
53
  config = WhisperEncoderOnlyConfig.from_pretrained(MODEL_REPO)
 
 
 
 
 
54
  model = WhisperEncoderOnlyForClassification(config)
55
+
56
+ # Load weights from safetensors
57
+ from huggingface_hub import hf_hub_download
58
+ weights_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors")
59
+ state_dict = load_file(weights_path)
60
  model.load_state_dict(state_dict)
61
  model.eval()
62