alex4cip Claude commited on
Commit
884298e
·
1 Parent(s): e6dc16b

fix: Add fallback for model loading with better error handling

Browse files

**Model Loading Improvements:**
- Explicitly set use_safetensors=True for primary loading attempt
- Add try-except fallback to default loading if safetensors fails
- Keep torch_dtype=torch.float32 (dtype not supported in transformers 4.30)
- Better error messages for debugging

**Error Handling:**
- Primary: Try loading with use_safetensors=True
- Fallback: Try loading without use_safetensors if primary fails
- Print warning message when fallback is used
- Prevents complete failure when safetensors has issues

This should fix the model loading error on HF Spaces while
maintaining compatibility with both safetensors and legacy formats.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -82,13 +82,26 @@ def load_model(model_name):
82
  tokenizer.pad_token = tokenizer.eos_token
83
 
84
  # Load model with safetensors support
85
- model = AutoModelForCausalLM.from_pretrained(
86
- model_name,
87
- token=HF_TOKEN,
88
- dtype=torch.float32,
89
- low_cpu_mem_usage=True,
90
- trust_remote_code=True
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  model.to(device)
93
  model.eval()
94
 
 
82
  tokenizer.pad_token = tokenizer.eos_token
83
 
84
  # Load model with safetensors support
85
+ try:
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ model_name,
88
+ token=HF_TOKEN,
89
+ torch_dtype=torch.float32,
90
+ low_cpu_mem_usage=True,
91
+ trust_remote_code=True,
92
+ use_safetensors=True
93
+ )
94
+ except Exception as e:
95
+ # Fallback to default loading if safetensors fails
96
+ print(f"⚠️ Safetensors loading failed, trying default method: {e}")
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ model_name,
99
+ token=HF_TOKEN,
100
+ torch_dtype=torch.float32,
101
+ low_cpu_mem_usage=True,
102
+ trust_remote_code=True
103
+ )
104
+
105
  model.to(device)
106
  model.eval()
107