balgot commited on
Commit
ea96b59
·
1 Parent(s): 7a82a02

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +49 -1
README.md CHANGED
@@ -17,4 +17,52 @@ library_name: pytorch
17
  This model was created as a part of the project for FI:PA228 (Masaryk University),
18
  inspired by this paper: [Face Generation from Textual Features using Conditionally trained Inputs to Generative Adversarial Networks](https://arxiv.org/abs/2301.09123)
19
 
20
- It was trained on the generated dataset from BLIP and StyleGAN3.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  This model was created as a part of the project for FI:PA228 (Masaryk University),
18
  inspired by this paper: [Face Generation from Textual Features using Conditionally trained Inputs to Generative Adversarial Networks](https://arxiv.org/abs/2301.09123)
19
 
20
+ It was trained on the generated dataset from BLIP and StyleGAN3.
21
+
22
+ ## How to use
23
+
24
+
25
+ ```python
26
+ import torch.nn as nn
27
+
28
+
29
+ # for now, the model class needs to be defined, so...
30
+ class LaTran(nn.Module):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.pipe = nn.Sequential(
34
+ nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
35
+ nn.MaxPool1d(kernel_size=2, stride=2),
36
+ nn.Flatten(),
37
+ nn.Linear(3072, 512),
38
+ nn.ReLU(),
39
+ nn.Linear(512, 512)
40
+ )
41
+
42
+ # works but is big
43
+ self.pipe2 = nn.Sequential(
44
+ nn.Flatten(),
45
+ nn.Linear(384, 1024),
46
+ nn.ReLU(),
47
+ nn.Linear(1024, 512)
48
+ )
49
+
50
+ # works, and will be loaded
51
+ self.pipe3 = nn.Sequential(
52
+ nn.Flatten(),
53
+ nn.Linear(384, 512),
54
+ nn.ReLU(),
55
+ nn.Linear(512, 512)
56
+ )
57
+
58
+ def forward(self, v):
59
+ return self.pipe3(v.unsqueeze(1))
60
+
61
+
62
+
63
+ # Instantiate and load the model
64
+ dev = ... # device to use
65
+ PATH = "local_path_file_ending_like...sd.pt"
66
+ model = LaTran().to(dev)
67
+ model.load_state_dict(torch.load(TRANSLATION_MODEL, map_location=dev))
68
+ ```