Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,19 @@ if torch.cuda.is_available():
|
|
| 21 |
add_watermarker=False
|
| 22 |
)
|
| 23 |
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
pipe.to("cuda")
|
| 26 |
|
| 27 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
|
|
| 21 |
add_watermarker=False
|
| 22 |
)
|
| 23 |
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
| 24 |
+
# 设置 tokenizer 的最大长度
|
| 25 |
+
max_token_length = 512
|
| 26 |
+
pipe.tokenizer.model_max_length = max_token_length
|
| 27 |
+
|
| 28 |
+
# 调整文本编码器的配置
|
| 29 |
+
pipe.text_encoder.config.max_position_embeddings = max_token_length
|
| 30 |
+
|
| 31 |
+
# 如果需要,重新初始化位置嵌入
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
old_emb = pipe.text_encoder.text_model.embeddings.position_embedding.weight.data
|
| 34 |
+
new_emb = nn.Parameter(torch.zeros(max_token_length, old_emb.shape[1]))
|
| 35 |
+
new_emb.data[:old_emb.shape[0], :] = old_emb
|
| 36 |
+
pipe.text_encoder.text_model.embeddings.position_embedding.weight = new_emb
|
| 37 |
pipe.to("cuda")
|
| 38 |
|
| 39 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|