Tousifahamed commited on
Commit
aa150c0
·
verified ·
1 Parent(s): badb16f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. model_utils.py +10 -3
app.py CHANGED
@@ -59,6 +59,6 @@ demo = gr.Interface(
59
  )
60
 
61
  if __name__ == "__main__":
62
- demo.launch(share=True)
63
  else:
64
- app = demo.launch(share=True)
 
59
  )
60
 
61
  if __name__ == "__main__":
62
+ demo.launch(share=False)
63
  else:
64
+ app = demo.launch(share=False)
model_utils.py CHANGED
@@ -111,9 +111,16 @@ class GPT(nn.Module):
111
  def load_model(model_path):
112
  """Load the trained model"""
113
  try:
114
- torch.serialization.add_safe_globals({'GPTConfig': GPTConfig})
115
- checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), weights_only=True)
116
- config = GPTConfig(**checkpoint['config'])
 
 
 
 
 
 
 
117
  model = GPT(config)
118
  model.load_state_dict(checkpoint['model_state_dict'])
119
  model.eval()
 
111
  def load_model(model_path):
112
  """Load the trained model"""
113
  try:
114
+ checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
115
+
116
+ # Create config from the saved dictionary
117
+ config_dict = checkpoint['config']
118
+ if isinstance(config_dict, str):
119
+ # If config was saved as string, parse it to dict
120
+ import ast
121
+ config_dict = ast.literal_eval(config_dict)
122
+ config = GPTConfig(**config_dict)
123
+
124
  model = GPT(config)
125
  model.load_state_dict(checkpoint['model_state_dict'])
126
  model.eval()