mansaripo commited on
Commit
e5f90f7
·
verified ·
1 Parent(s): 99063a5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_cloverlm.py +20 -2
modeling_cloverlm.py CHANGED
@@ -255,13 +255,31 @@ class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
255
  )
256
  self.post_init()
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  @classmethod
259
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
260
  import os
261
  from safetensors import safe_open
262
 
263
- st_path = os.path.join(str(pretrained_model_name_or_path), "model.safetensors")
264
- if not os.path.exists(st_path):
265
  return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
266
 
267
  with safe_open(st_path, framework="pt") as f:
 
255
  )
256
  self.post_init()
257
 
258
+ @classmethod
259
+ def _resolve_safetensors(cls, pretrained_model_name_or_path, **kwargs):
260
+ """Locate model.safetensors for a local dir or Hub repo ID."""
261
+ import os
262
+ path = str(pretrained_model_name_or_path)
263
+ local = os.path.join(path, "model.safetensors")
264
+ if os.path.exists(local):
265
+ return local
266
+ try:
267
+ from huggingface_hub import hf_hub_download
268
+ return hf_hub_download(
269
+ repo_id=path,
270
+ filename="model.safetensors",
271
+ token=kwargs.get("token"),
272
+ )
273
+ except Exception:
274
+ return None
275
+
276
  @classmethod
277
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
278
  import os
279
  from safetensors import safe_open
280
 
281
+ st_path = cls._resolve_safetensors(pretrained_model_name_or_path, **kwargs)
282
+ if st_path is None:
283
  return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
284
 
285
  with safe_open(st_path, framework="pt") as f: