Tousifahamed commited on
Commit
9d9dde9
·
verified ·
1 Parent(s): f77e017

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. model_utils.py +10 -6
app.py CHANGED
@@ -56,6 +56,6 @@ demo = gr.Interface(
56
  )
57
 
58
  if __name__ == "__main__":
59
- demo.launch(share=True)
60
  else:
61
- app = demo.launch(share=False)
 
56
  )
57
 
58
  if __name__ == "__main__":
59
+ demo.launch()
60
  else:
61
+ app = demo.launch()
model_utils.py CHANGED
@@ -110,12 +110,16 @@ class GPT(nn.Module):
110
 
111
  def load_model(model_path):
112
  """Load the trained model"""
113
- checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
114
- config = checkpoint['config']
115
- model = GPT(config)
116
- model.load_state_dict(checkpoint['model_state_dict'])
117
- model.eval()
118
- return model
 
 
 
 
119
 
120
  def generate_text(model, prompt, max_new_tokens=50, temperature=0.8, top_k=40):
121
  """Generate text based on a prompt
 
110
 
111
  def load_model(model_path):
112
  """Load the trained model"""
113
+ try:
114
+ checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=True)
115
+ config = GPTConfig(**checkpoint['config'])
116
+ model = GPT(config)
117
+ model.load_state_dict(checkpoint['model_state_dict'])
118
+ model.eval()
119
+ return model
120
+ except AttributeError as e:
121
+ print(f"Error loading model: {e}")
122
+ return None
123
 
124
  def generate_text(model, prompt, max_new_tokens=50, temperature=0.8, top_k=40):
125
  """Generate text based on a prompt