Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -150,7 +150,9 @@ class SmolLM2ForCausalLM(PreTrainedModel):
|
|
| 150 |
if config.tie_word_embeddings:
|
| 151 |
self.lm_head.weight = self.embed_tokens.weight
|
| 152 |
|
| 153 |
-
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
|
|
|
|
|
| 154 |
hidden_states = self.embed_tokens(input_ids)
|
| 155 |
|
| 156 |
# Create causal attention mask if none provided
|
|
@@ -171,8 +173,18 @@ class SmolLM2ForCausalLM(PreTrainedModel):
|
|
| 171 |
loss = None
|
| 172 |
if labels is not None:
|
| 173 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 178 |
return {
|
|
|
|
| 150 |
if config.tie_word_embeddings:
|
| 151 |
self.lm_head.weight = self.embed_tokens.weight
|
| 152 |
|
| 153 |
+
def forward(self, input_ids, attention_mask=None, labels=None, return_dict=None, **kwargs):
|
| 154 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 155 |
+
|
| 156 |
hidden_states = self.embed_tokens(input_ids)
|
| 157 |
|
| 158 |
# Create causal attention mask if none provided
|
|
|
|
| 173 |
loss = None
|
| 174 |
if labels is not None:
|
| 175 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 176 |
+
|
| 177 |
+
if return_dict:
|
| 178 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 179 |
+
return CausalLMOutputWithCrossAttentions(
|
| 180 |
+
loss=loss,
|
| 181 |
+
logits=logits,
|
| 182 |
+
past_key_values=None,
|
| 183 |
+
hidden_states=None,
|
| 184 |
+
attentions=None,
|
| 185 |
+
cross_attentions=None,
|
| 186 |
+
)
|
| 187 |
+
return (loss, logits) if loss is not None else logits
|
| 188 |
|
| 189 |
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 190 |
return {
|