shwethd commited on
Commit
0634381
·
verified ·
1 Parent(s): 3216812

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -114,14 +114,35 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
114
  config = GPTConfig()
115
  model = GPT(config)
116
 
117
- # Try to load model (works both locally and on HuggingFace)
118
  try:
119
- checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
120
- model.load_state_dict(checkpoint['model_state_dict'])
121
- print("Model loaded from checkpoint")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  except FileNotFoundError:
123
  print("Warning: Model checkpoint not found. Using untrained model.")
124
  # Model will be randomly initialized - not ideal but won't crash
 
 
 
125
 
126
  model.to(device)
127
  model.eval()
 
114
  config = GPTConfig()
115
  model = GPT(config)
116
 
117
+ # Try to load model from HuggingFace Model Hub first, then local file
118
  try:
119
+ from huggingface_hub import hf_hub_download
120
+ import os
121
+
122
+ # Try to get model path from environment variable or use default
123
+ repo_id = os.getenv('HF_MODEL_REPO', 'YOUR_USERNAME/gpt2-shakespeare-124m') # Update with your repo
124
+
125
+ try:
126
+ model_path = hf_hub_download(
127
+ repo_id=repo_id,
128
+ filename="model_checkpoint_final.pt",
129
+ cache_dir=None
130
+ )
131
+ checkpoint = torch.load(model_path, map_location=device)
132
+ model.load_state_dict(checkpoint['model_state_dict'])
133
+ print(f"Model loaded from HuggingFace Hub: {repo_id}")
134
+ except Exception as e:
135
+ print(f"Could not load from Hub ({e}), trying local file...")
136
+ # Fallback to local file
137
+ checkpoint = torch.load('model_checkpoint_final.pt', map_location=device)
138
+ model.load_state_dict(checkpoint['model_state_dict'])
139
+ print("Model loaded from local checkpoint")
140
  except FileNotFoundError:
141
  print("Warning: Model checkpoint not found. Using untrained model.")
142
  # Model will be randomly initialized - not ideal but won't crash
143
+ except Exception as e:
144
+ print(f"Error loading model: {e}")
145
+ print("Using untrained model as fallback.")
146
 
147
  model.to(device)
148
  model.eval()