satyanayak commited on
Commit
d640193
·
1 Parent(s): 852e307

typo on transformer fixed and model train is made false forciblt

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import torch.nn.functional as F
4
  import tiktoken
5
  from huggingface_hub import hf_hub_download
6
- from transformer import GPT, GPTConfig # Import your model class
7
 
8
  # Load the model from Hugging Face Hub
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -17,11 +17,19 @@ def load_model_from_hf():
17
  model = GPT(config)
18
  model.load_state_dict(checkpoint['model_state_dict'])
19
  model.to(device)
20
- model.eval()
 
 
 
 
 
21
  return model
22
 
23
  model = load_model_from_hf()
24
 
 
 
 
25
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
26
  enc = tiktoken.get_encoding('gpt2')
27
  tokens = enc.encode(prompt)
 
3
  import torch.nn.functional as F
4
  import tiktoken
5
  from huggingface_hub import hf_hub_download
6
+ from transformerg import GPT, GPTConfig # Import your model class
7
 
8
  # Load the model from Hugging Face Hub
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
17
  model = GPT(config)
18
  model.load_state_dict(checkpoint['model_state_dict'])
19
  model.to(device)
20
+ model.eval() # Set to evaluation mode
21
+
22
+ # Disable gradient computation
23
+ for param in model.parameters():
24
+ param.requires_grad = False
25
+
26
  return model
27
 
28
  model = load_model_from_hf()
29
 
30
+ # Force model to stay in eval mode
31
+ model.train(False)
32
+
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
  enc = tiktoken.get_encoding('gpt2')
35
  tokens = enc.encode(prompt)