theguywhosucks commited on
Commit
428bb9e
·
verified ·
1 Parent(s): 262bf6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -2,25 +2,32 @@ import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
 
5
  repo_id = "theguywhosucks/mochaV2"
6
 
7
- # Load the tokenizer from the repo (uses tokenizer.json internally)
8
  tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
9
 
10
- # GPT2-style models often don't have a pad token
11
  if tokenizer.pad_token is None:
12
  tokenizer.pad_token = tokenizer.eos_token
13
 
14
- # Load the model (safetensors used automatically)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = AutoModelForCausalLM.from_pretrained(
17
  repo_id,
18
- dtype=torch.float32, # torch_dtype is deprecated
19
- trust_remote_code=True
20
  )
21
  model.to(device)
22
  model.eval()
23
 
 
 
 
 
 
 
24
  def complete_sentence(prompt, max_new_tokens=50, temperature=0.7):
25
  # Tokenize input safely
26
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
@@ -34,6 +41,7 @@ def complete_sentence(prompt, max_new_tokens=50, temperature=0.7):
34
  pad_token_id=tokenizer.pad_token_id
35
  )
36
 
 
37
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
39
  # Launch Gradio app
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Model repo
6
  repo_id = "theguywhosucks/mochaV2"
7
 
8
+ # Load the tokenizer shipped with the model (tokenizer.json internally)
9
  tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
10
 
11
+ # GPT2-style models often lack a pad token; set it to eos_token
12
  if tokenizer.pad_token is None:
13
  tokenizer.pad_token = tokenizer.eos_token
14
 
15
+ # Load model
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model = AutoModelForCausalLM.from_pretrained(
18
  repo_id,
19
+ trust_remote_code=True, # required if model uses custom code
20
+ dtype=torch.float32 # torch_dtype is deprecated, use dtype
21
  )
22
  model.to(device)
23
  model.eval()
24
 
25
+ # Optional: confirm vocab sizes match
26
+ assert tokenizer.vocab_size == model.config.vocab_size, (
27
+ f"Tokenizer vocab size ({tokenizer.vocab_size}) does not match model ({model.config.vocab_size})"
28
+ )
29
+
30
+ # Gradio function
31
  def complete_sentence(prompt, max_new_tokens=50, temperature=0.7):
32
  # Tokenize input safely
33
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
41
  pad_token_id=tokenizer.pad_token_id
42
  )
43
 
44
+ # Decode output, skipping special tokens
45
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
 
47
  # Launch Gradio app