Spaces:
Runtime error
Runtime error
Gagan Bhatia commited on
Commit ·
2679662
1
Parent(s): 4ac518a
Update model.py
Browse files- src/models/model.py +15 -5
src/models/model.py
CHANGED
|
@@ -252,24 +252,34 @@ class LightningModel(LightningModule):
|
|
| 252 |
no_decay = ["bias", "LayerNorm.weight"]
|
| 253 |
optimizer_grouped_parameters = [
|
| 254 |
{
|
| 255 |
-
"params": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
"weight_decay": self.weight_decay,
|
| 257 |
},
|
| 258 |
{
|
| 259 |
-
"params": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
"weight_decay": 0.0,
|
| 261 |
},
|
| 262 |
]
|
| 263 |
-
optimizer = AdamW(
|
|
|
|
|
|
|
| 264 |
self.opt = optimizer
|
| 265 |
return [optimizer]
|
| 266 |
|
| 267 |
|
| 268 |
class Summarization:
|
| 269 |
-
"""
|
| 270 |
|
| 271 |
def __init__(self) -> None:
|
| 272 |
-
"""
|
| 273 |
pass
|
| 274 |
|
| 275 |
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|
|
|
|
| 252 |
no_decay = ["bias", "LayerNorm.weight"]
|
| 253 |
optimizer_grouped_parameters = [
|
| 254 |
{
|
| 255 |
+
"params": [
|
| 256 |
+
p
|
| 257 |
+
for n, p in model.named_parameters()
|
| 258 |
+
if not any(nd in n for nd in no_decay)
|
| 259 |
+
],
|
| 260 |
"weight_decay": self.weight_decay,
|
| 261 |
},
|
| 262 |
{
|
| 263 |
+
"params": [
|
| 264 |
+
p
|
| 265 |
+
for n, p in model.named_parameters()
|
| 266 |
+
if any(nd in n for nd in no_decay)
|
| 267 |
+
],
|
| 268 |
"weight_decay": 0.0,
|
| 269 |
},
|
| 270 |
]
|
| 271 |
+
optimizer = AdamW(
|
| 272 |
+
optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon
|
| 273 |
+
)
|
| 274 |
self.opt = optimizer
|
| 275 |
return [optimizer]
|
| 276 |
|
| 277 |
|
| 278 |
class Summarization:
|
| 279 |
+
"""Custom Summarization class"""
|
| 280 |
|
| 281 |
def __init__(self) -> None:
|
| 282 |
+
"""initiates Summarization class"""
|
| 283 |
pass
|
| 284 |
|
| 285 |
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|