flpelerin commited on
Commit
841b8e8
·
1 Parent(s): cb76b78

Update file model.py

Browse files
Files changed (1) hide show
  1. model.py +40 -0
model.py CHANGED
@@ -44,6 +44,45 @@ class Model:
44
  return lm_loss
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def generate_text(self, tokenizer, seed_text, num_predict):
48
  max_len = num_predict + len(seed_text)
49
 
@@ -56,6 +95,7 @@ class Model:
56
  text = tokenizer.decode(logits)
57
 
58
  return text
 
59
 
60
 
61
  @staticmethod
 
44
  return lm_loss
45
 
46
 
47
+
48
+ def generate_text(self, model,
49
+ tokenizer,
50
+ prompt: str,
51
+ n_tokens_to_gen: int = 50,
52
+ sample: bool = True,
53
+ top_k: int = 40):
54
+
55
+ model = self.model
56
+ model.eval()
57
+
58
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids
59
+
60
+ for token_n in range(n_tokens_to_gen):
61
+ with torch.no_grad():
62
+ indices_to_input = input_ids
63
+ next_token_logits = model(indices_to_input)[:, -1]
64
+
65
+ probs = F.softmax(next_token_logits, dim=-1)
66
+ (batch, vocab_size) = probs.shape
67
+
68
+ if top_k is not None:
69
+ (values, indices) = torch.topk(probs, k=top_k)
70
+ probs[probs < values[:, -1, None]] = 0
71
+ probs = probs / probs.sum(axis=1, keepdims=True)
72
+
73
+ if sample:
74
+ next_indices = torch.multinomial(probs, num_samples=1)
75
+ else:
76
+ next_indices = torch.argmax(probs, dim=-1)[:, None]
77
+
78
+ input_ids = torch.cat([input_ids, next_indices], dim=1)
79
+
80
+ output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
81
+
82
+ return output_completions
83
+
84
+
85
+ """
86
  def generate_text(self, tokenizer, seed_text, num_predict):
87
  max_len = num_predict + len(seed_text)
88
 
 
95
  text = tokenizer.decode(logits)
96
 
97
  return text
98
+ """
99
 
100
 
101
  @staticmethod