satyanayak commited on
Commit
852e307
·
1 Parent(s): 2bb12cf

model source file name changed

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. input.txt +0 -0
  3. transformer-basic.py → transformer.py +1 -3
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn.functional as F
4
  import tiktoken
5
  from huggingface_hub import hf_hub_download
6
- from transformer-basic import GPT, GPTConfig # Import your model class
7
 
8
  # Load the model from Hugging Face Hub
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
3
  import torch.nn.functional as F
4
  import tiktoken
5
  from huggingface_hub import hf_hub_download
6
+ from transformer import GPT, GPTConfig # Import your model class
7
 
8
  # Load the model from Hugging Face Hub
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
transformer-basic.py → transformer.py RENAMED
@@ -269,7 +269,7 @@ best_loss = float('inf')
269
  step = 0
270
  losses = [] # Keep track of losses for monitoring
271
  last_time = time.time()
272
- interval = 10 # Print every 10 steps
273
 
274
  while step < total_steps and best_loss > 0.099999:
275
  x, y = train_loader.next_batch()
@@ -317,8 +317,6 @@ print(f'Average of last 100 losses: {sum(losses[-100:]) / min(len(losses), 100):
317
  save_path = 'trained_model.pt'
318
  torch.save({
319
  'model_state_dict': model.state_dict(),
320
- 'optimizer_state_dict': optimizer.state_dict(),
321
- 'scheduler_state_dict': scheduler.state_dict(),
322
  'best_loss': best_loss,
323
  'config': model.config,
324
  }, save_path)
 
269
  step = 0
270
  losses = [] # Keep track of losses for monitoring
271
  last_time = time.time()
272
+ interval = 2 # Print every 10 steps
273
 
274
  while step < total_steps and best_loss > 0.099999:
275
  x, y = train_loader.next_batch()
 
317
  save_path = 'trained_model.pt'
318
  torch.save({
319
  'model_state_dict': model.state_dict(),
 
 
320
  'best_loss': best_loss,
321
  'config': model.config,
322
  }, save_path)