mathminakshi commited on
Commit
77c1734
·
verified ·
1 Parent(s): 2584bfb

changed app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -9,12 +9,12 @@ def get_model():
9
  """Load the trained GPT model."""
10
  model = GPT(GPTConfig())
11
  # Load from the Hugging Face Hub instead of local file
12
- model_path = 'YOUR_USERNAME/YOUR_MODEL_REPO/final_best_model.pth'
13
- model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/final_best_model.pth', map_location='cpu')['model_state_dict'])
14
  model.eval()
15
  return model
16
 
17
- def generate_text(prompt, max_tokens=500, temperature=0.8, top_k=40):
18
  """Generate text based on the prompt."""
19
  # Encode the prompt
20
  enc = tiktoken.get_encoding('gpt2')
 
9
  """Load the trained GPT model."""
10
  model = GPT(GPTConfig())
11
  # Load from the Hugging Face Hub instead of local file
12
+ model_path = 'mathminakshi/custom_gpt2'
13
+ model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/best_model.pth', map_location='cpu')['model_state_dict'])
14
  model.eval()
15
  return model
16
 
17
+ def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
18
  """Generate text based on the prompt."""
19
  # Encode the prompt
20
  enc = tiktoken.get_encoding('gpt2')