ryomo commited on
Commit
233a7d4
·
verified ·
1 Parent(s): a93cf58

Upload MinjaLM

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. model.safetensors +1 -1
  3. modeling.py +38 -9
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  language:
3
- - ja
4
  license: mit
5
  ---
6
 
 
1
  ---
2
  language:
3
+ - ja
4
  license: mit
5
  ---
6
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c38da631e32f98357b7c2f41b3147af1fb063c185005d27bd76be7921f8fbf71
3
  size 37524064
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42b623bbbed1c65ed75a4b408c68ac8634c77e8b14e964ac026c45cb118fd13b
3
  size 37524064
modeling.py CHANGED
@@ -44,21 +44,50 @@ class MinjaLM(PreTrainedModel):
44
  logits = self.head(x)
45
  return logits
46
 
47
- def generate(self, tokenizer, prompt, max_new_tokens=20, temperature=0.7, device="cpu"):
48
  """
49
- Generate text using the model and tokenizer with temperature sampling.
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
  self.eval()
 
52
  self.to(device)
53
- idx = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
 
 
 
54
 
55
  with torch.no_grad():
56
  for _ in range(max_new_tokens):
57
- logits = self(idx[:, -self.config.block_size:])
58
- logits = logits[:, -1, :] / temperature
59
- probs = torch.softmax(logits, dim=-1)
60
- next_id = torch.multinomial(probs, num_samples=1)
 
 
 
 
 
 
 
 
 
61
  idx = torch.cat([idx, next_id], dim=1)
62
- if next_id.item() == tokenizer.eos_token_id:
 
 
63
  break
64
- return tokenizer.decode(idx[0].tolist(), skip_special_tokens=True)
 
 
44
  logits = self.head(x)
45
  return logits
46
 
47
+ def generate(self, input_ids, max_new_tokens=20, temperature=0.7, eos_token_id=None, pad_token_id=None, do_sample=True):
48
  """
49
+ Generate tokens using the model with temperature sampling.
50
+
51
+ Args:
52
+ input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len)
53
+ max_new_tokens (int): Maximum number of new tokens to generate
54
+ temperature (float): Temperature for sampling (higher = more random)
55
+ eos_token_id (int, optional): Token ID to stop generation
56
+ pad_token_id (int, optional): Padding token ID (unused for now)
57
+ do_sample (bool): Whether to use sampling (True) or greedy decoding (False)
58
+
59
+ Returns:
60
+ torch.Tensor: Generated token IDs of shape (batch_size, original_seq_len + generated_tokens)
61
  """
62
  self.eval()
63
+ device = input_ids.device
64
  self.to(device)
65
+
66
+ # Ensure input_ids has the right shape
67
+ if input_ids.dim() == 1:
68
+ input_ids = input_ids.unsqueeze(0)
69
+
70
+ idx = input_ids.clone()
71
 
72
  with torch.no_grad():
73
  for _ in range(max_new_tokens):
74
+ # Crop to the last block_size tokens if sequence is too long
75
+ idx_cond = idx[:, -self.config.block_size:] if idx.size(1) > self.config.block_size else idx
76
+ logits = self(idx_cond)
77
+ logits = logits[:, -1, :] # Get the last token's logits
78
+
79
+ if do_sample:
80
+ logits = logits / temperature
81
+ probs = torch.softmax(logits, dim=-1)
82
+ next_id = torch.multinomial(probs, num_samples=1)
83
+ else:
84
+ # Greedy decoding
85
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
86
+
87
  idx = torch.cat([idx, next_id], dim=1)
88
+
89
+ # Stop if we hit the end-of-sequence token
90
+ if eos_token_id is not None and next_id.item() == eos_token_id:
91
  break
92
+
93
+ return idx