Update infrance_text_pop.py
Browse files- infrance_text_pop.py +4 -2
infrance_text_pop.py
CHANGED
|
@@ -128,9 +128,11 @@ if __name__ == '__main__':
|
|
| 128 |
import torch.nn as nn
|
| 129 |
|
| 130 |
# Initialize the model
|
| 131 |
-
|
| 132 |
device = 'mps'
|
| 133 |
-
model = InrenceTextVAR(
|
|
|
|
|
|
|
| 134 |
model.to(device)
|
| 135 |
|
| 136 |
def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
|
|
|
|
| 128 |
import torch.nn as nn
|
| 129 |
|
| 130 |
# Initialize the model
|
| 131 |
+
checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path
|
| 132 |
device = 'mps'
|
| 133 |
+
model = InrenceTextVAR(device=device)
|
| 134 |
+
state_dict = torch.load(checkpoint,map_location = "cpu")
|
| 135 |
+
model.load_state_dict(state_dict)
|
| 136 |
model.to(device)
|
| 137 |
|
| 138 |
def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
|