shwethd commited on
Commit
adc8386
·
verified ·
1 Parent(s): c1e7837

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -138,6 +138,9 @@ try:
138
  )
139
  state_dict = load_file(model_path, device=device)
140
  model.load_state_dict(state_dict)
 
 
 
141
  model_loaded = True
142
  print(f"✅ Model loaded successfully from SafeTensors: {repo_id}")
143
  except Exception as e:
@@ -148,7 +151,9 @@ try:
148
  filename="model_checkpoint_final.pt",
149
  cache_dir=None
150
  )
151
- checkpoint = torch.load(model_path, map_location=device)
 
 
152
 
153
  # Handle different checkpoint formats
154
  if 'model_state_dict' in checkpoint:
@@ -168,7 +173,8 @@ try:
168
  filename="model_checkpoint_final.pt",
169
  cache_dir=None
170
  )
171
- checkpoint = torch.load(model_path, map_location=device)
 
172
 
173
  # Handle different checkpoint formats
174
  if 'model_state_dict' in checkpoint:
@@ -185,7 +191,8 @@ try:
185
  print(f"⚠️ Could not load from Hub ({e}), trying local file...")
186
  try:
187
  # Fallback to local file
188
- checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
 
189
  if 'model_state_dict' in checkpoint:
190
  model.load_state_dict(checkpoint['model_state_dict'])
191
  elif 'state_dict' in checkpoint:
@@ -369,5 +376,6 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
369
  """)
370
 
371
  if __name__ == "__main__":
372
- demo.launch(share=True)
 
373
 
 
138
  )
139
  state_dict = load_file(model_path, device=device)
140
  model.load_state_dict(state_dict)
141
+ # Restore weight sharing (broken during SafeTensors conversion)
142
+ # lm_head.weight and transformer.wte.weight should share memory
143
+ model.transformer.wte.weight = model.lm_head.weight
144
  model_loaded = True
145
  print(f"✅ Model loaded successfully from SafeTensors: {repo_id}")
146
  except Exception as e:
 
151
  filename="model_checkpoint_final.pt",
152
  cache_dir=None
153
  )
154
+ # PyTorch 2.6+ requires weights_only=False for custom classes
155
+ # This is safe since we trust our own trained model
156
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
157
 
158
  # Handle different checkpoint formats
159
  if 'model_state_dict' in checkpoint:
 
173
  filename="model_checkpoint_final.pt",
174
  cache_dir=None
175
  )
176
+ # PyTorch 2.6+ requires weights_only=False for custom classes
177
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
178
 
179
  # Handle different checkpoint formats
180
  if 'model_state_dict' in checkpoint:
 
191
  print(f"⚠️ Could not load from Hub ({e}), trying local file...")
192
  try:
193
  # Fallback to local file
194
+ # PyTorch 2.6+ requires weights_only=False for custom classes
195
+ checkpoint = torch.load('model_checkpoint_final.pt', map_location=device, weights_only=False)
196
  if 'model_state_dict' in checkpoint:
197
  model.load_state_dict(checkpoint['model_state_dict'])
198
  elif 'state_dict' in checkpoint:
 
376
  """)
377
 
378
  if __name__ == "__main__":
379
+ # Don't use share=True on HuggingFace Spaces
380
+ demo.launch()
381