Spaces:
Runtime error
Runtime error
Jonathan Malott
commited on
Commit
·
ba3afcd
1
Parent(s):
1557730
updated model/_init_.py
Browse files- dalle/models/__init__.py +11 -7
dalle/models/__init__.py
CHANGED
|
@@ -43,20 +43,24 @@ class Dalle(nn.Module):
|
|
| 43 |
@classmethod
|
| 44 |
def from_pretrained(cls,
|
| 45 |
path: str) -> nn.Module:
|
| 46 |
-
path = _MODELS[path] if path in _MODELS else path
|
| 47 |
-
path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
|
|
|
|
| 48 |
|
| 49 |
config_base = get_base_config()
|
| 50 |
-
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
| 51 |
config_update = OmegaConf.merge(config_base, config_new)
|
| 52 |
|
| 53 |
model = cls(config_update)
|
| 54 |
-
model.tokenizer = build_tokenizer(
|
| 55 |
context_length=model.config_dataset.context_length,
|
| 56 |
lowercase=True,
|
| 57 |
dropout=None)
|
| 58 |
-
model.stage1.from_ckpt(
|
| 59 |
-
model.stage2.from_ckpt(
|
|
|
|
|
|
|
|
|
|
| 60 |
return model
|
| 61 |
|
| 62 |
@torch.no_grad()
|
|
@@ -199,4 +203,4 @@ class ImageGPT(pl.LightningModule):
|
|
| 199 |
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
| 200 |
|
| 201 |
def on_epoch_start(self):
|
| 202 |
-
self.stage1.eval()
|
|
|
|
| 43 |
@classmethod
|
| 44 |
def from_pretrained(cls,
|
| 45 |
path: str) -> nn.Module:
|
| 46 |
+
#path = _MODELS[path] if path in _MODELS else path
|
| 47 |
+
#path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
|
| 48 |
+
path = ''
|
| 49 |
|
| 50 |
config_base = get_base_config()
|
| 51 |
+
config_new = OmegaConf.load(os.path.join(path, '.cache/minDALL-E/1.3B/config.yaml'))
|
| 52 |
config_update = OmegaConf.merge(config_base, config_new)
|
| 53 |
|
| 54 |
model = cls(config_update)
|
| 55 |
+
model.tokenizer = build_tokenizer('.cache/minDALL-E/1.3B/tokenizer',
|
| 56 |
context_length=model.config_dataset.context_length,
|
| 57 |
lowercase=True,
|
| 58 |
dropout=None)
|
| 59 |
+
model.stage1.from_ckpt('.cache/minDALL-E/1.3B/stage1_last.ckpt')
|
| 60 |
+
model.stage2.from_ckpt('.cache/minDALL-E/1.3B/stage2_last.ckpt')
|
| 61 |
+
#model.stage1.from_ckpt('https://utexas.box.com/shared/static/rpt9miyj2kikogyekpqnkd6y115xp51i.ckpt')
|
| 62 |
+
#model.stage2.from_ckpt('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt')
|
| 63 |
+
|
| 64 |
return model
|
| 65 |
|
| 66 |
@torch.no_grad()
|
|
|
|
| 203 |
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
| 204 |
|
| 205 |
def on_epoch_start(self):
|
| 206 |
+
self.stage1.eval()
|