Instructions to use babkasotona/2b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/2b with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/2b", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- .gitignore +21 -0
- Untitled.ipynb +139 -0
- dataset.py +300 -0
- dataset_sample.ipynb +170 -0
- model_index.json +24 -0
- pipeline_sdxs.py +348 -0
- pipeline_sdxs_t5.py +291 -0
- scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json +22 -0
- scheduler/scheduler_config.json +22 -0
- t.py +116 -0
- test.ipynb +3 -0
- text_encoder/.ipynb_checkpoints/config-checkpoint.json +101 -0
- text_encoder/config.json +101 -0
- text_encoder/model.safetensors +3 -0
- tokenizer/chat_template.jinja +154 -0
- tokenizer/tokenizer.json +3 -0
- tokenizer/tokenizer_config.json +32 -0
- train-Copy1.py +924 -0
- transformer/config.json +37 -0
- transformer/diffusion_pytorch_model.safetensors +3 -0
- vae/.ipynb_checkpoints/config-checkpoint.json +56 -0
- vae/config.json +56 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
- wandb/debug-cli.root.log +0 -0
- wandb/debug-internal.log +0 -0
- wandb/debug.log +19 -0
- wandb/offline-run-20260428_132658-o9052r27/files/requirements.txt +117 -0
- wandb/offline-run-20260428_132658-o9052r27/logs/debug-core.log +14 -0
- wandb/offline-run-20260428_132658-o9052r27/logs/debug-internal.log +15 -0
- wandb/offline-run-20260428_132658-o9052r27/logs/debug.log +21 -0
- wandb/offline-run-20260428_132658-o9052r27/run-o9052r27.wandb +0 -0
- wandb/run-20260428_171645-wt40fdyx/files/output.log +385 -0
- wandb/run-20260428_171645-wt40fdyx/files/requirements.txt +117 -0
- wandb/run-20260428_171645-wt40fdyx/files/wandb-metadata.json +46 -0
- wandb/run-20260428_171645-wt40fdyx/logs/debug-core.log +7 -0
- wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log +0 -0
- wandb/run-20260428_171645-wt40fdyx/logs/debug.log +19 -0
- wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
media/refined.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
test.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Jupyter Notebook
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.ipynb_checkpoints/
|
| 5 |
+
*.ipynb_checkpoints/*
|
| 6 |
+
.ipynb_checkpoints/*
|
| 7 |
+
src/samples
|
| 8 |
+
# cache
|
| 9 |
+
cache
|
| 10 |
+
datasets
|
| 11 |
+
test
|
| 12 |
+
wandb
|
| 13 |
+
nohup.out
|
| 14 |
+
samples/
|
| 15 |
+
transformer/
|
| 16 |
+
*.jpg
|
| 17 |
+
*.png
|
| 18 |
+
datasets/
|
| 19 |
+
samples/
|
| 20 |
+
*.jpg
|
| 21 |
+
train.py
|
Untitled.ipynb
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "7e8f9dc5-d07a-4538-bc03-8953412a72fa",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Keyword arguments {'safety_checker': <__main__.DummyCosmosSafetyChecker object at 0x7f7e8c3fb620>} are not expected by SdxsPipeline and will be ignored.\n"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"data": {
|
| 18 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 19 |
+
"model_id": "99e2522e93064308b5dd34923a133c39",
|
| 20 |
+
"version_major": 2,
|
| 21 |
+
"version_minor": 0
|
| 22 |
+
},
|
| 23 |
+
"text/plain": [
|
| 24 |
+
"Loading pipeline components...: 0%| | 0/5 [00:00<?, ?it/s]"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"output_type": "display_data"
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"data": {
|
| 32 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 33 |
+
"model_id": "def23cc245b2470a9012f70d5e4c78ed",
|
| 34 |
+
"version_major": 2,
|
| 35 |
+
"version_minor": 0
|
| 36 |
+
},
|
| 37 |
+
"text/plain": [
|
| 38 |
+
"Loading weights: 0%| | 0/195 [00:00<?, ?it/s]"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"output_type": "display_data"
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"name": "stderr",
|
| 46 |
+
"output_type": "stream",
|
| 47 |
+
"text": [
|
| 48 |
+
"The config attributes {'final_sigmas_type': 'sigma_min', 'sigma_data': 1.0, 'sigma_max': 80.0, 'sigma_min': 0.002} were passed to FlowMatchEulerDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.\n",
|
| 49 |
+
"Sampling: 100%|██████████| 40/40 [00:10<00:00, 3.65it/s]\n"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"name": "stdout",
|
| 54 |
+
"output_type": "stream",
|
| 55 |
+
"text": [
|
| 56 |
+
"Готово! Изображение сохранено как output.png\n"
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
],
|
| 60 |
+
"source": [
|
| 61 |
+
"import torch\n",
|
| 62 |
+
"from diffusers import Cosmos2TextToImagePipeline\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"class DummyCosmosSafetyChecker:\n",
|
| 65 |
+
" def to(self, *args, **kwargs):\n",
|
| 66 |
+
" return self\n",
|
| 67 |
+
" \n",
|
| 68 |
+
" def eval(self):\n",
|
| 69 |
+
" return self\n",
|
| 70 |
+
"\n",
|
| 71 |
+
" # Обход проверки текста\n",
|
| 72 |
+
" def check_text_safety(self, prompt, *args, **kwargs):\n",
|
| 73 |
+
" return True\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" # Обход проверки \"видео\" (картинки из 1 кадра)\n",
|
| 76 |
+
" def check_video_safety(self, vid, *args, **kwargs):\n",
|
| 77 |
+
" # Просто возвращаем тензор обратно без изменений\n",
|
| 78 |
+
" return vid\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" # На всякий случай оставляем оригинальный __call__\n",
|
| 81 |
+
" def __call__(self, images, **kwargs):\n",
|
| 82 |
+
" return images, [False] * len(images)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"model_id = \"/workspace/sdxs-2b\"\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"pipe = Cosmos2TextToImagePipeline.from_pretrained(\n",
|
| 87 |
+
" model_id,\n",
|
| 88 |
+
" safety_checker=DummyCosmosSafetyChecker(), \n",
|
| 89 |
+
" torch_dtype=torch.bfloat16 \n",
|
| 90 |
+
")\n",
|
| 91 |
+
"pipe.to(\"cuda\")\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"prompt = \"In a serene garden, two young girls stand side by side, their youthful energy palpable. The girl on the left, adorned with a blue dress and a matching blue flower in her hair, gazes directly at the viewer, her eyes sparkling with curiosity.\"#\"There is a young male character standing against a vibrant, colorful graffiti wall. he is wearing a hat, a jacket adorned with gold accents, and black shorts.\"\n",
|
| 94 |
+
"negative_prompt = \"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.\"\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"# 3. Генерируем изображение\n",
|
| 97 |
+
"output = pipe(\n",
|
| 98 |
+
" height = 1024,\n",
|
| 99 |
+
" width=1024,\n",
|
| 100 |
+
" prompt=prompt, \n",
|
| 101 |
+
" negative_prompt=negative_prompt, \n",
|
| 102 |
+
" generator=torch.Generator(device=\"cuda\").manual_seed(1)\n",
|
| 103 |
+
").images[0]\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"output.save(\"output.png\")\n",
|
| 106 |
+
"print(\"Готово! Изображение сохранено как output.png\")"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"id": "8a173167-6c28-4bbd-8879-1375e0fd37f0",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": []
|
| 116 |
+
}
|
| 117 |
+
],
|
| 118 |
+
"metadata": {
|
| 119 |
+
"kernelspec": {
|
| 120 |
+
"display_name": "Python3 (ipykernel)",
|
| 121 |
+
"language": "python",
|
| 122 |
+
"name": "python3"
|
| 123 |
+
},
|
| 124 |
+
"language_info": {
|
| 125 |
+
"codemirror_mode": {
|
| 126 |
+
"name": "ipython",
|
| 127 |
+
"version": 3
|
| 128 |
+
},
|
| 129 |
+
"file_extension": ".py",
|
| 130 |
+
"mimetype": "text/x-python",
|
| 131 |
+
"name": "python",
|
| 132 |
+
"nbconvert_exporter": "python",
|
| 133 |
+
"pygments_lexer": "ipython3",
|
| 134 |
+
"version": "3.12.13"
|
| 135 |
+
}
|
| 136 |
+
},
|
| 137 |
+
"nbformat": 4,
|
| 138 |
+
"nbformat_minor": 5
|
| 139 |
+
}
|
dataset.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pip install flash-attn --no-build-isolation
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import gc
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import json
|
| 8 |
+
import shutil
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
from datasets import Dataset, load_from_disk, concatenate_datasets
|
| 12 |
+
from diffusers import AutoencoderKLQwenImage
|
| 13 |
+
from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
|
| 14 |
+
from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from datetime import timedelta
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
|
| 21 |
+
accelerator = Accelerator()
|
| 22 |
+
device = accelerator.device
|
| 23 |
+
is_main_process = accelerator.is_main_process
|
| 24 |
+
process_index = accelerator.process_index
|
| 25 |
+
num_processes = accelerator.num_processes
|
| 26 |
+
|
| 27 |
+
# ---------------- 1️⃣ Настройки ----------------
|
| 28 |
+
dtype = torch.float16
|
| 29 |
+
batch_size = 5
|
| 30 |
+
min_size = 320
|
| 31 |
+
max_size = 640
|
| 32 |
+
step = 32
|
| 33 |
+
empty_share = 0.0
|
| 34 |
+
limit = 0
|
| 35 |
+
|
| 36 |
+
folder_path = "/workspace/dataset/d23"
|
| 37 |
+
save_path = "/workspace/ds234_640_vae_qwen"
|
| 38 |
+
os.makedirs(save_path, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
def clear_cuda_memory():
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
| 43 |
+
print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB")
|
| 44 |
+
torch.cuda.empty_cache()
|
| 45 |
+
gc.collect()
|
| 46 |
+
|
| 47 |
+
# ---------------- 2️⃣ Загрузка моделей ----------------
|
| 48 |
+
def load_models():
|
| 49 |
+
print(f"[GPU {process_index}] Загрузка моделей...")
|
| 50 |
+
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
|
| 51 |
+
return vae
|
| 52 |
+
|
| 53 |
+
vae = load_models()
|
| 54 |
+
|
| 55 |
+
shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
|
| 56 |
+
scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0
|
| 57 |
+
|
| 58 |
+
mean = getattr(vae.config, "latents_mean", None)
|
| 59 |
+
std = getattr(vae.config, "latents_std", None)
|
| 60 |
+
if mean is not None and std is not None:
|
| 61 |
+
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1)
|
| 62 |
+
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1)
|
| 63 |
+
|
| 64 |
+
# ---------------- 3️⃣ Трансформации ----------------
|
| 65 |
+
def get_image_transform(min_size=256, max_size=512, step=64):
|
| 66 |
+
def transform(img, dry_run=False):
|
| 67 |
+
original_width, original_height = img.size
|
| 68 |
+
|
| 69 |
+
if original_width >= original_height:
|
| 70 |
+
new_width = max_size
|
| 71 |
+
new_height = int(max_size * original_height / original_width)
|
| 72 |
+
else:
|
| 73 |
+
new_height = max_size
|
| 74 |
+
new_width = int(max_size * original_width / original_height)
|
| 75 |
+
|
| 76 |
+
if new_height < min_size or new_width < min_size:
|
| 77 |
+
if original_width <= original_height:
|
| 78 |
+
new_width = min_size
|
| 79 |
+
new_height = int(min_size * original_height / original_width)
|
| 80 |
+
else:
|
| 81 |
+
new_height = min_size
|
| 82 |
+
new_width = int(min_size * original_width / original_height)
|
| 83 |
+
|
| 84 |
+
crop_width = min(max_size, (new_width // step) * step)
|
| 85 |
+
crop_height = min(max_size, (new_height // step) * step)
|
| 86 |
+
|
| 87 |
+
crop_width = max(min_size, crop_width)
|
| 88 |
+
crop_height = max(min_size, crop_height)
|
| 89 |
+
|
| 90 |
+
if dry_run:
|
| 91 |
+
return crop_width, crop_height
|
| 92 |
+
|
| 93 |
+
img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
|
| 94 |
+
|
| 95 |
+
top = (new_height - crop_height) // 3
|
| 96 |
+
left = 0
|
| 97 |
+
|
| 98 |
+
img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
|
| 99 |
+
|
| 100 |
+
final_width, final_height = img_cropped.size
|
| 101 |
+
|
| 102 |
+
img_tensor = ToTensor()(img_cropped)
|
| 103 |
+
img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor)
|
| 104 |
+
return img_tensor, img_cropped, final_width, final_height
|
| 105 |
+
|
| 106 |
+
return transform
|
| 107 |
+
|
| 108 |
+
# ---------------- 4️⃣ Функции обработки ----------------
|
| 109 |
+
def clean_label(label):
|
| 110 |
+
label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","")
|
| 111 |
+
label = label.replace("The image depicts ","").replace("The image presents ","")
|
| 112 |
+
label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
|
| 113 |
+
if label.startswith("."):
|
| 114 |
+
label = label[1:].lstrip()
|
| 115 |
+
return label
|
| 116 |
+
|
| 117 |
+
def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
|
| 118 |
+
labels_for_model = []
|
| 119 |
+
labels_for_logging = []
|
| 120 |
+
|
| 121 |
+
for label in original_labels:
|
| 122 |
+
if random.random() < prob_to_make_empty:
|
| 123 |
+
labels_for_model.append("")
|
| 124 |
+
labels_for_logging.append(f"zero: {label}")
|
| 125 |
+
else:
|
| 126 |
+
labels_for_model.append(label)
|
| 127 |
+
labels_for_logging.append(label)
|
| 128 |
+
|
| 129 |
+
return labels_for_model, labels_for_logging
|
| 130 |
+
|
| 131 |
+
def encode_to_latents(images, texts):
|
| 132 |
+
transform = get_image_transform(min_size, max_size, step)
|
| 133 |
+
|
| 134 |
+
transformed_tensors = []
|
| 135 |
+
widths, heights = [], []
|
| 136 |
+
|
| 137 |
+
for img in images:
|
| 138 |
+
try:
|
| 139 |
+
t_img, _, w, h = transform(img)
|
| 140 |
+
transformed_tensors.append(t_img)
|
| 141 |
+
widths.append(w)
|
| 142 |
+
heights.append(h)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Ошибка трансформации: {e}")
|
| 145 |
+
|
| 146 |
+
if not transformed_tensors:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
|
| 150 |
+
|
| 151 |
+
if batch_tensor.ndim==4:
|
| 152 |
+
batch_tensor = batch_tensor.unsqueeze(2)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
posteriors = vae.encode(batch_tensor).latent_dist.mode()
|
| 156 |
+
if mean is not None and std is not None:
|
| 157 |
+
posteriors = (posteriors - latents_mean) / latents_std
|
| 158 |
+
posteriors = (posteriors - shift_factor) / scaling_factor
|
| 159 |
+
|
| 160 |
+
#latents_np = posteriors.cpu().numpy()
|
| 161 |
+
latents_np = posteriors.squeeze(2).cpu().numpy()
|
| 162 |
+
|
| 163 |
+
text_labels = [clean_label(text) for text in texts]
|
| 164 |
+
_, text_labels = process_labels_for_guidance(text_labels, empty_share)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"vae": latents_np,
|
| 168 |
+
"text": text_labels,
|
| 169 |
+
"width": widths,
|
| 170 |
+
"height": heights
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# ---------------- 5️⃣ Обработка папки ----------------
|
| 174 |
+
def process_folder(folder_path, limit=None):
|
| 175 |
+
image_paths, text_paths, width, height = [], [], [], []
|
| 176 |
+
transform = get_image_transform(min_size, max_size, step)
|
| 177 |
+
|
| 178 |
+
for root, _, files in os.walk(folder_path):
|
| 179 |
+
for filename in files:
|
| 180 |
+
if filename.lower().endswith((".jpg",".jpeg",".png")):
|
| 181 |
+
image_path = os.path.join(root, filename)
|
| 182 |
+
try:
|
| 183 |
+
img = Image.open(image_path)
|
| 184 |
+
except:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
w,h = transform(img, dry_run=True)
|
| 188 |
+
text_path = os.path.splitext(image_path)[0]+".txt"
|
| 189 |
+
|
| 190 |
+
if os.path.exists(text_path):
|
| 191 |
+
image_paths.append(image_path)
|
| 192 |
+
text_paths.append(text_path)
|
| 193 |
+
width.append(w)
|
| 194 |
+
height.append(h)
|
| 195 |
+
|
| 196 |
+
print(f"Найдено {len(image_paths)} изображений")
|
| 197 |
+
return image_paths, text_paths, width, height
|
| 198 |
+
|
| 199 |
+
def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
|
| 200 |
+
total_files = len(image_paths)
|
| 201 |
+
start_time = time.time()
|
| 202 |
+
|
| 203 |
+
for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1):
|
| 204 |
+
end = min(start+chunk_size,total_files)
|
| 205 |
+
|
| 206 |
+
chunk_image_paths = image_paths[start:end]
|
| 207 |
+
chunk_text_paths = text_paths[start:end]
|
| 208 |
+
chunk_widths = width[start:end]
|
| 209 |
+
chunk_heights = height[start:end]
|
| 210 |
+
|
| 211 |
+
chunk_texts = []
|
| 212 |
+
for text_path in chunk_text_paths:
|
| 213 |
+
try:
|
| 214 |
+
with open(text_path,'r',encoding='utf-8') as f:
|
| 215 |
+
chunk_texts.append(f.read().strip())
|
| 216 |
+
except:
|
| 217 |
+
chunk_texts.append("")
|
| 218 |
+
|
| 219 |
+
size_groups = {}
|
| 220 |
+
for i in range(len(chunk_image_paths)):
|
| 221 |
+
key=(chunk_widths[i],chunk_heights[i])
|
| 222 |
+
size_groups.setdefault(key,{"image_paths":[],"texts":[]})
|
| 223 |
+
size_groups[key]["image_paths"].append(chunk_image_paths[i])
|
| 224 |
+
size_groups[key]["texts"].append(chunk_texts[i])
|
| 225 |
+
|
| 226 |
+
for size_key,group_data in size_groups.items():
|
| 227 |
+
group_dataset = Dataset.from_dict(group_data)
|
| 228 |
+
|
| 229 |
+
processed_group = group_dataset.map(
|
| 230 |
+
lambda ex: encode_to_latents(
|
| 231 |
+
[Image.open(p) for p in ex["image_paths"]],
|
| 232 |
+
#[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память
|
| 233 |
+
ex["texts"]
|
| 234 |
+
),
|
| 235 |
+
batched=True,
|
| 236 |
+
batch_size=batch_size,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# --- NEW: уникальный путь ---
|
| 240 |
+
group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_"
|
| 241 |
+
# --- END NEW ---
|
| 242 |
+
|
| 243 |
+
processed_group.save_to_disk(group_save_path)
|
| 244 |
+
clear_cuda_memory()
|
| 245 |
+
|
| 246 |
+
# ---------------- 7️⃣ Объединение ----------------
|
| 247 |
+
def combine_chunks(temp_path, final_path):
|
| 248 |
+
chunks = sorted([
|
| 249 |
+
os.path.join(temp_path,d)
|
| 250 |
+
for d in os.listdir(temp_path)
|
| 251 |
+
if "chunk_" in d
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
datasets = [load_from_disk(c) for c in chunks]
|
| 255 |
+
combined = concatenate_datasets(datasets)
|
| 256 |
+
combined.save_to_disk(final_path)
|
| 257 |
+
|
| 258 |
+
print("✅ Сохранено")
|
| 259 |
+
|
| 260 |
+
# ---------------- MAIN ----------------
|
| 261 |
+
temp_path = f"{save_path}_temp"
|
| 262 |
+
os.makedirs(temp_path, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
image_paths, text_paths, width, height = process_folder(folder_path,limit)
|
| 265 |
+
|
| 266 |
+
# сортировка
|
| 267 |
+
sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i]))
|
| 268 |
+
image_paths = [image_paths[i] for i in sorted_indices]
|
| 269 |
+
text_paths = [text_paths[i] for i in sorted_indices]
|
| 270 |
+
width = [width[i] for i in sorted_indices]
|
| 271 |
+
height = [height[i] for i in sorted_indices]
|
| 272 |
+
|
| 273 |
+
# --- shard по GPU ---
|
| 274 |
+
indices = list(range(len(image_paths)))
|
| 275 |
+
indices = indices[process_index::num_processes]
|
| 276 |
+
|
| 277 |
+
image_paths = [image_paths[i] for i in indices]
|
| 278 |
+
text_paths = [text_paths[i] for i in indices]
|
| 279 |
+
width = [width[i] for i in indices]
|
| 280 |
+
height = [height[i] for i in indices]
|
| 281 |
+
|
| 282 |
+
print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов")
|
| 283 |
+
|
| 284 |
+
process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=batch_size)
|
| 285 |
+
|
| 286 |
+
accelerator.wait_for_everyone()
|
| 287 |
+
|
| 288 |
+
# --- NEW: только главный процесс ---
|
| 289 |
+
if is_main_process:
|
| 290 |
+
#try:
|
| 291 |
+
#shutil.rmtree(folder_path)
|
| 292 |
+
#except:
|
| 293 |
+
# pass
|
| 294 |
+
|
| 295 |
+
combine_chunks(temp_path, save_path)
|
| 296 |
+
|
| 297 |
+
try:
|
| 298 |
+
shutil.rmtree(temp_path)
|
| 299 |
+
except:
|
| 300 |
+
pass
|
dataset_sample.ipynb
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 3,
|
| 6 |
+
"id": "9c312df2-cb57-44f6-af54-3af6ab8f962f",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"ename": "ModuleNotFoundError",
|
| 11 |
+
"evalue": "No module named 'numpy'",
|
| 12 |
+
"output_type": "error",
|
| 13 |
+
"traceback": [
|
| 14 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 15 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
| 16 |
+
"Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#from datasets import load_from_disk\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n",
|
| 17 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpy'"
|
| 18 |
+
]
|
| 19 |
+
}
|
| 20 |
+
],
|
| 21 |
+
"source": [
|
| 22 |
+
"from datasets import load_from_disk\n",
|
| 23 |
+
"import numpy as np\n",
|
| 24 |
+
"import torch\n",
|
| 25 |
+
"from PIL import Image\n",
|
| 26 |
+
"from collections import defaultdict\n",
|
| 27 |
+
"from diffusers import AutoencoderKLQwenImage\n",
|
| 28 |
+
"import gc\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"def analyze_dataset_by_size(dataset_path):\n",
|
| 31 |
+
" \"\"\"\n",
|
| 32 |
+
" Группирует датасет по размерам изображений и выводит базовую информацию.\n",
|
| 33 |
+
" \"\"\"\n",
|
| 34 |
+
" # Настройка устройства и типа данных\n",
|
| 35 |
+
" dtype = torch.float16\n",
|
| 36 |
+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 37 |
+
" \n",
|
| 38 |
+
" # Загрузка VAE модели\n",
|
| 39 |
+
" print(\"Загрузка VAE модели...\")\n",
|
| 40 |
+
" vae = AutoencoderKLQwenImage.from_pretrained(\"vae\",torch_dtype=dtype).to(device).eval()\n",
|
| 41 |
+
" shift_factor = getattr(vae.config, \"shift_factor\", 0.0)\n",
|
| 42 |
+
" if shift_factor is None:\n",
|
| 43 |
+
" shift_factor = 0.0\n",
|
| 44 |
+
" \n",
|
| 45 |
+
" scaling_factor = getattr(vae.config, \"scaling_factor\", 1.0)\n",
|
| 46 |
+
" if scaling_factor is None:\n",
|
| 47 |
+
" scaling_factor = 1.0\n",
|
| 48 |
+
" \n",
|
| 49 |
+
" mean = getattr(vae.config, \"latents_mean\", None)\n",
|
| 50 |
+
" std = getattr(vae.config, \"latents_std\", None)\n",
|
| 51 |
+
" if mean is not None and std is not None:\n",
|
| 52 |
+
" latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)\n",
|
| 53 |
+
" latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)\n",
|
| 54 |
+
" \n",
|
| 55 |
+
" # Загружаем датасет\n",
|
| 56 |
+
" print(f\"Загрузка датасета из {dataset_path}...\")\n",
|
| 57 |
+
" dataset = load_from_disk(dataset_path)\n",
|
| 58 |
+
"\n",
|
| 59 |
+
" print(f\"Осталось примеров после фильтрации: {len(dataset)}\")\n",
|
| 60 |
+
" \n",
|
| 61 |
+
" # Группируем примеры по размерам\n",
|
| 62 |
+
" print(\"\\nГруппировка примеров по размерам...\")\n",
|
| 63 |
+
" size_to_indices = defaultdict(list)\n",
|
| 64 |
+
" \n",
|
| 65 |
+
" # Собираем примеры с одинаковыми размерами\n",
|
| 66 |
+
" # Собираем примеры с одинаковыми размерами (оптимизированная версия)\n",
|
| 67 |
+
" widths = dataset[\"width\"]\n",
|
| 68 |
+
" heights = dataset[\"height\"]\n",
|
| 69 |
+
" for i, (w, h) in enumerate(zip(widths, heights)):\n",
|
| 70 |
+
" size_to_indices[(w, h)].append(i)\n",
|
| 71 |
+
" \n",
|
| 72 |
+
" # Сортируем размеры по количеству примеров\n",
|
| 73 |
+
" print(\"\\nСортируем...\")\n",
|
| 74 |
+
" size_stats = [(size, len(indices)) for size, indices in size_to_indices.items()]\n",
|
| 75 |
+
" size_stats.sort(key=lambda x: x[1], reverse=True)\n",
|
| 76 |
+
" \n",
|
| 77 |
+
" # Выводим информацию о каждой группе и показываем первый пример\n",
|
| 78 |
+
" for size, count in size_stats:\n",
|
| 79 |
+
" width, height = size\n",
|
| 80 |
+
" first_idx = size_to_indices[size][1]\n",
|
| 81 |
+
" example = dataset[first_idx]\n",
|
| 82 |
+
" \n",
|
| 83 |
+
" print(f\"\\n--- Батч {width}x{height}: {count} примеров ---\")\n",
|
| 84 |
+
" \n",
|
| 85 |
+
" # Декодируем латентное представление для первого примера\n",
|
| 86 |
+
" latent = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n",
|
| 87 |
+
" \n",
|
| 88 |
+
" # 1. Снова обманываем VAE, превращая картинку в \"видео из 1 кадра\" [B, C, 1, H, W]\n",
|
| 89 |
+
" if latent.ndim == 4:\n",
|
| 90 |
+
" latent = latent.unsqueeze(2)\n",
|
| 91 |
+
" \n",
|
| 92 |
+
" with torch.no_grad():\n",
|
| 93 |
+
" if latents_mean is not None and latents_std is not None:\n",
|
| 94 |
+
" latent = latent * latents_std + latents_mean\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" print(f\"Min of latent_for_vae: {latent.min()}\")\n",
|
| 97 |
+
" print(f\"Max of latent_for_vae: {latent.max()}\")\n",
|
| 98 |
+
" print(f\"Mean of latent_for_vae: {latent.mean()}\")\n",
|
| 99 |
+
" print(f\"Std: {latent.std().item():.4f}\")\n",
|
| 100 |
+
" if torch.isnan(latent).any() or torch.isinf(latent).any():\n",
|
| 101 |
+
" print(\"WARNING: Raw latents contain NaN or Inf values!\")\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" reconstructed_image = vae.decode(latent).sample\n",
|
| 104 |
+
" \n",
|
| 105 |
+
" # 2. Вытаскиваем обычную 3D-картинку [C, H, W] из 5D-видеотензора\n",
|
| 106 |
+
" if reconstructed_image.ndim == 5:\n",
|
| 107 |
+
" # Берем нулевой батч, все каналы, нулевой кадр, всю высоту и ширину\n",
|
| 108 |
+
" img_tensor = reconstructed_image[0, :, 0, :, :] \n",
|
| 109 |
+
" else:\n",
|
| 110 |
+
" img_tensor = reconstructed_image.squeeze(0) # На всякий случай, если VAE вернул 4D\n",
|
| 111 |
+
" \n",
|
| 112 |
+
" img_array = img_tensor.cpu().numpy()\n",
|
| 113 |
+
" img_array = np.transpose(img_array, (1, 2, 0))\n",
|
| 114 |
+
" img_array = (img_array + 1) / 2 # Нормализация к [0, 1]\n",
|
| 115 |
+
" img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8) # Преобразуем в uint8 для PIL\n",
|
| 116 |
+
" \n",
|
| 117 |
+
" # Создаем PIL изображение из массива\n",
|
| 118 |
+
" pil_image = Image.fromarray(img_array)\n",
|
| 119 |
+
" print(f\"Текст: {example['text']}\")\n",
|
| 120 |
+
" print(f\"Ключи: {', '.join(example.keys())}\")\n",
|
| 121 |
+
" print(f\"latent: {latent.shape}\")\n",
|
| 122 |
+
" pil_image.save(\"1.jpg\")\n",
|
| 123 |
+
" \n",
|
| 124 |
+
" # Очистка памяти\n",
|
| 125 |
+
" if torch.cuda.is_available():\n",
|
| 126 |
+
" torch.cuda.empty_cache()\n",
|
| 127 |
+
" gc.collect()\n",
|
| 128 |
+
" \n",
|
| 129 |
+
" return size_to_indices # Возвращаем словарь с индексами по группам\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"# Использование\n",
|
| 132 |
+
"if __name__ == \"__main__\":\n",
|
| 133 |
+
" # Путь к датасету\n",
|
| 134 |
+
" save_path = \"datasets/ds234_640_vae_qwen\"\n",
|
| 135 |
+
" \n",
|
| 136 |
+
" # Анализ датасета\n",
|
| 137 |
+
" size_groups = analyze_dataset_by_size(save_path)"
|
| 138 |
+
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "code",
|
| 142 |
+
"execution_count": null,
|
| 143 |
+
"id": "74a5d11d-369f-4f25-9ee0-31d3bccd0254",
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"outputs": [],
|
| 146 |
+
"source": []
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"metadata": {
|
| 150 |
+
"kernelspec": {
|
| 151 |
+
"display_name": "Python 3 (ipykernel)",
|
| 152 |
+
"language": "python",
|
| 153 |
+
"name": "python3"
|
| 154 |
+
},
|
| 155 |
+
"language_info": {
|
| 156 |
+
"codemirror_mode": {
|
| 157 |
+
"name": "ipython",
|
| 158 |
+
"version": 3
|
| 159 |
+
},
|
| 160 |
+
"file_extension": ".py",
|
| 161 |
+
"mimetype": "text/x-python",
|
| 162 |
+
"name": "python",
|
| 163 |
+
"nbconvert_exporter": "python",
|
| 164 |
+
"pygments_lexer": "ipython3",
|
| 165 |
+
"version": "3.12.3"
|
| 166 |
+
}
|
| 167 |
+
},
|
| 168 |
+
"nbformat": 4,
|
| 169 |
+
"nbformat_minor": 5
|
| 170 |
+
}
|
model_index.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": ["pipeline_sdxs", "SdxsPipeline"],
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"scheduler": [
|
| 5 |
+
"diffusers",
|
| 6 |
+
"FlowMatchEulerDiscreteScheduler"
|
| 7 |
+
],
|
| 8 |
+
"text_encoder": [
|
| 9 |
+
"transformers",
|
| 10 |
+
"Qwen3_5ForConditionalGeneration"
|
| 11 |
+
],
|
| 12 |
+
"tokenizer": [
|
| 13 |
+
"transformers",
|
| 14 |
+
"Qwen3_5Tokenizer"
|
| 15 |
+
],
|
| 16 |
+
"transformer": [
|
| 17 |
+
"diffusers",
|
| 18 |
+
"CosmosTransformer3DModel"
|
| 19 |
+
],
|
| 20 |
+
"vae": [
|
| 21 |
+
"diffusers",
|
| 22 |
+
"AutoencoderKLQwenImage"
|
| 23 |
+
]
|
| 24 |
+
}
|
pipeline_sdxs.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import List, Union, Optional, Tuple
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from diffusers import DiffusionPipeline
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SdxsPipelineOutput(BaseOutput):
|
| 13 |
+
images: Union[List[Image.Image], np.ndarray]
|
| 14 |
+
prompt: Optional[Union[str, List[str]]] = None
|
| 15 |
+
|
| 16 |
+
class SdxsPipeline(DiffusionPipeline):
|
| 17 |
+
# Cosmos требует 512 токенов
|
| 18 |
+
MAX_TEXT_TOKENS = 512
|
| 19 |
+
|
| 20 |
+
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# Регистрируем модули (с Qwen)
|
| 23 |
+
self.register_modules(
|
| 24 |
+
vae=vae,
|
| 25 |
+
text_encoder=text_encoder,
|
| 26 |
+
tokenizer=tokenizer,
|
| 27 |
+
transformer=transformer,
|
| 28 |
+
scheduler=scheduler
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.vae_scale_factor = getattr(self.vae.config, "spatial_compression_ratio", 8)
|
| 32 |
+
if hasattr(self.vae.config, "block_out_channels"):
|
| 33 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 34 |
+
|
| 35 |
+
# Загружаем mean и std для VAE (Cosmos-style)
|
| 36 |
+
mean = getattr(self.vae.config, "latents_mean", None)
|
| 37 |
+
std = getattr(self.vae.config, "latents_std", None)
|
| 38 |
+
if mean is not None and std is not None:
|
| 39 |
+
self.vae_latents_mean = torch.tensor(mean).view(1, len(mean), 1, 1, 1)
|
| 40 |
+
# Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
|
| 41 |
+
self.vae_latents_std = torch.tensor(std).view(1, len(std), 1, 1, 1)
|
| 42 |
+
else:
|
| 43 |
+
self.vae_latents_mean = None
|
| 44 |
+
self.vae_latents_std = None
|
| 45 |
+
|
| 46 |
+
# Регистрируем параметры Cosmos в шедулере (если они еще не там)
|
| 47 |
+
if self.scheduler is not None:
|
| 48 |
+
self.scheduler.register_to_config(
|
| 49 |
+
sigma_max=getattr(self.scheduler.config, "sigma_max", 80.0),
|
| 50 |
+
sigma_min=getattr(self.scheduler.config, "sigma_min", 0.002),
|
| 51 |
+
sigma_data=getattr(self.scheduler.config, "sigma_data", 1.0),
|
| 52 |
+
final_sigmas_type=getattr(self.scheduler.config, "final_sigmas_type", "sigma_min"),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
|
| 57 |
+
current_len = tensor.shape[dim]
|
| 58 |
+
if current_len >= target_len:
|
| 59 |
+
return tensor
|
| 60 |
+
pad_size = target_len - current_len
|
| 61 |
+
if tensor.dim() == 3:
|
| 62 |
+
padding = (0, 0, 0, pad_size, 0, 0)
|
| 63 |
+
elif tensor.dim() == 2:
|
| 64 |
+
padding = (0, pad_size, 0, 0)
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
|
| 67 |
+
return torch.nn.functional.pad(tensor, padding, value=pad_value)
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def refine_prompts(
|
| 71 |
+
self,
|
| 72 |
+
prompts: Union[str, List[str]],
|
| 73 |
+
system_prompt: Optional[str] = None,
|
| 74 |
+
temperature: float = 0.7
|
| 75 |
+
) -> List[str]:
|
| 76 |
+
"""Refines a list of prompts using the Text Encoder (LLM)."""
|
| 77 |
+
device = self.device
|
| 78 |
+
|
| 79 |
+
if system_prompt is None:
|
| 80 |
+
system_prompt = (
|
| 81 |
+
"You are a skilled text-to-image prompt engineer whose sole function is to transform "
|
| 82 |
+
"the user's input into an aesthetically optimized, detailed, and visually descriptive two-sentence output. "
|
| 83 |
+
"**The primary subject MUST be the main focus of the revised prompt "
|
| 84 |
+
"and MUST be described in rich detail within the first sentence.** "
|
| 85 |
+
"Output **only** the final revised prompt, with absolutely no commentary. "
|
| 86 |
+
"Don't use cliches like warm, soft, vibrant, wildflowers. Be creative. User input prompt: "
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
pad_id = getattr(self.text_encoder.config, "pad_token_id", None) or \
|
| 90 |
+
getattr(self.text_encoder.config, "eos_token_id", None)
|
| 91 |
+
|
| 92 |
+
prompts_list = [prompts] if isinstance(prompts, str) else prompts
|
| 93 |
+
refined_list = []
|
| 94 |
+
|
| 95 |
+
for p in prompts_list:
|
| 96 |
+
full_text = system_prompt + p
|
| 97 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}]
|
| 98 |
+
|
| 99 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 100 |
+
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
|
| 101 |
+
).to(device)
|
| 102 |
+
|
| 103 |
+
generated_ids = self.text_encoder.generate(
|
| 104 |
+
**inputs,
|
| 105 |
+
max_new_tokens=self.MAX_TEXT_TOKENS,
|
| 106 |
+
do_sample=True,
|
| 107 |
+
temperature=temperature,
|
| 108 |
+
pad_token_id=pad_id
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
generated_ids_trimmed = [
|
| 112 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 113 |
+
]
|
| 114 |
+
output_text = self.tokenizer.batch_decode(
|
| 115 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 116 |
+
)
|
| 117 |
+
refined_list.append(output_text[0])
|
| 118 |
+
|
| 119 |
+
return refined_list
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 123 |
+
"""Qwen-specific text encoding (using chat_template and hidden_states[-2])"""
|
| 124 |
+
device = self.device
|
| 125 |
+
dtype = self.transformer.dtype
|
| 126 |
+
if text is None: text = ""
|
| 127 |
+
if isinstance(text, str): text = [text]
|
| 128 |
+
|
| 129 |
+
formatted_prompts = []
|
| 130 |
+
for t in text:
|
| 131 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 132 |
+
formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
|
| 133 |
+
|
| 134 |
+
toks = self.tokenizer(formatted_prompts, padding="max_length", max_length=self.MAX_TEXT_TOKENS, truncation=True, return_tensors="pt").to(device)
|
| 135 |
+
outputs = self.text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
|
| 136 |
+
|
| 137 |
+
# Берем предпоследний слой эмбеддингов, как того требуют современные пайплайны
|
| 138 |
+
last_hidden = outputs.hidden_states[-2]
|
| 139 |
+
|
| 140 |
+
return last_hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.int64)
|
| 141 |
+
|
| 142 |
+
@torch.no_grad()
|
| 143 |
+
def image_upscale(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]], batch_size: int = 1) -> List[Image.Image]:
|
| 144 |
+
images = [image] if isinstance(image, (str, Image.Image)) else image
|
| 145 |
+
|
| 146 |
+
batch_data = []
|
| 147 |
+
for img in images:
|
| 148 |
+
if isinstance(img, str): img = Image.open(img)
|
| 149 |
+
if img.mode == "RGBA":
|
| 150 |
+
img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
|
| 151 |
+
img = img.convert("RGB")
|
| 152 |
+
|
| 153 |
+
w, h = img.size
|
| 154 |
+
pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
|
| 155 |
+
if pw or ph:
|
| 156 |
+
padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
|
| 157 |
+
padded.paste(img)
|
| 158 |
+
img = padded
|
| 159 |
+
|
| 160 |
+
t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
|
| 161 |
+
batch_data.append((t.to(self.device, torch.float16), w, h))
|
| 162 |
+
|
| 163 |
+
unique_shapes = {t.shape for t, _, _ in batch_data}
|
| 164 |
+
step = batch_size if len(unique_shapes) == 1 else 1
|
| 165 |
+
|
| 166 |
+
output_images = []
|
| 167 |
+
for i in range(0, len(batch_data), step):
|
| 168 |
+
chunk = batch_data[i : i + step]
|
| 169 |
+
tensors = torch.stack([c[0] for c in chunk]).unsqueeze(2)
|
| 170 |
+
|
| 171 |
+
latents = self.vae.encode(tensors).latent_dist.mean
|
| 172 |
+
decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
|
| 173 |
+
|
| 174 |
+
if decoded.ndim == 5:
|
| 175 |
+
decoded = decoded.squeeze(2)
|
| 176 |
+
|
| 177 |
+
decoded = (decoded.clamp(-1, 1) + 1) / 2
|
| 178 |
+
for j, tensor in enumerate(decoded):
|
| 179 |
+
w, h = chunk[j][1], chunk[j][2]
|
| 180 |
+
arr = tensor.cpu().permute(1, 2, 0).float().numpy()
|
| 181 |
+
arr = arr[:h * 2, :w * 2]
|
| 182 |
+
output_images.append(Image.fromarray((arr * 255).astype("uint8")))
|
| 183 |
+
|
| 184 |
+
return output_images
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def __call__(
|
| 188 |
+
self,
|
| 189 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 190 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 191 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 192 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 193 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 194 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 195 |
+
latents: Optional[torch.Tensor] = None,
|
| 196 |
+
height: int = 1024,
|
| 197 |
+
width: int = 1024,
|
| 198 |
+
num_inference_steps: int = 40,
|
| 199 |
+
guidance_scale: float = 4.0,
|
| 200 |
+
generator: Optional[torch.Generator] = None,
|
| 201 |
+
seed: Optional[int] = None,
|
| 202 |
+
output_type: str = "pil",
|
| 203 |
+
return_dict: bool = True,
|
| 204 |
+
**kwargs,
|
| 205 |
+
):
|
| 206 |
+
device = self.device
|
| 207 |
+
dtype = self.transformer.dtype
|
| 208 |
+
|
| 209 |
+
if generator is None and seed is not None:
|
| 210 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 211 |
+
|
| 212 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 213 |
+
|
| 214 |
+
# 1. Encode Positive
|
| 215 |
+
if prompt_embeds is None:
|
| 216 |
+
if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
|
| 217 |
+
prompt_embeds, prompt_attention_mask = self.encode_text(prompt)
|
| 218 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
| 219 |
+
prompt_attention_mask = prompt_attention_mask.to(device=device, dtype=torch.int64)
|
| 220 |
+
batch_size = prompt_embeds.shape[0]
|
| 221 |
+
|
| 222 |
+
# 2. Encode Negative
|
| 223 |
+
if do_classifier_free_guidance:
|
| 224 |
+
if negative_prompt_embeds is None:
|
| 225 |
+
neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
|
| 226 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self.encode_text(neg_text)
|
| 227 |
+
|
| 228 |
+
negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
|
| 229 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device, dtype=torch.int64)
|
| 230 |
+
|
| 231 |
+
if negative_prompt_embeds.shape[0] != batch_size:
|
| 232 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
|
| 233 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(batch_size, 1)
|
| 234 |
+
|
| 235 |
+
max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
| 236 |
+
prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
|
| 237 |
+
negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
|
| 238 |
+
prompt_attention_mask = self._pad_tensor_to_length(prompt_attention_mask, max_len, dim=1, pad_value=0)
|
| 239 |
+
negative_prompt_attention_mask = self._pad_tensor_to_length(negative_prompt_attention_mask, max_len, dim=1, pad_value=0)
|
| 240 |
+
|
| 241 |
+
text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 242 |
+
else:
|
| 243 |
+
text_embeddings = prompt_embeds
|
| 244 |
+
|
| 245 |
+
# 3. Prepare Timesteps (Cosmos specific schedule)
|
| 246 |
+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 247 |
+
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
|
| 248 |
+
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
|
| 249 |
+
timesteps = self.scheduler.timesteps
|
| 250 |
+
|
| 251 |
+
# Защита от деления на ноль на последнем шаге
|
| 252 |
+
if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
|
| 253 |
+
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
| 254 |
+
if self.scheduler.sigmas[-1] == 0.0:
|
| 255 |
+
self.scheduler.sigmas[-1] = 1e-4
|
| 256 |
+
|
| 257 |
+
# 4. Prepare Latents (Noise)
|
| 258 |
+
latent_h = height // self.vae_scale_factor
|
| 259 |
+
latent_w = width // self.vae_scale_factor
|
| 260 |
+
in_channels = self.transformer.config.in_channels
|
| 261 |
+
sigma_max = getattr(self.scheduler.config, "sigma_max", 80.0)
|
| 262 |
+
|
| 263 |
+
if latents is None:
|
| 264 |
+
# Создаем 5D тензор [Batch, Channels, Frames, Height, Width]
|
| 265 |
+
latents = torch.randn((batch_size, in_channels, 1, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
|
| 266 |
+
latents = latents * sigma_max
|
| 267 |
+
else:
|
| 268 |
+
latents = latents.to(device=device, dtype=dtype) * sigma_max
|
| 269 |
+
|
| 270 |
+
# Cosmos Padding Mask
|
| 271 |
+
padding_mask = torch.zeros((1, 1, height, width), device=device, dtype=dtype)
|
| 272 |
+
|
| 273 |
+
# 5. Denoising Loop (Continuous Flow Math)
|
| 274 |
+
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
|
| 275 |
+
current_sigma = self.scheduler.sigmas[i]
|
| 276 |
+
|
| 277 |
+
# Защита от деления на 0 при вычислении current_t
|
| 278 |
+
if current_sigma == 0.0:
|
| 279 |
+
current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
|
| 280 |
+
|
| 281 |
+
current_t = current_sigma / (current_sigma + 1.0)
|
| 282 |
+
c_in = 1.0 - current_t
|
| 283 |
+
c_skip = 1.0 - current_t
|
| 284 |
+
c_out = -current_t
|
| 285 |
+
|
| 286 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 287 |
+
latent_model_input = (latent_model_input * c_in).to(dtype)
|
| 288 |
+
|
| 289 |
+
# Трансформер ждет timestep в виде 1D тензора [B]
|
| 290 |
+
t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t)
|
| 291 |
+
timestep_tensor = torch.tensor(
|
| 292 |
+
[t_val],
|
| 293 |
+
device=device,
|
| 294 |
+
dtype=dtype
|
| 295 |
+
).view(1, 1, 1, 1, 1).expand(latent_model_input.shape[0], 1, 1, 1, 1)
|
| 296 |
+
|
| 297 |
+
model_out = self.transformer(
|
| 298 |
+
hidden_states=latent_model_input,
|
| 299 |
+
timestep=timestep_tensor,
|
| 300 |
+
encoder_hidden_states=text_embeddings,
|
| 301 |
+
padding_mask=padding_mask,
|
| 302 |
+
return_dict=False,
|
| 303 |
+
)[0]
|
| 304 |
+
|
| 305 |
+
batched_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 306 |
+
noise_pred = (c_skip * batched_latents + c_out * model_out.float()).to(dtype)
|
| 307 |
+
|
| 308 |
+
if do_classifier_free_guidance:
|
| 309 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 310 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 311 |
+
|
| 312 |
+
noise_pred = (latents - noise_pred) / current_sigma
|
| 313 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 314 |
+
|
| 315 |
+
# 6. Decode
|
| 316 |
+
if output_type == "latent":
|
| 317 |
+
if not return_dict: return (latents, prompt)
|
| 318 |
+
return SdxsPipelineOutput(images=latents)
|
| 319 |
+
|
| 320 |
+
if getattr(self.vae.config, "latents_std", None) is not None and getattr(self.vae.config, "latents_mean", None) is not None:
|
| 321 |
+
sigma_data = getattr(self.scheduler.config, "sigma_data", 1.0)
|
| 322 |
+
|
| 323 |
+
l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 324 |
+
l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 325 |
+
|
| 326 |
+
# Оригинальная формула: делим на инвертированный std (что равноценно умножению на std)
|
| 327 |
+
#latents_std_inv = 1.0 / l_std
|
| 328 |
+
latents = latents * l_std + l_mean
|
| 329 |
+
|
| 330 |
+
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 331 |
+
|
| 332 |
+
if image_output.ndim == 5:
|
| 333 |
+
image_output = image_output.squeeze(2)
|
| 334 |
+
|
| 335 |
+
image_output = (image_output.clamp(-1, 1) + 1) / 2
|
| 336 |
+
image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 337 |
+
|
| 338 |
+
# На всякий случай вычищаем NaNs
|
| 339 |
+
image_np = np.nan_to_num(image_np, nan=0.0, posinf=1.0, neginf=0.0)
|
| 340 |
+
|
| 341 |
+
if output_type == "pil":
|
| 342 |
+
images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
|
| 343 |
+
else:
|
| 344 |
+
images = image_np
|
| 345 |
+
|
| 346 |
+
if not return_dict:
|
| 347 |
+
return (images,)
|
| 348 |
+
return SdxsPipelineOutput(images=images)
|
pipeline_sdxs_t5.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import List, Union, Optional, Tuple
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from diffusers import DiffusionPipeline
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SdxsPipelineOutput(BaseOutput):
|
| 13 |
+
images: Union[List[Image.Image], np.ndarray]
|
| 14 |
+
prompt: Optional[Union[str, List[str]]] = None
|
| 15 |
+
|
| 16 |
+
class SdxsPipeline(DiffusionPipeline):
|
| 17 |
+
# Cosmos требует 512 токенов
|
| 18 |
+
MAX_TEXT_TOKENS = 512
|
| 19 |
+
|
| 20 |
+
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.register_modules(
|
| 23 |
+
vae=vae,
|
| 24 |
+
text_encoder=text_encoder,
|
| 25 |
+
tokenizer=tokenizer,
|
| 26 |
+
transformer=transformer,
|
| 27 |
+
scheduler=scheduler
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.vae_scale_factor = getattr(self.vae.config, "spatial_compression_ratio", 8)
|
| 31 |
+
if hasattr(self.vae.config, "block_out_channels"):
|
| 32 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 33 |
+
|
| 34 |
+
# Регистрируем параметры Cosmos в шедулере
|
| 35 |
+
if self.scheduler is not None:
|
| 36 |
+
self.scheduler.register_to_config(
|
| 37 |
+
sigma_max=getattr(self.scheduler.config, "sigma_max", 80.0),
|
| 38 |
+
sigma_min=getattr(self.scheduler.config, "sigma_min", 0.002),
|
| 39 |
+
sigma_data=getattr(self.scheduler.config, "sigma_data", 1.0),
|
| 40 |
+
final_sigmas_type=getattr(self.scheduler.config, "final_sigmas_type", "sigma_min"),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _pad_tensor_to_length(tensor: torch.Tensor, target_len: int, dim: int = 1, pad_value: float = 0) -> torch.Tensor:
|
| 45 |
+
current_len = tensor.shape[dim]
|
| 46 |
+
if current_len >= target_len:
|
| 47 |
+
return tensor
|
| 48 |
+
pad_size = target_len - current_len
|
| 49 |
+
if tensor.dim() == 3:
|
| 50 |
+
padding = (0, 0, 0, pad_size, 0, 0)
|
| 51 |
+
elif tensor.dim() == 2:
|
| 52 |
+
padding = (0, pad_size, 0, 0)
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
|
| 55 |
+
return torch.nn.functional.pad(tensor, padding, value=pad_value)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def refine_prompts(
|
| 59 |
+
self,
|
| 60 |
+
prompts: Union[str, List[str]],
|
| 61 |
+
system_prompt: Optional[str] = None,
|
| 62 |
+
temperature: float = 0.7
|
| 63 |
+
) -> List[str]:
|
| 64 |
+
return [prompts] if isinstance(prompts, str) else prompts
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def encode_text(self, text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
+
device = self.device
|
| 69 |
+
dtype = self.transformer.dtype
|
| 70 |
+
if text is None: text = ""
|
| 71 |
+
if isinstance(text, str): text = [text]
|
| 72 |
+
|
| 73 |
+
text_inputs = self.tokenizer(
|
| 74 |
+
text,
|
| 75 |
+
padding="max_length",
|
| 76 |
+
max_length=self.MAX_TEXT_TOKENS,
|
| 77 |
+
truncation=True,
|
| 78 |
+
return_tensors="pt"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 82 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 83 |
+
|
| 84 |
+
outputs = self.text_encoder(input_ids=text_input_ids, attention_mask=attention_mask)
|
| 85 |
+
prompt_embeds = outputs.last_hidden_state
|
| 86 |
+
|
| 87 |
+
lengths = attention_mask.sum(dim=1)
|
| 88 |
+
for i, length in enumerate(lengths):
|
| 89 |
+
prompt_embeds[i, length:] = 0
|
| 90 |
+
|
| 91 |
+
return prompt_embeds.to(dtype=dtype), attention_mask.to(dtype=torch.int64)
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def image_upscale(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]], batch_size: int = 1) -> List[Image.Image]:
|
| 95 |
+
images = [image] if isinstance(image, (str, Image.Image)) else image
|
| 96 |
+
|
| 97 |
+
batch_data = []
|
| 98 |
+
for img in images:
|
| 99 |
+
if isinstance(img, str): img = Image.open(img)
|
| 100 |
+
if img.mode == "RGBA":
|
| 101 |
+
img = Image.alpha_composite(Image.new("RGBA", img.size, (255, 255, 255)), img)
|
| 102 |
+
img = img.convert("RGB")
|
| 103 |
+
|
| 104 |
+
w, h = img.size
|
| 105 |
+
pw, ph = (8 - w % 8) % 8, (8 - h % 8) % 8
|
| 106 |
+
if pw or ph:
|
| 107 |
+
padded = Image.new("RGB", (w + pw, h + ph), (255, 255, 255))
|
| 108 |
+
padded.paste(img)
|
| 109 |
+
img = padded
|
| 110 |
+
|
| 111 |
+
t = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0).permute(2, 0, 1)
|
| 112 |
+
batch_data.append((t.to(self.device, torch.float16), w, h))
|
| 113 |
+
|
| 114 |
+
unique_shapes = {t.shape for t, _, _ in batch_data}
|
| 115 |
+
step = batch_size if len(unique_shapes) == 1 else 1
|
| 116 |
+
|
| 117 |
+
output_images = []
|
| 118 |
+
for i in range(0, len(batch_data), step):
|
| 119 |
+
chunk = batch_data[i : i + step]
|
| 120 |
+
tensors = torch.stack([c[0] for c in chunk]).unsqueeze(2)
|
| 121 |
+
|
| 122 |
+
latents = self.vae.encode(tensors).latent_dist.mean
|
| 123 |
+
decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
|
| 124 |
+
|
| 125 |
+
if decoded.ndim == 5:
|
| 126 |
+
decoded = decoded.squeeze(2)
|
| 127 |
+
|
| 128 |
+
decoded = (decoded.clamp(-1, 1) + 1) / 2
|
| 129 |
+
for j, tensor in enumerate(decoded):
|
| 130 |
+
w, h = chunk[j][1], chunk[j][2]
|
| 131 |
+
arr = tensor.cpu().permute(1, 2, 0).float().numpy()
|
| 132 |
+
arr = arr[:h * 2, :w * 2]
|
| 133 |
+
output_images.append(Image.fromarray((arr * 255).astype("uint8")))
|
| 134 |
+
|
| 135 |
+
return output_images
|
| 136 |
+
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def __call__(
|
| 139 |
+
self,
|
| 140 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 141 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 142 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 143 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 144 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 145 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 146 |
+
latents: Optional[torch.Tensor] = None,
|
| 147 |
+
height: int = 1024,
|
| 148 |
+
width: int = 1024,
|
| 149 |
+
num_inference_steps: int = 40,
|
| 150 |
+
guidance_scale: float = 7.0,
|
| 151 |
+
generator: Optional[torch.Generator] = None,
|
| 152 |
+
seed: Optional[int] = None,
|
| 153 |
+
output_type: str = "pil",
|
| 154 |
+
return_dict: bool = True,
|
| 155 |
+
**kwargs,
|
| 156 |
+
):
|
| 157 |
+
device = self.device
|
| 158 |
+
dtype = self.transformer.dtype
|
| 159 |
+
|
| 160 |
+
if generator is None and seed is not None:
|
| 161 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 162 |
+
|
| 163 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 164 |
+
|
| 165 |
+
# 1. Encode Positive
|
| 166 |
+
if prompt_embeds is None:
|
| 167 |
+
if prompt is None: raise ValueError("`prompt` or `prompt_embeds` required.")
|
| 168 |
+
prompt_embeds, _ = self.encode_text(prompt)
|
| 169 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
| 170 |
+
batch_size = prompt_embeds.shape[0]
|
| 171 |
+
|
| 172 |
+
# 2. Encode Negative
|
| 173 |
+
if do_classifier_free_guidance:
|
| 174 |
+
if negative_prompt_embeds is None:
|
| 175 |
+
neg_text = negative_prompt if negative_prompt is not None else ("" if isinstance(prompt, str) else [""] * len(prompt))
|
| 176 |
+
negative_prompt_embeds, _ = self.encode_text(neg_text)
|
| 177 |
+
|
| 178 |
+
negative_prompt_embeds = negative_prompt_embeds.to(device=device, dtype=dtype)
|
| 179 |
+
|
| 180 |
+
if negative_prompt_embeds.shape[0] != batch_size:
|
| 181 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)
|
| 182 |
+
|
| 183 |
+
max_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
| 184 |
+
prompt_embeds = self._pad_tensor_to_length(prompt_embeds, max_len, dim=1, pad_value=0)
|
| 185 |
+
negative_prompt_embeds = self._pad_tensor_to_length(negative_prompt_embeds, max_len, dim=1, pad_value=0)
|
| 186 |
+
|
| 187 |
+
# 3. Prepare Timesteps
|
| 188 |
+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 189 |
+
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
|
| 190 |
+
self.scheduler.set_timesteps(sigmas=sigmas, device=device)
|
| 191 |
+
timesteps = self.scheduler.timesteps
|
| 192 |
+
|
| 193 |
+
# Защита от деления на ноль на последнем шаге
|
| 194 |
+
if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
|
| 195 |
+
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
| 196 |
+
if self.scheduler.sigmas[-1] == 0.0:
|
| 197 |
+
self.scheduler.sigmas[-1] = 1e-4
|
| 198 |
+
|
| 199 |
+
# 4. Prepare Latents (Noise)
|
| 200 |
+
latent_h = height // self.vae_scale_factor
|
| 201 |
+
latent_w = width // self.vae_scale_factor
|
| 202 |
+
in_channels = self.transformer.config.in_channels
|
| 203 |
+
sigma_max = getattr(self.scheduler.config, "sigma_max", 80.0)
|
| 204 |
+
|
| 205 |
+
if latents is None:
|
| 206 |
+
latents = torch.randn((batch_size, in_channels, 1, latent_h, latent_w), generator=generator, device=device, dtype=dtype)
|
| 207 |
+
latents = latents * sigma_max
|
| 208 |
+
else:
|
| 209 |
+
latents = latents.to(device=device, dtype=dtype) * sigma_max
|
| 210 |
+
|
| 211 |
+
# Cosmos Padding Mask
|
| 212 |
+
padding_mask = latents.new_zeros(1, 1, height, width, dtype=dtype)
|
| 213 |
+
|
| 214 |
+
# 5. Denoising Loop
|
| 215 |
+
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
|
| 216 |
+
current_sigma = self.scheduler.sigmas[i]
|
| 217 |
+
|
| 218 |
+
# Защита от деления на 0 при вычислении current_t
|
| 219 |
+
if current_sigma == 0.0:
|
| 220 |
+
current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
|
| 221 |
+
|
| 222 |
+
current_t = current_sigma / (current_sigma + 1.0)
|
| 223 |
+
c_in = 1.0 - current_t
|
| 224 |
+
c_skip = 1.0 - current_t
|
| 225 |
+
c_out = -current_t
|
| 226 |
+
|
| 227 |
+
latent_model_input = (latents * c_in).to(dtype)
|
| 228 |
+
timestep = current_t.expand(latents.shape[0]).to(dtype)
|
| 229 |
+
|
| 230 |
+
# Проход 1
|
| 231 |
+
noise_pred = self.transformer(
|
| 232 |
+
hidden_states=latent_model_input,
|
| 233 |
+
timestep=timestep,
|
| 234 |
+
encoder_hidden_states=prompt_embeds,
|
| 235 |
+
padding_mask=padding_mask,
|
| 236 |
+
return_dict=False,
|
| 237 |
+
)[0]
|
| 238 |
+
|
| 239 |
+
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(dtype)
|
| 240 |
+
|
| 241 |
+
# Проход 2
|
| 242 |
+
if do_classifier_free_guidance:
|
| 243 |
+
noise_pred_uncond = self.transformer(
|
| 244 |
+
hidden_states=latent_model_input,
|
| 245 |
+
timestep=timestep,
|
| 246 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 247 |
+
padding_mask=padding_mask,
|
| 248 |
+
return_dict=False,
|
| 249 |
+
)[0]
|
| 250 |
+
|
| 251 |
+
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(dtype)
|
| 252 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
| 253 |
+
|
| 254 |
+
noise_pred = (latents - noise_pred) / current_sigma
|
| 255 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 256 |
+
|
| 257 |
+
# 6. Decode
|
| 258 |
+
if output_type == "latent":
|
| 259 |
+
if not return_dict: return (latents, prompt)
|
| 260 |
+
return SdxsPipelineOutput(images=latents)
|
| 261 |
+
|
| 262 |
+
# Точная математика NVIDIA для декодирования (без двойных инверсий)
|
| 263 |
+
if getattr(self.vae.config, "latents_std", None) is not None and getattr(self.vae.config, "latents_mean", None) is not None:
|
| 264 |
+
sigma_data = getattr(self.scheduler.config, "sigma_data", 1.0)
|
| 265 |
+
|
| 266 |
+
l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 267 |
+
l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype)
|
| 268 |
+
|
| 269 |
+
# Оригинальная формула: делим на инвертированный std (что равноценно умножению на std)
|
| 270 |
+
latents_std_inv = 1.0 / l_std
|
| 271 |
+
latents = latents / latents_std_inv / sigma_data + l_mean
|
| 272 |
+
|
| 273 |
+
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 274 |
+
|
| 275 |
+
if image_output.ndim == 5:
|
| 276 |
+
image_output = image_output.squeeze(2)
|
| 277 |
+
|
| 278 |
+
image_output = (image_output.clamp(-1, 1) + 1) / 2
|
| 279 |
+
image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 280 |
+
|
| 281 |
+
# На всякий случай вычищаем NaNs, если они проскользнули, чтобы скрипт не падал с кастом
|
| 282 |
+
image_np = np.nan_to_num(image_np, nan=0.0, posinf=1.0, neginf=0.0)
|
| 283 |
+
|
| 284 |
+
if output_type == "pil":
|
| 285 |
+
images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
|
| 286 |
+
else:
|
| 287 |
+
images = image_np
|
| 288 |
+
|
| 289 |
+
if not return_dict:
|
| 290 |
+
return (images,)
|
| 291 |
+
return SdxsPipelineOutput(images=images)
|
scheduler/.ipynb_checkpoints/scheduler_config-checkpoint.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.34.0.dev0",
|
| 4 |
+
"base_image_seq_len": 256,
|
| 5 |
+
"base_shift": 0.5,
|
| 6 |
+
"final_sigmas_type": "sigma_min",
|
| 7 |
+
"invert_sigmas": false,
|
| 8 |
+
"max_image_seq_len": 4096,
|
| 9 |
+
"max_shift": 1.15,
|
| 10 |
+
"num_train_timesteps": 1000,
|
| 11 |
+
"shift": 1.0,
|
| 12 |
+
"shift_terminal": null,
|
| 13 |
+
"sigma_data": 1.0,
|
| 14 |
+
"sigma_max": 80.0,
|
| 15 |
+
"sigma_min": 0.002,
|
| 16 |
+
"stochastic_sampling": false,
|
| 17 |
+
"time_shift_type": "exponential",
|
| 18 |
+
"use_beta_sigmas": false,
|
| 19 |
+
"use_dynamic_shifting": false,
|
| 20 |
+
"use_exponential_sigmas": false,
|
| 21 |
+
"use_karras_sigmas": true
|
| 22 |
+
}
|
scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.34.0.dev0",
|
| 4 |
+
"base_image_seq_len": 256,
|
| 5 |
+
"base_shift": 0.5,
|
| 6 |
+
"final_sigmas_type": "sigma_min",
|
| 7 |
+
"invert_sigmas": false,
|
| 8 |
+
"max_image_seq_len": 4096,
|
| 9 |
+
"max_shift": 1.15,
|
| 10 |
+
"num_train_timesteps": 1000,
|
| 11 |
+
"shift": 1.0,
|
| 12 |
+
"shift_terminal": null,
|
| 13 |
+
"sigma_data": 1.0,
|
| 14 |
+
"sigma_max": 80.0,
|
| 15 |
+
"sigma_min": 0.002,
|
| 16 |
+
"stochastic_sampling": false,
|
| 17 |
+
"time_shift_type": "exponential",
|
| 18 |
+
"use_beta_sigmas": false,
|
| 19 |
+
"use_dynamic_shifting": false,
|
| 20 |
+
"use_exponential_sigmas": false,
|
| 21 |
+
"use_karras_sigmas": true
|
| 22 |
+
}
|
t.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_from_disk
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from diffusers import AutoencoderKLQwenImage
|
| 7 |
+
import gc
|
| 8 |
+
|
| 9 |
+
def analyze_dataset_by_size(dataset_path):
|
| 10 |
+
"""
|
| 11 |
+
Группирует датасет по размерам изображений и выводит базовую информацию.
|
| 12 |
+
"""
|
| 13 |
+
# Настройка устройства и типа данных
|
| 14 |
+
dtype = torch.float32
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
|
| 17 |
+
# Загрузка VAE модели
|
| 18 |
+
print("Загрузка VAE модели...")
|
| 19 |
+
vae = AutoencoderKLQwenImage.from_pretrained("vae",torch_dtype=dtype).to(device).eval()
|
| 20 |
+
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 21 |
+
if shift_factor is None:
|
| 22 |
+
shift_factor = 0.0
|
| 23 |
+
|
| 24 |
+
scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
|
| 25 |
+
if scaling_factor is None:
|
| 26 |
+
scaling_factor = 1.0
|
| 27 |
+
|
| 28 |
+
mean = getattr(vae.config, "latents_mean", None)
|
| 29 |
+
std = getattr(vae.config, "latents_std", None)
|
| 30 |
+
if mean is not None and std is not None:
|
| 31 |
+
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
|
| 32 |
+
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
|
| 33 |
+
|
| 34 |
+
# Загружаем датасет
|
| 35 |
+
print(f"Загрузка датасета из {dataset_path}...")
|
| 36 |
+
dataset = load_from_disk(dataset_path)
|
| 37 |
+
|
| 38 |
+
print(f"Осталось примеров после фильтрации: {len(dataset)}")
|
| 39 |
+
|
| 40 |
+
# Группируем примеры по размерам
|
| 41 |
+
print("\nГруппировка примеров по размерам...")
|
| 42 |
+
size_to_indices = defaultdict(list)
|
| 43 |
+
|
| 44 |
+
# Собираем примеры с одинаковыми размерами
|
| 45 |
+
# Собираем примеры с одинаковыми размерами (оптимизированная версия)
|
| 46 |
+
widths = dataset["width"]
|
| 47 |
+
heights = dataset["height"]
|
| 48 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
| 49 |
+
size_to_indices[(w, h)].append(i)
|
| 50 |
+
|
| 51 |
+
# Сортируем размеры по количеству примеров
|
| 52 |
+
print("\nСортируем...")
|
| 53 |
+
size_stats = [(size, len(indices)) for size, indices in size_to_indices.items()]
|
| 54 |
+
size_stats.sort(key=lambda x: x[1], reverse=True)
|
| 55 |
+
|
| 56 |
+
# Выводим информацию о каждой группе и показываем первый пример
|
| 57 |
+
for size, count in size_stats:
|
| 58 |
+
width, height = size
|
| 59 |
+
first_idx = size_to_indices[size][1]
|
| 60 |
+
example = dataset[first_idx]
|
| 61 |
+
|
| 62 |
+
print(f"\n--- Батч {width}x{height}: {count} примеров ---")
|
| 63 |
+
|
| 64 |
+
# Декодируем латентное представление для первого примера
|
| 65 |
+
latent = torch.tensor(example["vae"], dtype=dtype).unsqueeze(0).to(device)
|
| 66 |
+
|
| 67 |
+
# 1. Снова обманываем VAE, превращая картинку в "видео из 1 кадра" [B, C, 1, H, W]
|
| 68 |
+
if latent.ndim == 4:
|
| 69 |
+
latent = latent.unsqueeze(2)
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
if latents_mean is not None and latents_std is not None:
|
| 73 |
+
latent = latent * latents_std + latents_mean
|
| 74 |
+
|
| 75 |
+
print(f"Min of latent_for_vae: {latent.min()}")
|
| 76 |
+
print(f"Max of latent_for_vae: {latent.max()}")
|
| 77 |
+
print(f"Mean of latent_for_vae: {latent.mean()}")
|
| 78 |
+
print(f"Std: {latent.std().item():.4f}")
|
| 79 |
+
if torch.isnan(latent).any() or torch.isinf(latent).any():
|
| 80 |
+
print("WARNING: Raw latents contain NaN or Inf values!")
|
| 81 |
+
|
| 82 |
+
reconstructed_image = vae.decode(latent).sample
|
| 83 |
+
|
| 84 |
+
# 2. Вытаскиваем обычную 3D-картинку [C, H, W] из 5D-видеотензора
|
| 85 |
+
if reconstructed_image.ndim == 5:
|
| 86 |
+
# Берем нулевой батч, все каналы, нулевой кадр, всю высоту и ширину
|
| 87 |
+
img_tensor = reconstructed_image[0, :, 0, :, :]
|
| 88 |
+
else:
|
| 89 |
+
img_tensor = reconstructed_image.squeeze(0) # На всякий случай, если VAE вернул 4D
|
| 90 |
+
|
| 91 |
+
img_array = img_tensor.cpu().numpy()
|
| 92 |
+
img_array = np.transpose(img_array, (1, 2, 0))
|
| 93 |
+
img_array = (img_array + 1) / 2 # Нормализация к [0, 1]
|
| 94 |
+
img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8) # Преобразуем в uint8 для PIL
|
| 95 |
+
|
| 96 |
+
# Создаем PIL изображение из массива
|
| 97 |
+
pil_image = Image.fromarray(img_array)
|
| 98 |
+
print(f"Текст: {example['text']}")
|
| 99 |
+
print(f"Ключи: {', '.join(example.keys())}")
|
| 100 |
+
print(f"latent: {latent.shape}")
|
| 101 |
+
pil_image.save("1.jpg")
|
| 102 |
+
|
| 103 |
+
# Очистка памяти
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
torch.cuda.empty_cache()
|
| 106 |
+
gc.collect()
|
| 107 |
+
|
| 108 |
+
return size_to_indices # Возвращаем словарь с индексами по группам
|
| 109 |
+
|
| 110 |
+
# Использование
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
# Путь к датасету
|
| 113 |
+
save_path = "datasets/ds234_640_vae_qwen"
|
| 114 |
+
|
| 115 |
+
# Анализ датасета
|
| 116 |
+
size_groups = analyze_dataset_by_size(save_path)
|
test.ipynb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:677906d20fb691440965fb107de2c9d8e9b7c75884d9e3e15b4375f4257df8ae
|
| 3 |
+
size 21416092
|
text_encoder/.ipynb_checkpoints/config-checkpoint.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3_5Model"
|
| 4 |
+
],
|
| 5 |
+
"dtype": "bfloat16",
|
| 6 |
+
"image_token_id": 248056,
|
| 7 |
+
"model_type": "qwen3_5",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"attention_bias": false,
|
| 10 |
+
"attention_dropout": 0.0,
|
| 11 |
+
"attn_output_gate": true,
|
| 12 |
+
"bos_token_id": null,
|
| 13 |
+
"dtype": "bfloat16",
|
| 14 |
+
"eos_token_id": 248044,
|
| 15 |
+
"full_attention_interval": 4,
|
| 16 |
+
"head_dim": 256,
|
| 17 |
+
"hidden_act": "silu",
|
| 18 |
+
"hidden_size": 1024,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 3584,
|
| 21 |
+
"layer_types": [
|
| 22 |
+
"linear_attention",
|
| 23 |
+
"linear_attention",
|
| 24 |
+
"linear_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"linear_attention",
|
| 27 |
+
"linear_attention",
|
| 28 |
+
"linear_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"linear_attention",
|
| 31 |
+
"linear_attention",
|
| 32 |
+
"linear_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"linear_attention",
|
| 35 |
+
"linear_attention",
|
| 36 |
+
"linear_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"linear_attention",
|
| 39 |
+
"linear_attention",
|
| 40 |
+
"linear_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"linear_attention",
|
| 43 |
+
"linear_attention",
|
| 44 |
+
"linear_attention",
|
| 45 |
+
"full_attention"
|
| 46 |
+
],
|
| 47 |
+
"linear_conv_kernel_dim": 4,
|
| 48 |
+
"linear_key_head_dim": 128,
|
| 49 |
+
"linear_num_key_heads": 16,
|
| 50 |
+
"linear_num_value_heads": 16,
|
| 51 |
+
"linear_value_head_dim": 128,
|
| 52 |
+
"mamba_ssm_dtype": "float32",
|
| 53 |
+
"max_position_embeddings": 262144,
|
| 54 |
+
"mlp_only_layers": [],
|
| 55 |
+
"model_type": "qwen3_5_text",
|
| 56 |
+
"mtp_num_hidden_layers": 1,
|
| 57 |
+
"mtp_use_dedicated_embeddings": false,
|
| 58 |
+
"num_attention_heads": 8,
|
| 59 |
+
"num_hidden_layers": 24,
|
| 60 |
+
"num_key_value_heads": 2,
|
| 61 |
+
"pad_token_id": null,
|
| 62 |
+
"partial_rotary_factor": 0.25,
|
| 63 |
+
"rms_norm_eps": 1e-06,
|
| 64 |
+
"rope_parameters": {
|
| 65 |
+
"mrope_interleaved": true,
|
| 66 |
+
"mrope_section": [
|
| 67 |
+
11,
|
| 68 |
+
11,
|
| 69 |
+
10
|
| 70 |
+
],
|
| 71 |
+
"partial_rotary_factor": 0.25,
|
| 72 |
+
"rope_theta": 10000000,
|
| 73 |
+
"rope_type": "default"
|
| 74 |
+
},
|
| 75 |
+
"tie_word_embeddings": true,
|
| 76 |
+
"use_cache": true,
|
| 77 |
+
"vocab_size": 248320
|
| 78 |
+
},
|
| 79 |
+
"tie_word_embeddings": true,
|
| 80 |
+
"transformers_version": "5.6.1",
|
| 81 |
+
"video_token_id": 248057,
|
| 82 |
+
"vision_config": {
|
| 83 |
+
"deepstack_visual_indexes": [],
|
| 84 |
+
"depth": 12,
|
| 85 |
+
"dtype": "bfloat16",
|
| 86 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 87 |
+
"hidden_size": 768,
|
| 88 |
+
"in_channels": 3,
|
| 89 |
+
"initializer_range": 0.02,
|
| 90 |
+
"intermediate_size": 3072,
|
| 91 |
+
"model_type": "qwen3_5_vision",
|
| 92 |
+
"num_heads": 12,
|
| 93 |
+
"num_position_embeddings": 2304,
|
| 94 |
+
"out_hidden_size": 1024,
|
| 95 |
+
"patch_size": 16,
|
| 96 |
+
"spatial_merge_size": 2,
|
| 97 |
+
"temporal_patch_size": 2
|
| 98 |
+
},
|
| 99 |
+
"vision_end_token_id": 248054,
|
| 100 |
+
"vision_start_token_id": 248053
|
| 101 |
+
}
|
text_encoder/config.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3_5Model"
|
| 4 |
+
],
|
| 5 |
+
"dtype": "bfloat16",
|
| 6 |
+
"image_token_id": 248056,
|
| 7 |
+
"model_type": "qwen3_5",
|
| 8 |
+
"text_config": {
|
| 9 |
+
"attention_bias": false,
|
| 10 |
+
"attention_dropout": 0.0,
|
| 11 |
+
"attn_output_gate": true,
|
| 12 |
+
"bos_token_id": null,
|
| 13 |
+
"dtype": "bfloat16",
|
| 14 |
+
"eos_token_id": 248044,
|
| 15 |
+
"full_attention_interval": 4,
|
| 16 |
+
"head_dim": 256,
|
| 17 |
+
"hidden_act": "silu",
|
| 18 |
+
"hidden_size": 1024,
|
| 19 |
+
"initializer_range": 0.02,
|
| 20 |
+
"intermediate_size": 3584,
|
| 21 |
+
"layer_types": [
|
| 22 |
+
"linear_attention",
|
| 23 |
+
"linear_attention",
|
| 24 |
+
"linear_attention",
|
| 25 |
+
"full_attention",
|
| 26 |
+
"linear_attention",
|
| 27 |
+
"linear_attention",
|
| 28 |
+
"linear_attention",
|
| 29 |
+
"full_attention",
|
| 30 |
+
"linear_attention",
|
| 31 |
+
"linear_attention",
|
| 32 |
+
"linear_attention",
|
| 33 |
+
"full_attention",
|
| 34 |
+
"linear_attention",
|
| 35 |
+
"linear_attention",
|
| 36 |
+
"linear_attention",
|
| 37 |
+
"full_attention",
|
| 38 |
+
"linear_attention",
|
| 39 |
+
"linear_attention",
|
| 40 |
+
"linear_attention",
|
| 41 |
+
"full_attention",
|
| 42 |
+
"linear_attention",
|
| 43 |
+
"linear_attention",
|
| 44 |
+
"linear_attention",
|
| 45 |
+
"full_attention"
|
| 46 |
+
],
|
| 47 |
+
"linear_conv_kernel_dim": 4,
|
| 48 |
+
"linear_key_head_dim": 128,
|
| 49 |
+
"linear_num_key_heads": 16,
|
| 50 |
+
"linear_num_value_heads": 16,
|
| 51 |
+
"linear_value_head_dim": 128,
|
| 52 |
+
"mamba_ssm_dtype": "float32",
|
| 53 |
+
"max_position_embeddings": 262144,
|
| 54 |
+
"mlp_only_layers": [],
|
| 55 |
+
"model_type": "qwen3_5_text",
|
| 56 |
+
"mtp_num_hidden_layers": 1,
|
| 57 |
+
"mtp_use_dedicated_embeddings": false,
|
| 58 |
+
"num_attention_heads": 8,
|
| 59 |
+
"num_hidden_layers": 24,
|
| 60 |
+
"num_key_value_heads": 2,
|
| 61 |
+
"pad_token_id": null,
|
| 62 |
+
"partial_rotary_factor": 0.25,
|
| 63 |
+
"rms_norm_eps": 1e-06,
|
| 64 |
+
"rope_parameters": {
|
| 65 |
+
"mrope_interleaved": true,
|
| 66 |
+
"mrope_section": [
|
| 67 |
+
11,
|
| 68 |
+
11,
|
| 69 |
+
10
|
| 70 |
+
],
|
| 71 |
+
"partial_rotary_factor": 0.25,
|
| 72 |
+
"rope_theta": 10000000,
|
| 73 |
+
"rope_type": "default"
|
| 74 |
+
},
|
| 75 |
+
"tie_word_embeddings": true,
|
| 76 |
+
"use_cache": true,
|
| 77 |
+
"vocab_size": 248320
|
| 78 |
+
},
|
| 79 |
+
"tie_word_embeddings": true,
|
| 80 |
+
"transformers_version": "5.6.1",
|
| 81 |
+
"video_token_id": 248057,
|
| 82 |
+
"vision_config": {
|
| 83 |
+
"deepstack_visual_indexes": [],
|
| 84 |
+
"depth": 12,
|
| 85 |
+
"dtype": "bfloat16",
|
| 86 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 87 |
+
"hidden_size": 768,
|
| 88 |
+
"in_channels": 3,
|
| 89 |
+
"initializer_range": 0.02,
|
| 90 |
+
"intermediate_size": 3072,
|
| 91 |
+
"model_type": "qwen3_5_vision",
|
| 92 |
+
"num_heads": 12,
|
| 93 |
+
"num_position_embeddings": 2304,
|
| 94 |
+
"out_hidden_size": 1024,
|
| 95 |
+
"patch_size": 16,
|
| 96 |
+
"spatial_merge_size": 2,
|
| 97 |
+
"temporal_patch_size": 2
|
| 98 |
+
},
|
| 99 |
+
"vision_end_token_id": 248054,
|
| 100 |
+
"vision_start_token_id": 248053
|
| 101 |
+
}
|
text_encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be05a6e8dcacdae04865491110f227b71229110e321aa655982c4bd793ea411a
|
| 3 |
+
size 1706027688
|
tokenizer/chat_template.jinja
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- set image_count = namespace(value=0) %}
|
| 2 |
+
{%- set video_count = namespace(value=0) %}
|
| 3 |
+
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
|
| 4 |
+
{%- if content is string %}
|
| 5 |
+
{{- content }}
|
| 6 |
+
{%- elif content is iterable and content is not mapping %}
|
| 7 |
+
{%- for item in content %}
|
| 8 |
+
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
|
| 9 |
+
{%- if is_system_content %}
|
| 10 |
+
{{- raise_exception('System message cannot contain images.') }}
|
| 11 |
+
{%- endif %}
|
| 12 |
+
{%- if do_vision_count %}
|
| 13 |
+
{%- set image_count.value = image_count.value + 1 %}
|
| 14 |
+
{%- endif %}
|
| 15 |
+
{%- if add_vision_id %}
|
| 16 |
+
{{- 'Picture ' ~ image_count.value ~ ': ' }}
|
| 17 |
+
{%- endif %}
|
| 18 |
+
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
| 19 |
+
{%- elif 'video' in item or item.type == 'video' %}
|
| 20 |
+
{%- if is_system_content %}
|
| 21 |
+
{{- raise_exception('System message cannot contain videos.') }}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
{%- if do_vision_count %}
|
| 24 |
+
{%- set video_count.value = video_count.value + 1 %}
|
| 25 |
+
{%- endif %}
|
| 26 |
+
{%- if add_vision_id %}
|
| 27 |
+
{{- 'Video ' ~ video_count.value ~ ': ' }}
|
| 28 |
+
{%- endif %}
|
| 29 |
+
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
|
| 30 |
+
{%- elif 'text' in item %}
|
| 31 |
+
{{- item.text }}
|
| 32 |
+
{%- else %}
|
| 33 |
+
{{- raise_exception('Unexpected item type in content.') }}
|
| 34 |
+
{%- endif %}
|
| 35 |
+
{%- endfor %}
|
| 36 |
+
{%- elif content is none or content is undefined %}
|
| 37 |
+
{{- '' }}
|
| 38 |
+
{%- else %}
|
| 39 |
+
{{- raise_exception('Unexpected content type.') }}
|
| 40 |
+
{%- endif %}
|
| 41 |
+
{%- endmacro %}
|
| 42 |
+
{%- if not messages %}
|
| 43 |
+
{{- raise_exception('No messages provided.') }}
|
| 44 |
+
{%- endif %}
|
| 45 |
+
{%- if tools and tools is iterable and tools is not mapping %}
|
| 46 |
+
{{- '<|im_start|>system\n' }}
|
| 47 |
+
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
|
| 48 |
+
{%- for tool in tools %}
|
| 49 |
+
{{- "\n" }}
|
| 50 |
+
{{- tool | tojson }}
|
| 51 |
+
{%- endfor %}
|
| 52 |
+
{{- "\n</tools>" }}
|
| 53 |
+
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
| 54 |
+
{%- if messages[0].role == 'system' %}
|
| 55 |
+
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
| 56 |
+
{%- if content %}
|
| 57 |
+
{{- '\n\n' + content }}
|
| 58 |
+
{%- endif %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<|im_end|>\n' }}
|
| 61 |
+
{%- else %}
|
| 62 |
+
{%- if messages[0].role == 'system' %}
|
| 63 |
+
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
| 64 |
+
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
|
| 65 |
+
{%- endif %}
|
| 66 |
+
{%- endif %}
|
| 67 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 68 |
+
{%- for message in messages[::-1] %}
|
| 69 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 70 |
+
{%- if ns.multi_step_tool and message.role == "user" %}
|
| 71 |
+
{%- set content = render_content(message.content, false)|trim %}
|
| 72 |
+
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
|
| 73 |
+
{%- set ns.multi_step_tool = false %}
|
| 74 |
+
{%- set ns.last_query_index = index %}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{%- endif %}
|
| 77 |
+
{%- endfor %}
|
| 78 |
+
{%- if ns.multi_step_tool %}
|
| 79 |
+
{{- raise_exception('No user query found in messages.') }}
|
| 80 |
+
{%- endif %}
|
| 81 |
+
{%- for message in messages %}
|
| 82 |
+
{%- set content = render_content(message.content, true)|trim %}
|
| 83 |
+
{%- if message.role == "system" %}
|
| 84 |
+
{%- if not loop.first %}
|
| 85 |
+
{{- raise_exception('System message must be at the beginning.') }}
|
| 86 |
+
{%- endif %}
|
| 87 |
+
{%- elif message.role == "user" %}
|
| 88 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 89 |
+
{%- elif message.role == "assistant" %}
|
| 90 |
+
{%- set reasoning_content = '' %}
|
| 91 |
+
{%- if message.reasoning_content is string %}
|
| 92 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 93 |
+
{%- else %}
|
| 94 |
+
{%- if '</think>' in content %}
|
| 95 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 96 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 97 |
+
{%- endif %}
|
| 98 |
+
{%- endif %}
|
| 99 |
+
{%- set reasoning_content = reasoning_content|trim %}
|
| 100 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 101 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
|
| 102 |
+
{%- else %}
|
| 103 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 104 |
+
{%- endif %}
|
| 105 |
+
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
|
| 106 |
+
{%- for tool_call in message.tool_calls %}
|
| 107 |
+
{%- if tool_call.function is defined %}
|
| 108 |
+
{%- set tool_call = tool_call.function %}
|
| 109 |
+
{%- endif %}
|
| 110 |
+
{%- if loop.first %}
|
| 111 |
+
{%- if content|trim %}
|
| 112 |
+
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 113 |
+
{%- else %}
|
| 114 |
+
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 115 |
+
{%- endif %}
|
| 116 |
+
{%- else %}
|
| 117 |
+
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
| 118 |
+
{%- endif %}
|
| 119 |
+
{%- if tool_call.arguments is defined %}
|
| 120 |
+
{%- for args_name, args_value in tool_call.arguments|items %}
|
| 121 |
+
{{- '<parameter=' + args_name + '>\n' }}
|
| 122 |
+
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
| 123 |
+
{{- args_value }}
|
| 124 |
+
{{- '\n</parameter>\n' }}
|
| 125 |
+
{%- endfor %}
|
| 126 |
+
{%- endif %}
|
| 127 |
+
{{- '</function>\n</tool_call>' }}
|
| 128 |
+
{%- endfor %}
|
| 129 |
+
{%- endif %}
|
| 130 |
+
{{- '<|im_end|>\n' }}
|
| 131 |
+
{%- elif message.role == "tool" %}
|
| 132 |
+
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
| 133 |
+
{{- '<|im_start|>user' }}
|
| 134 |
+
{%- endif %}
|
| 135 |
+
{{- '\n<tool_response>\n' }}
|
| 136 |
+
{{- content }}
|
| 137 |
+
{{- '\n</tool_response>' }}
|
| 138 |
+
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
| 139 |
+
{{- '<|im_end|>\n' }}
|
| 140 |
+
{%- elif loop.last %}
|
| 141 |
+
{{- '<|im_end|>\n' }}
|
| 142 |
+
{%- endif %}
|
| 143 |
+
{%- else %}
|
| 144 |
+
{{- raise_exception('Unexpected message role.') }}
|
| 145 |
+
{%- endif %}
|
| 146 |
+
{%- endfor %}
|
| 147 |
+
{%- if add_generation_prompt %}
|
| 148 |
+
{{- '<|im_start|>assistant\n' }}
|
| 149 |
+
{%- if enable_thinking is defined and enable_thinking is true %}
|
| 150 |
+
{{- '<think>\n' }}
|
| 151 |
+
{%- else %}
|
| 152 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 153 |
+
{%- endif %}
|
| 154 |
+
{%- endif %}
|
tokenizer/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06b9509352d2af50381ab2247e083b80d32d5c0aba91c272ca9ff729b6a0e523
|
| 3 |
+
size 19989325
|
tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"audio_bos_token": "<|audio_start|>",
|
| 4 |
+
"audio_eos_token": "<|audio_end|>",
|
| 5 |
+
"audio_token": "<|audio_pad|>",
|
| 6 |
+
"backend": "tokenizers",
|
| 7 |
+
"bos_token": null,
|
| 8 |
+
"clean_up_tokenization_spaces": false,
|
| 9 |
+
"eos_token": "<|im_end|>",
|
| 10 |
+
"errors": "replace",
|
| 11 |
+
"image_token": "<|image_pad|>",
|
| 12 |
+
"is_local": false,
|
| 13 |
+
"local_files_only": false,
|
| 14 |
+
"model_max_length": 262144,
|
| 15 |
+
"model_specific_special_tokens": {
|
| 16 |
+
"audio_bos_token": "<|audio_start|>",
|
| 17 |
+
"audio_eos_token": "<|audio_end|>",
|
| 18 |
+
"audio_token": "<|audio_pad|>",
|
| 19 |
+
"image_token": "<|image_pad|>",
|
| 20 |
+
"video_token": "<|video_pad|>",
|
| 21 |
+
"vision_bos_token": "<|vision_start|>",
|
| 22 |
+
"vision_eos_token": "<|vision_end|>"
|
| 23 |
+
},
|
| 24 |
+
"pad_token": "<|endoftext|>",
|
| 25 |
+
"pretokenize_regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
| 26 |
+
"split_special_tokens": false,
|
| 27 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 28 |
+
"unk_token": null,
|
| 29 |
+
"video_token": "<|video_pad|>",
|
| 30 |
+
"vision_bos_token": "<|vision_start|>",
|
| 31 |
+
"vision_eos_token": "<|vision_end|>"
|
| 32 |
+
}
|
train-Copy1.py
ADDED
|
@@ -0,0 +1,924 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import wandb, comet_ml
|
| 7 |
+
import random, time
|
| 8 |
+
import gc
|
| 9 |
+
import bitsandbytes as bnb
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import argparse
|
| 12 |
+
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from diffusers import CosmosTransformer3DModel, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler
|
| 15 |
+
from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
|
| 16 |
+
from torch.utils.data import DataLoader, Sampler
|
| 17 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
from datasets import load_from_disk
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
from PIL import Image, ImageOps
|
| 23 |
+
from torch.utils.checkpoint import checkpoint
|
| 24 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 25 |
+
from contextlib import nullcontext
|
| 26 |
+
from transformers.optimization import Adafactor
|
| 27 |
+
|
| 28 |
+
# Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
|
| 29 |
+
from muon_adamw8bit import MuonAdamW8bit
|
| 30 |
+
|
| 31 |
+
os.environ["NCCL_P2P_DISABLE"] = "1"
|
| 32 |
+
os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
|
| 33 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 34 |
+
|
| 35 |
+
# --------------------------- Параметры ---------------------------
|
| 36 |
+
ds_path = "datasets/ds234_640_vae_qwen"
|
| 37 |
+
project = "transformer"
|
| 38 |
+
|
| 39 |
+
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 40 |
+
local_bs = max(1, int((gpu_mem_gb / 32) * 7))
|
| 41 |
+
num_gpus = torch.cuda.device_count()
|
| 42 |
+
batch_size = local_bs * num_gpus
|
| 43 |
+
|
| 44 |
+
base_learning_rate = 4e-5
|
| 45 |
+
min_learning_rate = 4e-6
|
| 46 |
+
|
| 47 |
+
learning_rate_scale = 3
|
| 48 |
+
base_learning_rate = base_learning_rate / learning_rate_scale
|
| 49 |
+
min_learning_rate = min_learning_rate / learning_rate_scale
|
| 50 |
+
print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
|
| 51 |
+
|
| 52 |
+
num_epochs = num_gpus
|
| 53 |
+
sink_interval_share = 10
|
| 54 |
+
sample_interval_min = 20
|
| 55 |
+
cfg_dropout = 0.10
|
| 56 |
+
# Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно)
|
| 57 |
+
sigmoid_bias = 0.1
|
| 58 |
+
max_length = 250
|
| 59 |
+
use_precomputed_embeddings = False
|
| 60 |
+
use_wandb = False
|
| 61 |
+
use_comet_ml = False
|
| 62 |
+
save_model = True
|
| 63 |
+
use_decay = True
|
| 64 |
+
fbp = False
|
| 65 |
+
torch_compile = False
|
| 66 |
+
transformer_gradient = True
|
| 67 |
+
loss_normalize = False
|
| 68 |
+
fixed_seed = False
|
| 69 |
+
shuffle = True
|
| 70 |
+
optimizer_type = "adafactor"
|
| 71 |
+
|
| 72 |
+
if optimizer_type == "muon_adam8bit":
|
| 73 |
+
batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
|
| 74 |
+
muon_lr_scale = 500
|
| 75 |
+
|
| 76 |
+
comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
|
| 77 |
+
comet_ml_workspace = "recoilme"
|
| 78 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 79 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 80 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 81 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 82 |
+
torch.backends.cuda.enable_math_sdp(False)
|
| 83 |
+
save_barrier = 1.25
|
| 84 |
+
warmup_percent = 0.0025
|
| 85 |
+
betta2 = 0.997
|
| 86 |
+
eps = 1e-6
|
| 87 |
+
clip_grad_norm = 1.0
|
| 88 |
+
limit = 0
|
| 89 |
+
checkpoints_folder = ""
|
| 90 |
+
gradient_accumulation_steps = 1
|
| 91 |
+
|
| 92 |
+
dtype = torch.float32
|
| 93 |
+
mixed_precision = "bf16"
|
| 94 |
+
|
| 95 |
+
# Параметры для диффузии
|
| 96 |
+
n_diffusion_steps = 40
|
| 97 |
+
samples_to_generate = 12
|
| 98 |
+
guidance_scale = 7.0
|
| 99 |
+
|
| 100 |
+
# Папки для сохранения результатов
|
| 101 |
+
generated_folder = "samples"
|
| 102 |
+
os.makedirs(generated_folder, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
# Настройка seed
|
| 105 |
+
current_date = datetime.now()
|
| 106 |
+
seed = int(current_date.strftime("%Y%m%d")) + 42
|
| 107 |
+
if fixed_seed:
|
| 108 |
+
torch.manual_seed(seed)
|
| 109 |
+
np.random.seed(seed)
|
| 110 |
+
random.seed(seed)
|
| 111 |
+
if torch.cuda.is_available():
|
| 112 |
+
torch.cuda.manual_seed_all(seed)
|
| 113 |
+
|
| 114 |
+
accelerator = Accelerator(
|
| 115 |
+
mixed_precision=mixed_precision,
|
| 116 |
+
gradient_accumulation_steps=gradient_accumulation_steps
|
| 117 |
+
)
|
| 118 |
+
device = accelerator.device
|
| 119 |
+
|
| 120 |
+
print("init")
|
| 121 |
+
parser = argparse.ArgumentParser(description='Train a model on a dataset.')
|
| 122 |
+
parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset')
|
| 123 |
+
parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model')
|
| 124 |
+
parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size')
|
| 125 |
+
parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate')
|
| 126 |
+
parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate')
|
| 127 |
+
parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling')
|
| 128 |
+
parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5')
|
| 129 |
+
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
|
| 132 |
+
batch_size = args.batch
|
| 133 |
+
ds_path = args.ds_path
|
| 134 |
+
base_learning_rate = args.max_lr
|
| 135 |
+
min_learning_rate = args.min_lr
|
| 136 |
+
num_epochs = args.ep
|
| 137 |
+
lvl = args.lvl
|
| 138 |
+
if args.dry_run:
|
| 139 |
+
save_model = False
|
| 140 |
+
if lvl >= 0.1:
|
| 141 |
+
base_learning_rate = base_learning_rate / lvl
|
| 142 |
+
min_learning_rate = min_learning_rate / lvl
|
| 143 |
+
print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}")
|
| 144 |
+
|
| 145 |
+
# --------------------------- Инициализация WandB ---------------------------
|
| 146 |
+
if accelerator.is_main_process:
|
| 147 |
+
if use_wandb:
|
| 148 |
+
wandb.init(project=project, config={
|
| 149 |
+
"batch_size": batch_size,
|
| 150 |
+
"base_learning_rate": base_learning_rate,
|
| 151 |
+
"num_epochs": num_epochs,
|
| 152 |
+
"optimizer_type": optimizer_type,
|
| 153 |
+
})
|
| 154 |
+
if use_comet_ml:
|
| 155 |
+
from comet_ml import Experiment
|
| 156 |
+
comet_experiment = Experiment(
|
| 157 |
+
api_key=comet_ml_api_key,
|
| 158 |
+
project_name=project,
|
| 159 |
+
workspace=comet_ml_workspace
|
| 160 |
+
)
|
| 161 |
+
hyper_params = {
|
| 162 |
+
"batch_size": batch_size,
|
| 163 |
+
"base_learning_rate": base_learning_rate,
|
| 164 |
+
"num_epochs": num_epochs,
|
| 165 |
+
}
|
| 166 |
+
comet_experiment.log_parameters(hyper_params)
|
| 167 |
+
|
| 168 |
+
# --------------------------- Загрузка моделей ---------------------------
|
| 169 |
+
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).to(dtype=dtype).eval()
|
| 170 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
|
| 171 |
+
tokenizer = None
|
| 172 |
+
text_encoder = None
|
| 173 |
+
|
| 174 |
+
def load_text_encoder():
|
| 175 |
+
global tokenizer, text_encoder
|
| 176 |
+
if tokenizer is None:
|
| 177 |
+
tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer")
|
| 178 |
+
if text_encoder is None:
|
| 179 |
+
text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained(
|
| 180 |
+
"text_encoder",
|
| 181 |
+
torch_dtype=dtype
|
| 182 |
+
).to(device).eval()
|
| 183 |
+
|
| 184 |
+
load_text_encoder()
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def encode_texts(text, max_length=max_length):
|
| 188 |
+
if text is None:
|
| 189 |
+
text = ""
|
| 190 |
+
if isinstance(text, str):
|
| 191 |
+
text = [text]
|
| 192 |
+
|
| 193 |
+
formatted_prompts = []
|
| 194 |
+
for t in text:
|
| 195 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 196 |
+
formatted_prompts.append(
|
| 197 |
+
tokenizer.apply_chat_template(
|
| 198 |
+
messages,
|
| 199 |
+
tokenize=False,
|
| 200 |
+
add_generation_prompt=False
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
toks = tokenizer(
|
| 205 |
+
formatted_prompts,
|
| 206 |
+
padding="max_length",
|
| 207 |
+
max_length=max_length,
|
| 208 |
+
truncation=True,
|
| 209 |
+
return_tensors="pt"
|
| 210 |
+
).to(device)
|
| 211 |
+
|
| 212 |
+
outputs = text_encoder(
|
| 213 |
+
input_ids=toks.input_ids,
|
| 214 |
+
attention_mask=toks.attention_mask,
|
| 215 |
+
output_hidden_states=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
hidden = outputs.hidden_states[-2].to(dtype=dtype)
|
| 219 |
+
|
| 220 |
+
lengths = toks.attention_mask.sum(dim=1)
|
| 221 |
+
for i, length in enumerate(lengths):
|
| 222 |
+
hidden[i, length:] = 0
|
| 223 |
+
|
| 224 |
+
return hidden, toks.attention_mask.to(dtype=torch.int64)
|
| 225 |
+
|
| 226 |
+
@torch.no_grad()
|
| 227 |
+
def encode_texts_fast(text, max_length=max_length):
|
| 228 |
+
if text is None: text = ""
|
| 229 |
+
if isinstance(text, str): text = [text]
|
| 230 |
+
|
| 231 |
+
formatted_prompts = []
|
| 232 |
+
for t in text:
|
| 233 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 234 |
+
formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
|
| 235 |
+
|
| 236 |
+
toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
|
| 237 |
+
outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
|
| 238 |
+
|
| 239 |
+
last_hidden = outputs.hidden_states[-2].to(dtype=dtype)
|
| 240 |
+
|
| 241 |
+
lengths = toks.attention_mask.sum(dim=1)
|
| 242 |
+
for i, length in enumerate(lengths):
|
| 243 |
+
last_hidden[i, length:] = 0
|
| 244 |
+
|
| 245 |
+
return last_hidden, toks.attention_mask.to(dtype=torch.int64)
|
| 246 |
+
|
| 247 |
+
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 248 |
+
if shift_factor is None:
|
| 249 |
+
shift_factor = 0.0
|
| 250 |
+
|
| 251 |
+
scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
|
| 252 |
+
if scaling_factor is None:
|
| 253 |
+
scaling_factor = 1.0
|
| 254 |
+
|
| 255 |
+
mean = getattr(vae.config, "latents_mean", None)
|
| 256 |
+
std = getattr(vae.config, "latents_std", None)
|
| 257 |
+
if mean is not None and std is not None:
|
| 258 |
+
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
|
| 259 |
+
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
|
| 260 |
+
# Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
|
| 261 |
+
#latents_std = 1.0 / torch.tensor(std).view(1, len(std), 1, 1, 1)
|
| 262 |
+
else:
|
| 263 |
+
latents_std = None
|
| 264 |
+
latents_mean = None
|
| 265 |
+
|
| 266 |
+
if scheduler is not None:
|
| 267 |
+
scheduler.register_to_config(
|
| 268 |
+
sigma_max=getattr(scheduler.config, "sigma_max", 80.0),
|
| 269 |
+
sigma_min=getattr(scheduler.config, "sigma_min", 0.002),
|
| 270 |
+
sigma_data=getattr(scheduler.config, "sigma_data", 1.0),
|
| 271 |
+
final_sigmas_type=getattr(scheduler.config, "final_sigmas_type", "sigma_min"),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
import numpy as np
|
| 275 |
+
from torch.utils.data import Sampler
|
| 276 |
+
|
| 277 |
+
class DistributedResolutionBatchSampler(Sampler):
|
| 278 |
+
def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True):
|
| 279 |
+
self.dataset = dataset
|
| 280 |
+
self.num_replicas = num_replicas
|
| 281 |
+
self.rank = rank
|
| 282 |
+
self.shuffle = shuffle
|
| 283 |
+
self.drop_last = drop_last
|
| 284 |
+
self.epoch = 0
|
| 285 |
+
|
| 286 |
+
self.batch_size = max(1, batch_size // num_replicas)
|
| 287 |
+
self.global_batch = self.batch_size * num_replicas
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
widths = np.asarray(dataset["width"])
|
| 291 |
+
heights = np.asarray(dataset["height"])
|
| 292 |
+
except KeyError:
|
| 293 |
+
widths = np.zeros(len(dataset))
|
| 294 |
+
heights = np.zeros(len(dataset))
|
| 295 |
+
|
| 296 |
+
groups = {}
|
| 297 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
| 298 |
+
groups.setdefault((w, h), []).append(i)
|
| 299 |
+
|
| 300 |
+
all_batches = []
|
| 301 |
+
for indices in groups.values():
|
| 302 |
+
idx = np.asarray(indices, dtype=np.int64)
|
| 303 |
+
num_batches = len(idx) // self.global_batch
|
| 304 |
+
if num_batches == 0:
|
| 305 |
+
continue
|
| 306 |
+
idx = idx[: num_batches * self.global_batch]
|
| 307 |
+
batches = idx.reshape(num_batches, self.global_batch)
|
| 308 |
+
all_batches.append(batches)
|
| 309 |
+
|
| 310 |
+
if len(all_batches) > 0:
|
| 311 |
+
self.global_batches = np.concatenate(all_batches, axis=0)
|
| 312 |
+
else:
|
| 313 |
+
self.global_batches = np.empty((0, self.global_batch), dtype=np.int64)
|
| 314 |
+
|
| 315 |
+
self.num_batches = len(self.global_batches)
|
| 316 |
+
|
| 317 |
+
def __iter__(self):
|
| 318 |
+
rng = np.random.RandomState(self.epoch)
|
| 319 |
+
order = np.arange(self.num_batches)
|
| 320 |
+
|
| 321 |
+
if self.shuffle:
|
| 322 |
+
rng.shuffle(order)
|
| 323 |
+
|
| 324 |
+
start = self.rank * self.batch_size
|
| 325 |
+
end = start + self.batch_size
|
| 326 |
+
|
| 327 |
+
for i in order:
|
| 328 |
+
yield self.global_batches[i][start:end]
|
| 329 |
+
|
| 330 |
+
def __len__(self):
|
| 331 |
+
return self.num_batches
|
| 332 |
+
|
| 333 |
+
def set_epoch(self, epoch):
|
| 334 |
+
self.epoch = epoch
|
| 335 |
+
|
| 336 |
+
def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
| 337 |
+
size_groups = defaultdict(list)
|
| 338 |
+
try:
|
| 339 |
+
widths = dataset["width"]
|
| 340 |
+
heights = dataset["height"]
|
| 341 |
+
except KeyError:
|
| 342 |
+
widths = [0] * len(dataset)
|
| 343 |
+
heights = [0] * len(dataset)
|
| 344 |
+
for i, (w, h) in enumerate(zip(widths, heights)):
|
| 345 |
+
size = (w, h)
|
| 346 |
+
size_groups[size].append(i)
|
| 347 |
+
|
| 348 |
+
fixed_samples = {}
|
| 349 |
+
for size, indices in size_groups.items():
|
| 350 |
+
n_samples = min(samples_per_group, len(indices))
|
| 351 |
+
if len(size_groups)==1:
|
| 352 |
+
n_samples = samples_to_generate
|
| 353 |
+
if n_samples == 0:
|
| 354 |
+
continue
|
| 355 |
+
sample_indices = random.sample(indices, n_samples)
|
| 356 |
+
samples_data = [dataset[idx] for idx in sample_indices]
|
| 357 |
+
|
| 358 |
+
latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
|
| 359 |
+
|
| 360 |
+
if latents.ndim == 4:
|
| 361 |
+
latents = latents.unsqueeze(2)
|
| 362 |
+
elif latents.ndim == 6:
|
| 363 |
+
latents = latents.squeeze(2)
|
| 364 |
+
|
| 365 |
+
texts = [item["text"] for item in samples_data]
|
| 366 |
+
|
| 367 |
+
if use_precomputed_embeddings:
|
| 368 |
+
embeddings = torch.tensor(
|
| 369 |
+
np.array([item["embeddings"] for item in samples_data]),
|
| 370 |
+
device=device,
|
| 371 |
+
dtype=dtype
|
| 372 |
+
)
|
| 373 |
+
masks = torch.tensor(
|
| 374 |
+
np.array([item["attention_mask"] for item in samples_data]),
|
| 375 |
+
device=device,
|
| 376 |
+
dtype=torch.int64
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
embeddings, masks = encode_texts(texts,max_length)
|
| 380 |
+
|
| 381 |
+
fixed_samples[size] = (latents, embeddings, masks, texts)
|
| 382 |
+
|
| 383 |
+
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
|
| 384 |
+
return fixed_samples
|
| 385 |
+
|
| 386 |
+
if limit > 0:
|
| 387 |
+
dataset = load_from_disk(ds_path).select(range(limit))
|
| 388 |
+
else:
|
| 389 |
+
dataset = load_from_disk(ds_path)
|
| 390 |
+
|
| 391 |
+
print(f"images: {len(dataset)}")
|
| 392 |
+
|
| 393 |
+
def collate_fn_simple(batch):
|
| 394 |
+
latents = torch.from_numpy(
|
| 395 |
+
np.array([item["vae"] for item in batch], dtype=np.float16)
|
| 396 |
+
).to(device, dtype=dtype)
|
| 397 |
+
|
| 398 |
+
if latents.ndim == 4:
|
| 399 |
+
latents = latents.unsqueeze(2)
|
| 400 |
+
elif latents.ndim == 6:
|
| 401 |
+
latents = latents.squeeze(2)
|
| 402 |
+
|
| 403 |
+
if use_precomputed_embeddings:
|
| 404 |
+
embeddings = torch.from_numpy(
|
| 405 |
+
np.array([item["embeddings"] for item in batch], dtype=np.float16)
|
| 406 |
+
).to(device, dtype=dtype)
|
| 407 |
+
|
| 408 |
+
attention_mask = torch.from_numpy(
|
| 409 |
+
np.array([item["attention_mask"] for item in batch], dtype=np.int64)
|
| 410 |
+
).to(device)
|
| 411 |
+
|
| 412 |
+
return latents, embeddings, attention_mask
|
| 413 |
+
|
| 414 |
+
raw_texts = [item["text"] for item in batch]
|
| 415 |
+
|
| 416 |
+
texts = [
|
| 417 |
+
"" if t.lower().startswith("zero")
|
| 418 |
+
else "" if random.random() < cfg_dropout
|
| 419 |
+
else t[1:].lstrip() if t.startswith(".")
|
| 420 |
+
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
|
| 421 |
+
for t in raw_texts
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
embeddings, attention_mask = encode_texts(texts,max_length)
|
| 425 |
+
attention_mask = attention_mask.to(dtype=torch.int64)
|
| 426 |
+
|
| 427 |
+
return latents, embeddings, attention_mask
|
| 428 |
+
|
| 429 |
+
batch_sampler = DistributedResolutionBatchSampler(
|
| 430 |
+
dataset=dataset,
|
| 431 |
+
batch_size=batch_size,
|
| 432 |
+
num_replicas=accelerator.num_processes,
|
| 433 |
+
rank=accelerator.process_index,
|
| 434 |
+
shuffle = shuffle
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
|
| 438 |
+
|
| 439 |
+
if accelerator.is_main_process:
|
| 440 |
+
print("Total samples", len(dataloader))
|
| 441 |
+
dataloader = accelerator.prepare(dataloader)
|
| 442 |
+
|
| 443 |
+
start_epoch = 0
|
| 444 |
+
global_step = 0
|
| 445 |
+
total_training_steps = (len(dataloader) * num_epochs)
|
| 446 |
+
world_size = accelerator.state.num_processes
|
| 447 |
+
|
| 448 |
+
latest_checkpoint = os.path.join(checkpoints_folder, project)
|
| 449 |
+
if os.path.isdir(latest_checkpoint):
|
| 450 |
+
print("Загружаем Transformer из чекпоинта:", latest_checkpoint)
|
| 451 |
+
transformer = CosmosTransformer3DModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
|
| 452 |
+
if transformer_gradient:
|
| 453 |
+
transformer.enable_gradient_checkpointing()
|
| 454 |
+
else:
|
| 455 |
+
raise FileNotFoundError(f"Transformer checkpoint not found at {latest_checkpoint}")
|
| 456 |
+
|
| 457 |
+
def create_optimizer(name, params):
|
| 458 |
+
if name == "adam8bit":
|
| 459 |
+
return bnb.optim.AdamW8bit(
|
| 460 |
+
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
|
| 461 |
+
)
|
| 462 |
+
elif name == "adam":
|
| 463 |
+
return torch.optim.AdamW(
|
| 464 |
+
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
|
| 465 |
+
)
|
| 466 |
+
elif name == "adafactor":
|
| 467 |
+
return Adafactor(
|
| 468 |
+
params,
|
| 469 |
+
lr=base_learning_rate,
|
| 470 |
+
eps=(1e-30, 1e-3),
|
| 471 |
+
clip_threshold=1.0,
|
| 472 |
+
decay_rate=-0.8,
|
| 473 |
+
beta1=None,
|
| 474 |
+
weight_decay=0.001,
|
| 475 |
+
relative_step=False,
|
| 476 |
+
scale_parameter=False,
|
| 477 |
+
warmup_init=False
|
| 478 |
+
)
|
| 479 |
+
elif name == "muon_adam8bit":
|
| 480 |
+
return MuonAdamW8bit(
|
| 481 |
+
params,
|
| 482 |
+
lr=base_learning_rate,
|
| 483 |
+
betas=(0.9, betta2),
|
| 484 |
+
eps=eps,
|
| 485 |
+
weight_decay=0.01,
|
| 486 |
+
muon_lr_mult=muon_lr_scale,
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
raise ValueError(f"Unknown optimizer: {name}")
|
| 490 |
+
|
| 491 |
+
if fbp:
|
| 492 |
+
trainable_params = list(transformer.parameters())
|
| 493 |
+
optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
|
| 494 |
+
def optimizer_hook(param):
|
| 495 |
+
optimizer_dict[param].step()
|
| 496 |
+
optimizer_dict[param].zero_grad(set_to_none=True)
|
| 497 |
+
for param in trainable_params:
|
| 498 |
+
param.register_post_accumulate_grad_hook(optimizer_hook)
|
| 499 |
+
transformer, optimizer = accelerator.prepare(transformer, optimizer_dict)
|
| 500 |
+
else:
|
| 501 |
+
#transformer.requires_grad_(True)
|
| 502 |
+
# 1. Сначала замораживаем ВООБЩЕ ВСЕ параметры
|
| 503 |
+
transformer.requires_grad_(False)
|
| 504 |
+
|
| 505 |
+
# 2. Определяем ключевое слово для слоев, которые нужно учить (Cross-Attention)
|
| 506 |
+
trainable_params_names = ["attn2"]
|
| 507 |
+
trainable_params = []
|
| 508 |
+
|
| 509 |
+
print("--- РАЗМОРОЖЕННЫЕ СЛОИ ---")
|
| 510 |
+
for name, param in transformer.named_parameters():
|
| 511 |
+
if any(target in name for target in trainable_params_names):
|
| 512 |
+
param.requires_grad_(True) # Размораживаем
|
| 513 |
+
trainable_params.append(param)
|
| 514 |
+
print(f"Обучаемый слой: {name}")
|
| 515 |
+
print("--------------------------")
|
| 516 |
+
|
| 517 |
+
# Защита от дурака
|
| 518 |
+
if len(trainable_params) == 0:
|
| 519 |
+
raise ValueError("Ошибка: ни один слой не был разморожен! Проверь ключи.")
|
| 520 |
+
|
| 521 |
+
optimizer = create_optimizer(optimizer_type, transformer.parameters())
|
| 522 |
+
|
| 523 |
+
def lr_schedule(step):
|
| 524 |
+
x = step / (total_training_steps * world_size)
|
| 525 |
+
warmup = warmup_percent
|
| 526 |
+
if not use_decay:
|
| 527 |
+
return base_learning_rate
|
| 528 |
+
if x < warmup:
|
| 529 |
+
return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
|
| 530 |
+
decay_ratio = (x - warmup) / (1 - warmup)
|
| 531 |
+
return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
|
| 532 |
+
(1 + math.cos(math.pi * decay_ratio))
|
| 533 |
+
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
|
| 534 |
+
|
| 535 |
+
if torch_compile:
|
| 536 |
+
print("Compiling Transformer... Это займет несколько минут, не прерывайте!")
|
| 537 |
+
transformer = torch.compile(transformer)
|
| 538 |
+
print("Compiling - ok")
|
| 539 |
+
|
| 540 |
+
if not fbp:
|
| 541 |
+
transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler)
|
| 542 |
+
|
| 543 |
+
# Фиксированные семплы
|
| 544 |
+
fixed_samples = get_fixed_samples_by_resolution(dataset)
|
| 545 |
+
|
| 546 |
+
def get_negative_embedding(neg_prompt="", batch_size=1):
|
| 547 |
+
if not neg_prompt:
|
| 548 |
+
hidden_dim = 2048
|
| 549 |
+
seq_len = max_length
|
| 550 |
+
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 551 |
+
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
|
| 552 |
+
return empty_emb, empty_mask
|
| 553 |
+
|
| 554 |
+
uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
|
| 555 |
+
uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
|
| 556 |
+
uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
|
| 557 |
+
|
| 558 |
+
return uncond_emb, uncond_mask
|
| 559 |
+
|
| 560 |
+
if use_precomputed_embeddings:
|
| 561 |
+
load_text_encoder()
|
| 562 |
+
uncond_emb, uncond_mask = get_negative_embedding("low quality")
|
| 563 |
+
uncond_emb = uncond_emb.to("cpu")
|
| 564 |
+
uncond_mask = uncond_mask.to("cpu")
|
| 565 |
+
del text_encoder
|
| 566 |
+
torch.cuda.empty_cache()
|
| 567 |
+
gc.collect()
|
| 568 |
+
text_encoder = None
|
| 569 |
+
else:
|
| 570 |
+
uncond_emb, uncond_mask = get_negative_embedding("low quality")
|
| 571 |
+
|
| 572 |
+
def pad_to_match(a, b, pad_value=0):
|
| 573 |
+
Ta, Tb = a.shape[1], b.shape[1]
|
| 574 |
+
if Ta == Tb:
|
| 575 |
+
return a, b
|
| 576 |
+
T = max(Ta, Tb)
|
| 577 |
+
def pad(x, T_target):
|
| 578 |
+
pad_len = T_target - x.shape[1]
|
| 579 |
+
if pad_len <= 0:
|
| 580 |
+
return x
|
| 581 |
+
return torch.nn.functional.pad(x, (0, 0, 0, pad_len), value=pad_value)
|
| 582 |
+
return pad(a, T), pad(b, T)
|
| 583 |
+
|
| 584 |
+
@torch.compiler.disable()
|
| 585 |
+
@torch.no_grad()
|
| 586 |
+
def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
| 587 |
+
uncond_emb, uncond_mask = uncond_data
|
| 588 |
+
uncond_emb = uncond_emb.to(device)
|
| 589 |
+
uncond_mask = uncond_mask.to(device)
|
| 590 |
+
|
| 591 |
+
original_model = None
|
| 592 |
+
try:
|
| 593 |
+
if not torch_compile:
|
| 594 |
+
original_model = accelerator.unwrap_model(transformer, keep_torch_compile=True).eval()
|
| 595 |
+
else:
|
| 596 |
+
original_model = transformer.eval()
|
| 597 |
+
|
| 598 |
+
vae.to(device=device).eval()
|
| 599 |
+
|
| 600 |
+
all_generated_images = []
|
| 601 |
+
all_captions = []
|
| 602 |
+
|
| 603 |
+
for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
|
| 604 |
+
width, height = size
|
| 605 |
+
|
| 606 |
+
curr_batch_size = sample_latents.shape[0]
|
| 607 |
+
in_channels = original_model.config.in_channels
|
| 608 |
+
|
| 609 |
+
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
|
| 610 |
+
|
| 611 |
+
sigmas_dtype = torch.float32
|
| 612 |
+
sigmas = torch.linspace(0, 1, n_diffusion_steps, dtype=sigmas_dtype)
|
| 613 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device)
|
| 614 |
+
|
| 615 |
+
if scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
|
| 616 |
+
scheduler.sigmas[-1] = scheduler.sigmas[-2]
|
| 617 |
+
if scheduler.sigmas[-1] == 0.0:
|
| 618 |
+
scheduler.sigmas[-1] = 1e-4
|
| 619 |
+
|
| 620 |
+
sigma_max = getattr(scheduler.config, "sigma_max", 80.0)
|
| 621 |
+
|
| 622 |
+
latents = torch.randn(
|
| 623 |
+
(curr_batch_size, in_channels, 1, sample_latents.shape[3], sample_latents.shape[4]),
|
| 624 |
+
device=device,
|
| 625 |
+
dtype=dtype,
|
| 626 |
+
generator=torch.Generator(device=device).manual_seed(seed)
|
| 627 |
+
) * sigma_max
|
| 628 |
+
|
| 629 |
+
padding_mask = torch.zeros((1, 1, sample_latents.shape[3], sample_latents.shape[4]), device=device, dtype=dtype)
|
| 630 |
+
|
| 631 |
+
if guidance_scale != 1:
|
| 632 |
+
neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
|
| 633 |
+
neg_emb_batch, sample_text_embeddings = pad_to_match(neg_emb_batch, sample_text_embeddings)
|
| 634 |
+
|
| 635 |
+
for i, t in enumerate(scheduler.timesteps):
|
| 636 |
+
current_sigma = scheduler.sigmas[i]
|
| 637 |
+
if current_sigma == 0.0:
|
| 638 |
+
current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
|
| 639 |
+
|
| 640 |
+
current_t = current_sigma / (current_sigma + 1.0)
|
| 641 |
+
c_in = 1.0 - current_t
|
| 642 |
+
c_skip = 1.0 - current_t
|
| 643 |
+
c_out = -current_t
|
| 644 |
+
|
| 645 |
+
latent_model_input = (latents * c_in).to(dtype)
|
| 646 |
+
|
| 647 |
+
t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t)
|
| 648 |
+
timestep_tensor = torch.tensor([t_val], device=device, dtype=dtype).expand(curr_batch_size)
|
| 649 |
+
|
| 650 |
+
noise_pred = original_model(
|
| 651 |
+
hidden_states=latent_model_input,
|
| 652 |
+
timestep=timestep_tensor,
|
| 653 |
+
encoder_hidden_states=sample_text_embeddings,
|
| 654 |
+
padding_mask=padding_mask,
|
| 655 |
+
return_dict=False
|
| 656 |
+
)[0]
|
| 657 |
+
|
| 658 |
+
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(dtype)
|
| 659 |
+
|
| 660 |
+
if guidance_scale != 1:
|
| 661 |
+
noise_pred_uncond = original_model(
|
| 662 |
+
hidden_states=latent_model_input,
|
| 663 |
+
timestep=timestep_tensor,
|
| 664 |
+
encoder_hidden_states=neg_emb_batch,
|
| 665 |
+
padding_mask=padding_mask,
|
| 666 |
+
return_dict=False
|
| 667 |
+
)[0]
|
| 668 |
+
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(dtype)
|
| 669 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
| 670 |
+
|
| 671 |
+
noise_pred = (latents - noise_pred) / current_sigma
|
| 672 |
+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 673 |
+
|
| 674 |
+
current_latents = latents
|
| 675 |
+
if step == 0:
|
| 676 |
+
current_latents = sample_latents
|
| 677 |
+
|
| 678 |
+
if latents_mean is not None and latents_std is not None:
|
| 679 |
+
sigma_data = getattr(scheduler.config, "sigma_data", 1.0)
|
| 680 |
+
# Переводим векторы нормализации в float32
|
| 681 |
+
l_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, torch.float32)
|
| 682 |
+
l_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, torch.float32)
|
| 683 |
+
|
| 684 |
+
# Кастуем латенты в float32 перед умножением, чтобы сохранить точность
|
| 685 |
+
latents_for_decode = (current_latents.to(torch.float32) * l_std) / sigma_data + l_mean
|
| 686 |
+
else:
|
| 687 |
+
latents_for_decode = current_latents.to(torch.float32)
|
| 688 |
+
|
| 689 |
+
# 2. Декодируем, ПРИНУДИТЕЛЬНО ВКЛЮЧИВ MATH_SDP только для этого шага!
|
| 690 |
+
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
| 691 |
+
decoded = vae.decode(latents_for_decode).sample
|
| 692 |
+
|
| 693 |
+
# 3. Отсекаем лишнее видео-измерение
|
| 694 |
+
if decoded.ndim == 5:
|
| 695 |
+
decoded = decoded[:, :, 0, :, :]
|
| 696 |
+
|
| 697 |
+
# 4. Он уже во float32, можно сразу пускать в цикл
|
| 698 |
+
decoded_fp32 = decoded
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
for img_idx, img_tensor in enumerate(decoded_fp32):
|
| 702 |
+
img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
|
| 703 |
+
img = img.transpose(1, 2, 0)
|
| 704 |
+
|
| 705 |
+
if np.isnan(img).any():
|
| 706 |
+
print("NaNs found, saving stopped! Step:", step)
|
| 707 |
+
img = np.nan_to_num(img, nan=0.0)
|
| 708 |
+
pil_img = Image.fromarray((img * 255).astype("uint8"))
|
| 709 |
+
|
| 710 |
+
max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
|
| 711 |
+
max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
|
| 712 |
+
max_w_overall = max(255, max_w_overall)
|
| 713 |
+
max_h_overall = max(255, max_h_overall)
|
| 714 |
+
|
| 715 |
+
padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
|
| 716 |
+
all_generated_images.append(padded_img)
|
| 717 |
+
|
| 718 |
+
caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
|
| 719 |
+
all_captions.append(caption_text)
|
| 720 |
+
|
| 721 |
+
sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
|
| 722 |
+
pil_img.save(sample_path, "JPEG", quality=95)
|
| 723 |
+
|
| 724 |
+
if use_wandb and accelerator.is_main_process:
|
| 725 |
+
wandb_images = [
|
| 726 |
+
wandb.Image(img, caption=f"{all_captions[i]}")
|
| 727 |
+
for i, img in enumerate(all_generated_images)
|
| 728 |
+
]
|
| 729 |
+
wandb.log({"generated_images": wandb_images})
|
| 730 |
+
if use_comet_ml and accelerator.is_main_process:
|
| 731 |
+
for i, img in enumerate(all_generated_images):
|
| 732 |
+
comet_experiment.log_image(
|
| 733 |
+
image_data=img,
|
| 734 |
+
name=f"step_{step}_img_{i}",
|
| 735 |
+
step=step,
|
| 736 |
+
metadata={"caption": all_captions[i]}
|
| 737 |
+
)
|
| 738 |
+
finally:
|
| 739 |
+
vae.to("cpu")
|
| 740 |
+
uncond_emb = uncond_emb.to("cpu")
|
| 741 |
+
uncond_mask = uncond_mask.to("cpu")
|
| 742 |
+
try:
|
| 743 |
+
all_generated_images.clear()
|
| 744 |
+
all_captions.clear()
|
| 745 |
+
del all_generated_images, all_captions
|
| 746 |
+
del latents, current_latents, latent_model_input
|
| 747 |
+
del decoded, decoded_fp32
|
| 748 |
+
del sample_latents, sample_text_embeddings, sample_mask
|
| 749 |
+
del noise_pred, noise_pred_uncond
|
| 750 |
+
except UnboundLocalError:
|
| 751 |
+
pass
|
| 752 |
+
|
| 753 |
+
torch.cuda.synchronize()
|
| 754 |
+
torch.cuda.empty_cache()
|
| 755 |
+
gc.collect()
|
| 756 |
+
|
| 757 |
+
if accelerator.is_main_process:
|
| 758 |
+
if save_model:
|
| 759 |
+
print("Генерация сэмплов до старта обучения...")
|
| 760 |
+
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
|
| 761 |
+
accelerator.wait_for_everyone()
|
| 762 |
+
|
| 763 |
+
def save_checkpoint(model_net, variant=""):
|
| 764 |
+
if accelerator.is_main_process:
|
| 765 |
+
model_to_save = None
|
| 766 |
+
if not torch_compile:
|
| 767 |
+
model_to_save = accelerator.unwrap_model(model_net)
|
| 768 |
+
else:
|
| 769 |
+
model_to_save = model_net
|
| 770 |
+
|
| 771 |
+
if variant != "":
|
| 772 |
+
model_to_save.to(dtype=torch.bfloat16).save_pretrained(
|
| 773 |
+
os.path.join(checkpoints_folder, f"{project}"), variant=variant
|
| 774 |
+
)
|
| 775 |
+
else:
|
| 776 |
+
model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
|
| 777 |
+
|
| 778 |
+
torch.cuda.synchronize()
|
| 779 |
+
torch.cuda.empty_cache()
|
| 780 |
+
gc.collect()
|
| 781 |
+
|
| 782 |
+
if accelerator.is_main_process:
|
| 783 |
+
print(f"Total steps per GPU: {total_training_steps}")
|
| 784 |
+
|
| 785 |
+
epoch_loss_points = []
|
| 786 |
+
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
|
| 787 |
+
|
| 788 |
+
steps_per_epoch = len(dataloader)
|
| 789 |
+
sink_interval = max(1, steps_per_epoch // sink_interval_share)
|
| 790 |
+
min_loss = 4.
|
| 791 |
+
last_sample_time = time.time()
|
| 792 |
+
sample_interval_seconds = sample_interval_min * 60
|
| 793 |
+
|
| 794 |
+
for epoch in range(start_epoch, start_epoch + num_epochs):
|
| 795 |
+
batch_losses = []
|
| 796 |
+
batch_grads = []
|
| 797 |
+
batch_sampler.set_epoch(epoch)
|
| 798 |
+
accelerator.wait_for_everyone()
|
| 799 |
+
transformer.train()
|
| 800 |
+
|
| 801 |
+
for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
|
| 802 |
+
|
| 803 |
+
if save_model == False and epoch == 0 and step == 5 :
|
| 804 |
+
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
| 805 |
+
print(f"Шаг {step}: {used_gb:.2f} GB")
|
| 806 |
+
|
| 807 |
+
amp_context = accelerator.autocast() if torch_compile else nullcontext()
|
| 808 |
+
with accelerator.accumulate(transformer):
|
| 809 |
+
with amp_context:
|
| 810 |
+
noise = torch.randn_like(latents, dtype=latents.dtype)
|
| 811 |
+
|
| 812 |
+
t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + sigmoid_bias)
|
| 813 |
+
|
| 814 |
+
noisy_latents_5d = (1.0 - t.view(-1, 1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1, 1) * noise
|
| 815 |
+
target_5d = noise - latents
|
| 816 |
+
|
| 817 |
+
padding_mask = torch.zeros((1, 1, latents.shape[3], latents.shape[4]), device=device, dtype=dtype)
|
| 818 |
+
|
| 819 |
+
timestep_tensor = t.flatten().to(dtype)
|
| 820 |
+
|
| 821 |
+
model_pred = transformer(
|
| 822 |
+
hidden_states=noisy_latents_5d,
|
| 823 |
+
timestep=timestep_tensor,
|
| 824 |
+
encoder_hidden_states=embeddings,
|
| 825 |
+
padding_mask=padding_mask,
|
| 826 |
+
return_dict=False
|
| 827 |
+
)[0]
|
| 828 |
+
|
| 829 |
+
mse_loss = F.mse_loss(model_pred.float(), target_5d.float())
|
| 830 |
+
batch_losses.append(mse_loss.detach().item())
|
| 831 |
+
|
| 832 |
+
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
|
| 833 |
+
accelerator.wait_for_everyone()
|
| 834 |
+
|
| 835 |
+
losses_dict = {}
|
| 836 |
+
losses_dict["mse"] = mse_loss
|
| 837 |
+
|
| 838 |
+
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
|
| 839 |
+
accelerator.wait_for_everyone()
|
| 840 |
+
|
| 841 |
+
accelerator.backward(mse_loss)
|
| 842 |
+
|
| 843 |
+
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
|
| 844 |
+
accelerator.wait_for_everyone()
|
| 845 |
+
|
| 846 |
+
grad = 0.0
|
| 847 |
+
if not fbp:
|
| 848 |
+
if accelerator.sync_gradients:
|
| 849 |
+
grad_val = accelerator.clip_grad_norm_(transformer.parameters(), clip_grad_norm)
|
| 850 |
+
grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
|
| 851 |
+
optimizer.step()
|
| 852 |
+
lr_scheduler.step()
|
| 853 |
+
optimizer.zero_grad(set_to_none=True)
|
| 854 |
+
|
| 855 |
+
if accelerator.sync_gradients:
|
| 856 |
+
global_step += 1
|
| 857 |
+
progress_bar.update(1)
|
| 858 |
+
if accelerator.is_main_process:
|
| 859 |
+
if fbp:
|
| 860 |
+
current_lr = base_learning_rate
|
| 861 |
+
else:
|
| 862 |
+
current_lr = lr_scheduler.get_last_lr()[0]
|
| 863 |
+
batch_grads.append(grad)
|
| 864 |
+
|
| 865 |
+
log_data = {}
|
| 866 |
+
log_data["loss_mse"] = mse_loss.detach().item()
|
| 867 |
+
log_data["lr"] = current_lr
|
| 868 |
+
log_data["grad"] = grad
|
| 869 |
+
if accelerator.sync_gradients:
|
| 870 |
+
if use_wandb:
|
| 871 |
+
wandb.log(log_data, step=global_step)
|
| 872 |
+
if use_comet_ml:
|
| 873 |
+
comet_experiment.log_metrics(log_data, step=global_step)
|
| 874 |
+
|
| 875 |
+
current_time = time.time()
|
| 876 |
+
is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
|
| 877 |
+
if is_time_to_sample or global_step == 50:
|
| 878 |
+
if save_model:
|
| 879 |
+
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
|
| 880 |
+
elif epoch % 10 == 0:
|
| 881 |
+
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
|
| 882 |
+
last_n = sink_interval
|
| 883 |
+
|
| 884 |
+
if save_model:
|
| 885 |
+
has_losses = len(batch_losses) > 0
|
| 886 |
+
avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
|
| 887 |
+
last_loss = batch_losses[-1] if has_losses else 0.0
|
| 888 |
+
max_loss = max(avg_sample_loss, last_loss)
|
| 889 |
+
should_save = max_loss < min_loss * save_barrier
|
| 890 |
+
print(
|
| 891 |
+
f"Saving: {should_save} | Max: {max_loss:.4f} | "
|
| 892 |
+
f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
|
| 893 |
+
)
|
| 894 |
+
if should_save:
|
| 895 |
+
min_loss = max_loss
|
| 896 |
+
save_checkpoint(transformer)
|
| 897 |
+
last_sample_time = current_time
|
| 898 |
+
transformer.train()
|
| 899 |
+
|
| 900 |
+
if accelerator.is_main_process:
|
| 901 |
+
avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
|
| 902 |
+
avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
|
| 903 |
+
|
| 904 |
+
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
|
| 905 |
+
log_data_ep = {
|
| 906 |
+
"epoch_loss": avg_epoch_loss,
|
| 907 |
+
"epoch_grad": avg_epoch_grad,
|
| 908 |
+
"epoch": epoch + 1,
|
| 909 |
+
}
|
| 910 |
+
if use_wandb:
|
| 911 |
+
wandb.log(log_data_ep)
|
| 912 |
+
if use_comet_ml:
|
| 913 |
+
comet_experiment.log_metrics(log_data_ep)
|
| 914 |
+
|
| 915 |
+
if accelerator.is_main_process:
|
| 916 |
+
print("Обучение завершено! Сохраняем финальную модель...")
|
| 917 |
+
save_checkpoint(transformer,"bf16")
|
| 918 |
+
if use_comet_ml:
|
| 919 |
+
comet_experiment.end()
|
| 920 |
+
accelerator.free_memory()
|
| 921 |
+
if torch.distributed.is_initialized():
|
| 922 |
+
torch.distributed.destroy_process_group()
|
| 923 |
+
|
| 924 |
+
print("Готово!")
|
transformer/config.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "CosmosTransformer3DModel",
|
| 3 |
+
"_diffusers_version": "0.37.1",
|
| 4 |
+
"_name_or_path": "transformer",
|
| 5 |
+
"adaln_lora_dim": 256,
|
| 6 |
+
"attention_head_dim": 128,
|
| 7 |
+
"concat_padding_mask": true,
|
| 8 |
+
"controlnet_block_every_n": null,
|
| 9 |
+
"crossattn_proj_in_channels": 1024,
|
| 10 |
+
"encoder_hidden_states_channels": 1024,
|
| 11 |
+
"extra_pos_embed_type": null,
|
| 12 |
+
"img_context_dim_in": null,
|
| 13 |
+
"img_context_dim_out": 2048,
|
| 14 |
+
"img_context_num_tokens": 256,
|
| 15 |
+
"in_channels": 16,
|
| 16 |
+
"max_size": [
|
| 17 |
+
128,
|
| 18 |
+
240,
|
| 19 |
+
240
|
| 20 |
+
],
|
| 21 |
+
"mlp_ratio": 4.0,
|
| 22 |
+
"num_attention_heads": 16,
|
| 23 |
+
"num_layers": 28,
|
| 24 |
+
"out_channels": 16,
|
| 25 |
+
"patch_size": [
|
| 26 |
+
1,
|
| 27 |
+
2,
|
| 28 |
+
2
|
| 29 |
+
],
|
| 30 |
+
"rope_scale": [
|
| 31 |
+
1.0,
|
| 32 |
+
4.0,
|
| 33 |
+
4.0
|
| 34 |
+
],
|
| 35 |
+
"text_embed_dim": 1024,
|
| 36 |
+
"use_crossattn_projection": false
|
| 37 |
+
}
|
transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:501d3b67f235189364d1bbeb3862fcdfc74e957f033e0714e8c2a12ba95a7041
|
| 3 |
+
size 7825687184
|
vae/.ipynb_checkpoints/config-checkpoint.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKLQwenImage",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"attn_scales": [],
|
| 5 |
+
"base_dim": 96,
|
| 6 |
+
"dim_mult": [
|
| 7 |
+
1,
|
| 8 |
+
2,
|
| 9 |
+
4,
|
| 10 |
+
4
|
| 11 |
+
],
|
| 12 |
+
"dropout": 0.0,
|
| 13 |
+
"latents_mean": [
|
| 14 |
+
-0.7571,
|
| 15 |
+
-0.7089,
|
| 16 |
+
-0.9113,
|
| 17 |
+
0.1075,
|
| 18 |
+
-0.1745,
|
| 19 |
+
0.9653,
|
| 20 |
+
-0.1517,
|
| 21 |
+
1.5508,
|
| 22 |
+
0.4134,
|
| 23 |
+
-0.0715,
|
| 24 |
+
0.5517,
|
| 25 |
+
-0.3632,
|
| 26 |
+
-0.1922,
|
| 27 |
+
-0.9497,
|
| 28 |
+
0.2503,
|
| 29 |
+
-0.2921
|
| 30 |
+
],
|
| 31 |
+
"latents_std": [
|
| 32 |
+
2.8184,
|
| 33 |
+
1.4541,
|
| 34 |
+
2.3275,
|
| 35 |
+
2.6558,
|
| 36 |
+
1.2196,
|
| 37 |
+
1.7708,
|
| 38 |
+
2.6052,
|
| 39 |
+
2.0743,
|
| 40 |
+
3.2687,
|
| 41 |
+
2.1526,
|
| 42 |
+
2.8652,
|
| 43 |
+
1.5579,
|
| 44 |
+
1.6382,
|
| 45 |
+
1.1253,
|
| 46 |
+
2.8251,
|
| 47 |
+
1.916
|
| 48 |
+
],
|
| 49 |
+
"num_res_blocks": 2,
|
| 50 |
+
"temperal_downsample": [
|
| 51 |
+
false,
|
| 52 |
+
true,
|
| 53 |
+
true
|
| 54 |
+
],
|
| 55 |
+
"z_dim": 16
|
| 56 |
+
}
|
vae/config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKLQwenImage",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"attn_scales": [],
|
| 5 |
+
"base_dim": 96,
|
| 6 |
+
"dim_mult": [
|
| 7 |
+
1,
|
| 8 |
+
2,
|
| 9 |
+
4,
|
| 10 |
+
4
|
| 11 |
+
],
|
| 12 |
+
"dropout": 0.0,
|
| 13 |
+
"latents_mean": [
|
| 14 |
+
-0.7571,
|
| 15 |
+
-0.7089,
|
| 16 |
+
-0.9113,
|
| 17 |
+
0.1075,
|
| 18 |
+
-0.1745,
|
| 19 |
+
0.9653,
|
| 20 |
+
-0.1517,
|
| 21 |
+
1.5508,
|
| 22 |
+
0.4134,
|
| 23 |
+
-0.0715,
|
| 24 |
+
0.5517,
|
| 25 |
+
-0.3632,
|
| 26 |
+
-0.1922,
|
| 27 |
+
-0.9497,
|
| 28 |
+
0.2503,
|
| 29 |
+
-0.2921
|
| 30 |
+
],
|
| 31 |
+
"latents_std": [
|
| 32 |
+
2.8184,
|
| 33 |
+
1.4541,
|
| 34 |
+
2.3275,
|
| 35 |
+
2.6558,
|
| 36 |
+
1.2196,
|
| 37 |
+
1.7708,
|
| 38 |
+
2.6052,
|
| 39 |
+
2.0743,
|
| 40 |
+
3.2687,
|
| 41 |
+
2.1526,
|
| 42 |
+
2.8652,
|
| 43 |
+
1.5579,
|
| 44 |
+
1.6382,
|
| 45 |
+
1.1253,
|
| 46 |
+
2.8251,
|
| 47 |
+
1.916
|
| 48 |
+
],
|
| 49 |
+
"num_res_blocks": 2,
|
| 50 |
+
"temperal_downsample": [
|
| 51 |
+
false,
|
| 52 |
+
true,
|
| 53 |
+
true
|
| 54 |
+
],
|
| 55 |
+
"z_dim": 16
|
| 56 |
+
}
|
vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c8bc8b758c649abef9ea407b95408389a3b2f610d0d10fcb054fe171d0a8344
|
| 3 |
+
size 253806966
|
wandb/debug-cli.root.log
ADDED
|
File without changes
|
wandb/debug-internal.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
wandb/debug.log
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
|
| 2 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Configure stats pid to 14112
|
| 3 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug.log
|
| 5 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log
|
| 6 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():850] calling init triggers
|
| 7 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'batch_size': 16, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
|
| 9 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():898] starting backend
|
| 10 |
+
2026-04-28 17:16:45,343 INFO MainThread:14112 [wandb_init.py:init():913] sending inform_init request
|
| 11 |
+
2026-04-28 17:16:45,731 INFO MainThread:14112 [wandb_init.py:init():918] backend started and connected
|
| 12 |
+
2026-04-28 17:16:45,734 INFO MainThread:14112 [wandb_init.py:init():988] updated telemetry
|
| 13 |
+
2026-04-28 17:16:45,742 INFO MainThread:14112 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-04-28 17:16:46,973 INFO MainThread:14112 [wandb_init.py:init():1056] starting run threads in backend
|
| 15 |
+
2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_console_start():2554] atexit reg
|
| 16 |
+
2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_redirect():2403] redirect: wrap_raw
|
| 17 |
+
2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2472] Wrapping output streams.
|
| 18 |
+
2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2495] Redirects installed.
|
| 19 |
+
2026-04-28 17:16:47,104 INFO MainThread:14112 [wandb_init.py:init():1094] run started, returning control to user process
|
wandb/offline-run-20260428_132658-o9052r27/files/requirements.txt
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cuda-toolkit==13.0.2
|
| 2 |
+
typing_extensions==4.15.0
|
| 3 |
+
nvidia-nvjitlink==13.0.88
|
| 4 |
+
MarkupSafe==3.0.3
|
| 5 |
+
nvidia-cufile==1.15.1.6
|
| 6 |
+
certifi==2026.4.22
|
| 7 |
+
nvidia-cusolver==12.0.4.66
|
| 8 |
+
nvidia-curand==10.4.0.35
|
| 9 |
+
Jinja2==3.1.6
|
| 10 |
+
nvidia-nvtx==13.0.85
|
| 11 |
+
nvidia-cuda-cupti==13.0.85
|
| 12 |
+
torchaudio==2.11.0+cu130
|
| 13 |
+
safetensors==0.7.0
|
| 14 |
+
nvidia-cuda-runtime==13.0.96
|
| 15 |
+
torchvision==0.26.0+cu130
|
| 16 |
+
nvidia-cufft==12.0.0.61
|
| 17 |
+
nvidia-cusparse==12.6.3.3
|
| 18 |
+
nvidia-cuda-nvrtc==13.0.88
|
| 19 |
+
fsspec==2026.2.0
|
| 20 |
+
nvidia-cusparselt-cu13==0.8.0
|
| 21 |
+
nvidia-nccl-cu13==2.28.9
|
| 22 |
+
nvidia-nvshmem-cu13==3.4.5
|
| 23 |
+
nvidia-cublas==13.1.0.3
|
| 24 |
+
nvidia-cudnn-cu13==9.19.0.56
|
| 25 |
+
mpmath==1.3.0
|
| 26 |
+
triton==3.6.0
|
| 27 |
+
networkx==3.6.1
|
| 28 |
+
sympy==1.14.0
|
| 29 |
+
torch==2.11.0+cu130
|
| 30 |
+
hf_transfer==0.1.9
|
| 31 |
+
six==1.17.0
|
| 32 |
+
typer==0.25.0
|
| 33 |
+
typing-inspection==0.4.2
|
| 34 |
+
muon-adamw8bit==0.5.0
|
| 35 |
+
aiosignal==1.4.0
|
| 36 |
+
wurlitzer==3.1.1
|
| 37 |
+
semantic-version==2.10.0
|
| 38 |
+
aiohappyeyeballs==2.6.1
|
| 39 |
+
cycler==0.12.1
|
| 40 |
+
tokenizers==0.22.2
|
| 41 |
+
annotated-doc==0.0.4
|
| 42 |
+
rpds-py==0.30.0
|
| 43 |
+
configobj==5.0.9
|
| 44 |
+
regex==2026.4.4
|
| 45 |
+
zipp==3.23.1
|
| 46 |
+
annotated-types==0.7.0
|
| 47 |
+
everett==3.1.0
|
| 48 |
+
pydantic_core==2.46.3
|
| 49 |
+
mdurl==0.1.2
|
| 50 |
+
platformdirs==4.9.6
|
| 51 |
+
idna==3.13
|
| 52 |
+
psutil==7.2.2
|
| 53 |
+
xxhash==3.7.0
|
| 54 |
+
smmap==5.0.3
|
| 55 |
+
frozenlist==1.8.0
|
| 56 |
+
multidict==6.7.1
|
| 57 |
+
shellingham==1.5.4
|
| 58 |
+
kiwisolver==1.5.0
|
| 59 |
+
propcache==0.4.1
|
| 60 |
+
h11==0.16.0
|
| 61 |
+
hf-xet==1.4.3
|
| 62 |
+
pyparsing==3.3.2
|
| 63 |
+
yarl==1.23.0
|
| 64 |
+
importlib_metadata==9.0.0
|
| 65 |
+
referencing==0.37.0
|
| 66 |
+
requests==2.33.1
|
| 67 |
+
filelock==3.29.0
|
| 68 |
+
charset-normalizer==3.4.7
|
| 69 |
+
wrapt==2.1.2
|
| 70 |
+
contourpy==1.3.3
|
| 71 |
+
python-box==6.1.0
|
| 72 |
+
python-dateutil==2.9.0.post0
|
| 73 |
+
packaging==26.2
|
| 74 |
+
httpx==0.28.1
|
| 75 |
+
PyYAML==6.0.3
|
| 76 |
+
click==8.3.3
|
| 77 |
+
jsonschema-specifications==2025.9.1
|
| 78 |
+
gitdb==4.0.12
|
| 79 |
+
einops==0.8.2
|
| 80 |
+
attrs==26.1.0
|
| 81 |
+
httpcore==1.0.9
|
| 82 |
+
cuda-pathfinder==1.5.4
|
| 83 |
+
requests-toolbelt==1.0.0
|
| 84 |
+
GitPython==3.1.48
|
| 85 |
+
jsonschema==4.26.0
|
| 86 |
+
tqdm==4.67.3
|
| 87 |
+
urllib3==2.6.3
|
| 88 |
+
anyio==4.13.0
|
| 89 |
+
simplejson==4.1.1
|
| 90 |
+
multiprocess==0.70.19
|
| 91 |
+
dill==0.4.1
|
| 92 |
+
protobuf==7.34.1
|
| 93 |
+
markdown-it-py==4.0.0
|
| 94 |
+
bitsandbytes==0.49.2
|
| 95 |
+
cuda-bindings==13.2.0
|
| 96 |
+
aiohttp==3.13.5
|
| 97 |
+
accelerate==1.13.0
|
| 98 |
+
dulwich==0.25.2
|
| 99 |
+
pydantic==2.13.3
|
| 100 |
+
datasets==4.8.5
|
| 101 |
+
rich==15.0.0
|
| 102 |
+
flash-linear-attention==0.5.0
|
| 103 |
+
pillow==12.2.0
|
| 104 |
+
huggingface_hub==1.12.0
|
| 105 |
+
sentry-sdk==2.58.0
|
| 106 |
+
fla-core==0.5.0
|
| 107 |
+
Pygments==2.20.0
|
| 108 |
+
diffusers==0.37.1
|
| 109 |
+
fonttools==4.62.1
|
| 110 |
+
comet_ml==3.57.3
|
| 111 |
+
setuptools==81.0.0
|
| 112 |
+
matplotlib==3.10.9
|
| 113 |
+
pyarrow==24.0.0
|
| 114 |
+
wandb==0.26.1
|
| 115 |
+
numpy==2.4.4
|
| 116 |
+
pandas==3.0.2
|
| 117 |
+
transformers==5.6.2
|
wandb/offline-run-20260428_132658-o9052r27/logs/debug-core.log
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-04-28T13:26:58.701599632Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpll17c794/port-6681.txt","pid":6681,"detached":false,"idle-timeout":600000000000,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
|
| 2 |
+
{"time":"2026-04-28T13:26:58.704326543Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":6681}
|
| 3 |
+
{"time":"2026-04-28T13:26:58.70424692Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-6681-6712-3956627621/socket","Net":"unix"}}
|
| 4 |
+
{"time":"2026-04-28T13:26:58.806869406Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
|
| 5 |
+
{"time":"2026-04-28T13:26:58.828765063Z","level":"INFO","msg":"handleInformInit: received","streamId":"o9052r27","id":"1(@)"}
|
| 6 |
+
{"time":"2026-04-28T13:26:58.960660655Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"o9052r27","id":"1(@)"}
|
| 7 |
+
{"time":"2026-04-28T13:27:04.392467558Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1(@)"}
|
| 8 |
+
{"time":"2026-04-28T13:27:04.392527721Z","level":"INFO","msg":"server is shutting down"}
|
| 9 |
+
{"time":"2026-04-28T13:27:04.392535141Z","level":"INFO","msg":"connection: closing","id":"1(@)"}
|
| 10 |
+
{"time":"2026-04-28T13:27:04.392635535Z","level":"INFO","msg":"connection: closed successfully","id":"1(@)"}
|
| 11 |
+
{"time":"2026-04-28T13:27:04.392627225Z","level":"INFO","msg":"server: listener closed","addr":{"Name":"/tmp/wandb-6681-6712-3956627621/socket","Net":"unix"}}
|
| 12 |
+
{"time":"2026-04-28T13:27:04.421552415Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1(@)"}
|
| 13 |
+
{"time":"2026-04-28T13:27:04.421573556Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1(@)"}
|
| 14 |
+
{"time":"2026-04-28T13:27:04.421579966Z","level":"INFO","msg":"server is closed"}
|
wandb/offline-run-20260428_132658-o9052r27/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-04-28T13:26:58.829048314Z","level":"INFO","msg":"wandb-core"}
|
| 2 |
+
{"time":"2026-04-28T13:26:58.829092766Z","level":"INFO","msg":"stream: starting","core version":"0.26.1"}
|
| 3 |
+
{"time":"2026-04-28T13:26:58.960323061Z","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
|
| 4 |
+
{"time":"2026-04-28T13:26:58.960354402Z","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
|
| 5 |
+
{"time":"2026-04-28T13:26:58.960391424Z","level":"INFO","msg":"stream: created new stream","id":"o9052r27"}
|
| 6 |
+
{"time":"2026-04-28T13:26:58.960480497Z","level":"INFO","msg":"handler: started"}
|
| 7 |
+
{"time":"2026-04-28T13:26:58.960646764Z","level":"INFO","msg":"stream: started"}
|
| 8 |
+
{"time":"2026-04-28T13:26:58.960704477Z","level":"INFO","msg":"writer: started","stream_id":"o9052r27"}
|
| 9 |
+
{"time":"2026-04-28T13:26:58.960767929Z","level":"INFO","msg":"sender: started"}
|
| 10 |
+
{"time":"2026-04-28T13:26:58.975123911Z","level":"WARN","msg":"featurechecker: GraphQL client is nil, skipping feature loading"}
|
| 11 |
+
{"time":"2026-04-28T13:26:58.975175533Z","level":"WARN","msg":"runupserter: server does not expand metric globs but the x_server_side_expand_glob_metrics setting is set; ignoring"}
|
| 12 |
+
{"time":"2026-04-28T13:27:04.392744599Z","level":"INFO","msg":"stream: finishing up"}
|
| 13 |
+
{"time":"2026-04-28T13:27:04.39276658Z","level":"INFO","msg":"handler: closed"}
|
| 14 |
+
{"time":"2026-04-28T13:27:04.392811252Z","level":"INFO","msg":"sender: closed"}
|
| 15 |
+
{"time":"2026-04-28T13:27:04.392819012Z","level":"INFO","msg":"stream: all finished"}
|
wandb/offline-run-20260428_132658-o9052r27/logs/debug.log
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
|
| 2 |
+
2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_setup.py:_flush():81] Configure stats pid to 6681
|
| 3 |
+
2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-04-28 13:26:58,591 INFO MainThread:6681 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /root/sdxs-2b/wandb/offline-run-20260428_132658-o9052r27/logs/debug.log
|
| 5 |
+
2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /root/sdxs-2b/wandb/offline-run-20260428_132658-o9052r27/logs/debug-internal.log
|
| 6 |
+
2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:init():850] calling init triggers
|
| 7 |
+
2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'batch_size': 7, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
|
| 9 |
+
2026-04-28 13:26:58,592 INFO MainThread:6681 [wandb_init.py:init():898] starting backend
|
| 10 |
+
2026-04-28 13:26:58,807 INFO MainThread:6681 [wandb_init.py:init():913] sending inform_init request
|
| 11 |
+
2026-04-28 13:26:58,961 INFO MainThread:6681 [wandb_init.py:init():918] backend started and connected
|
| 12 |
+
2026-04-28 13:26:58,964 INFO MainThread:6681 [wandb_init.py:init():988] updated telemetry
|
| 13 |
+
2026-04-28 13:26:58,971 INFO MainThread:6681 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-04-28 13:26:58,977 INFO MainThread:6681 [wandb_init.py:init():1056] starting run threads in backend
|
| 15 |
+
2026-04-28 13:26:59,098 INFO MainThread:6681 [wandb_run.py:_console_start():2554] atexit reg
|
| 16 |
+
2026-04-28 13:26:59,098 INFO MainThread:6681 [wandb_run.py:_redirect():2403] redirect: wrap_raw
|
| 17 |
+
2026-04-28 13:26:59,099 INFO MainThread:6681 [wandb_run.py:_redirect():2472] Wrapping output streams.
|
| 18 |
+
2026-04-28 13:26:59,099 INFO MainThread:6681 [wandb_run.py:_redirect():2495] Redirects installed.
|
| 19 |
+
2026-04-28 13:26:59,115 INFO MainThread:6681 [wandb_init.py:init():1094] run started, returning control to user process
|
| 20 |
+
2026-04-28 13:27:04,393 INFO wandb-AsyncioManager-main:6681 [service_client.py:_forward_responses():134] Reached EOF.
|
| 21 |
+
2026-04-28 13:27:04,393 INFO wandb-AsyncioManager-main:6681 [mailbox.py:close():155] Closing mailbox, abandoning 0 handles.
|
wandb/offline-run-20260428_132658-o9052r27/run-o9052r27.wandb
ADDED
|
Binary file (6.41 kB). View file
|
|
|
wandb/run-20260428_171645-wt40fdyx/files/output.log
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The config attributes {'final_sigmas_type': 'sigma_min', 'sigma_data': 1.0, 'sigma_max': 80.0, 'sigma_min': 0.002} were passed to FlowMatchEulerDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
|
| 2 |
+
[transformers] The fast path is not available because one of the required library is not installed. Falling back to torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and https://github.com/Dao-AILab/causal-conv1d
|
| 3 |
+
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████| 473/473 [00:00<00:00, 766.86it/s]
|
| 4 |
+
images: 233407
|
| 5 |
+
Total samples 14580
|
| 6 |
+
Загружаем Transformer из чекпоинта: transformer
|
| 7 |
+
--- РАЗМОРОЖЕННЫЕ СЛОИ ---
|
| 8 |
+
--------------------------
|
| 9 |
+
|
| 10 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 44, 80])
|
| 11 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 44, 80])
|
| 12 |
+
|
| 13 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 48, 80])
|
| 14 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 48, 80])
|
| 15 |
+
|
| 16 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 40])
|
| 17 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 40])
|
| 18 |
+
|
| 19 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 44])
|
| 20 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 44])
|
| 21 |
+
|
| 22 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 48])
|
| 23 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 48])
|
| 24 |
+
|
| 25 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 52])
|
| 26 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 52])
|
| 27 |
+
|
| 28 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 56])
|
| 29 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 56])
|
| 30 |
+
|
| 31 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 52, 80])
|
| 32 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 52, 80])
|
| 33 |
+
|
| 34 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 56, 80])
|
| 35 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 56, 80])
|
| 36 |
+
|
| 37 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 60, 80])
|
| 38 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 60, 80])
|
| 39 |
+
|
| 40 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 60])
|
| 41 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 60])
|
| 42 |
+
|
| 43 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 64, 80])
|
| 44 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 64, 80])
|
| 45 |
+
|
| 46 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 68, 80])
|
| 47 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 68, 80])
|
| 48 |
+
|
| 49 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 72, 80])
|
| 50 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 72, 80])
|
| 51 |
+
|
| 52 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 76, 80])
|
| 53 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 76, 80])
|
| 54 |
+
|
| 55 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 64])
|
| 56 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 64])
|
| 57 |
+
|
| 58 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 68])
|
| 59 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 68])
|
| 60 |
+
|
| 61 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 72])
|
| 62 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 72])
|
| 63 |
+
|
| 64 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 80, 76])
|
| 65 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 80, 76])
|
| 66 |
+
|
| 67 |
+
[ПОЧИНКА FIXED SAMPLES] Отсекаем мусор: torch.Size([1, 16, 16, 40, 80])
|
| 68 |
+
[ОТЛАДКА ДАТАСЕТА] latents final shape: torch.Size([1, 16, 1, 40, 80])
|
| 69 |
+
Создано 20 групп фиксированных семплов по разрешениям
|
| 70 |
+
Генерация сэмплов до старта обучения...
|
| 71 |
+
/usr/lib/python3.12/contextlib.py:105: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
|
| 72 |
+
self.gen = func(*args, **kwds)
|
| 73 |
+
|
| 74 |
+
==================================================
|
| 75 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 76 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 44, 80])
|
| 77 |
+
min=-1.9811, max=2.2364, std=0.6226
|
| 78 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 79 |
+
sigma_data=1.0
|
| 80 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 44, 80])
|
| 81 |
+
min=-3.9773, max=3.4072, std=1.1547
|
| 82 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 352, 640])
|
| 83 |
+
min=-1.0000, max=0.9945, std=0.6423
|
| 84 |
+
==================================================
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
==================================================
|
| 88 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 89 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 48, 80])
|
| 90 |
+
min=-2.1993, max=1.9178, std=0.5500
|
| 91 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 92 |
+
sigma_data=1.0
|
| 93 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 48, 80])
|
| 94 |
+
min=-3.5397, max=3.3824, std=1.0561
|
| 95 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 384, 640])
|
| 96 |
+
min=-1.0000, max=1.0000, std=0.3971
|
| 97 |
+
==================================================
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
==================================================
|
| 101 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 102 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 40])
|
| 103 |
+
min=-2.7174, max=2.0244, std=0.6368
|
| 104 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 105 |
+
sigma_data=1.0
|
| 106 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 40])
|
| 107 |
+
min=-4.0544, max=4.0678, std=1.1537
|
| 108 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 320])
|
| 109 |
+
min=-0.9997, max=1.0000, std=0.5404
|
| 110 |
+
==================================================
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
==================================================
|
| 114 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 115 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 44])
|
| 116 |
+
min=-2.0394, max=2.0944, std=0.5736
|
| 117 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 118 |
+
sigma_data=1.0
|
| 119 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 44])
|
| 120 |
+
min=-3.8287, max=3.3714, std=1.0290
|
| 121 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 352])
|
| 122 |
+
min=-1.0000, max=1.0000, std=0.4719
|
| 123 |
+
==================================================
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
==================================================
|
| 127 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 128 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 48])
|
| 129 |
+
min=-2.0441, max=1.9221, std=0.5108
|
| 130 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 131 |
+
sigma_data=1.0
|
| 132 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 48])
|
| 133 |
+
min=-3.4324, max=3.7347, std=0.9750
|
| 134 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 384])
|
| 135 |
+
min=-1.0000, max=1.0000, std=0.5049
|
| 136 |
+
==================================================
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
==================================================
|
| 140 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 141 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 52])
|
| 142 |
+
min=-2.0292, max=2.2682, std=0.7043
|
| 143 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 144 |
+
sigma_data=1.0
|
| 145 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 52])
|
| 146 |
+
min=-4.1673, max=4.4971, std=1.3949
|
| 147 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 416])
|
| 148 |
+
min=-1.0000, max=1.0000, std=0.6222
|
| 149 |
+
==================================================
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
==================================================
|
| 153 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 154 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 56])
|
| 155 |
+
min=-1.7528, max=1.6711, std=0.6432
|
| 156 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 157 |
+
sigma_data=1.0
|
| 158 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 56])
|
| 159 |
+
min=-4.0104, max=4.1834, std=1.4406
|
| 160 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 448])
|
| 161 |
+
min=-0.9654, max=1.0000, std=0.4818
|
| 162 |
+
==================================================
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
==================================================
|
| 166 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 167 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 52, 80])
|
| 168 |
+
min=-2.0965, max=2.2269, std=0.4286
|
| 169 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 170 |
+
sigma_data=1.0
|
| 171 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 52, 80])
|
| 172 |
+
min=-3.3608, max=2.9338, std=0.9200
|
| 173 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 416, 640])
|
| 174 |
+
min=-1.0000, max=0.9774, std=0.3019
|
| 175 |
+
==================================================
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
==================================================
|
| 179 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 180 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 56, 80])
|
| 181 |
+
min=-2.3215, max=2.6622, std=0.6174
|
| 182 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 183 |
+
sigma_data=1.0
|
| 184 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 56, 80])
|
| 185 |
+
min=-3.6939, max=4.7696, std=1.3130
|
| 186 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 448, 640])
|
| 187 |
+
min=-1.0000, max=1.0000, std=0.4811
|
| 188 |
+
==================================================
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
==================================================
|
| 192 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 193 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 60, 80])
|
| 194 |
+
min=-2.2899, max=2.1393, std=0.5506
|
| 195 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 196 |
+
sigma_data=1.0
|
| 197 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 60, 80])
|
| 198 |
+
min=-4.0351, max=4.0100, std=1.1577
|
| 199 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 480, 640])
|
| 200 |
+
min=-1.0000, max=1.0000, std=0.6317
|
| 201 |
+
==================================================
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
==================================================
|
| 205 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 206 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 60])
|
| 207 |
+
min=-1.8058, max=2.0032, std=0.5188
|
| 208 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 209 |
+
sigma_data=1.0
|
| 210 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 60])
|
| 211 |
+
min=-3.2342, max=3.6659, std=1.0352
|
| 212 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 480])
|
| 213 |
+
min=-1.0000, max=1.0000, std=0.6372
|
| 214 |
+
==================================================
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
==================================================
|
| 218 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 219 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 64, 80])
|
| 220 |
+
min=-2.1774, max=2.1568, std=0.6666
|
| 221 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 222 |
+
sigma_data=1.0
|
| 223 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 64, 80])
|
| 224 |
+
min=-4.7810, max=5.1935, std=1.3580
|
| 225 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 512, 640])
|
| 226 |
+
min=-1.0000, max=1.0000, std=0.5784
|
| 227 |
+
==================================================
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
==================================================
|
| 231 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 232 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 68, 80])
|
| 233 |
+
min=-1.9091, max=2.1057, std=0.5661
|
| 234 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 235 |
+
sigma_data=1.0
|
| 236 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 68, 80])
|
| 237 |
+
min=-3.4599, max=3.7540, std=1.0538
|
| 238 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 544, 640])
|
| 239 |
+
min=-1.0000, max=1.0000, std=0.6665
|
| 240 |
+
==================================================
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
==================================================
|
| 244 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 245 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 72, 80])
|
| 246 |
+
min=-2.1917, max=2.3725, std=0.6957
|
| 247 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 248 |
+
sigma_data=1.0
|
| 249 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 72, 80])
|
| 250 |
+
min=-3.8205, max=4.1090, std=1.5053
|
| 251 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 576, 640])
|
| 252 |
+
min=-1.0000, max=1.0000, std=0.6376
|
| 253 |
+
==================================================
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
==================================================
|
| 257 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 258 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 76, 80])
|
| 259 |
+
min=-2.3168, max=2.0439, std=0.6811
|
| 260 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 261 |
+
sigma_data=1.0
|
| 262 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 76, 80])
|
| 263 |
+
min=-3.8838, max=4.5797, std=1.3369
|
| 264 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 608, 640])
|
| 265 |
+
min=-1.0000, max=1.0000, std=0.6667
|
| 266 |
+
==================================================
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
==================================================
|
| 270 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 271 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 64])
|
| 272 |
+
min=-2.2767, max=2.3007, std=0.5141
|
| 273 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 274 |
+
sigma_data=1.0
|
| 275 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 64])
|
| 276 |
+
min=-3.7021, max=3.3769, std=0.8752
|
| 277 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 512])
|
| 278 |
+
min=-1.0000, max=1.0000, std=0.4680
|
| 279 |
+
==================================================
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
==================================================
|
| 283 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 284 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 68])
|
| 285 |
+
min=-2.3068, max=2.3424, std=0.7115
|
| 286 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 287 |
+
sigma_data=1.0
|
| 288 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 68])
|
| 289 |
+
min=-3.9636, max=4.6402, std=1.4684
|
| 290 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 544])
|
| 291 |
+
min=-0.9553, max=1.0000, std=0.4083
|
| 292 |
+
==================================================
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
==================================================
|
| 296 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 297 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 72])
|
| 298 |
+
min=-2.3526, max=2.5922, std=0.7641
|
| 299 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 300 |
+
sigma_data=1.0
|
| 301 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 72])
|
| 302 |
+
min=-4.1452, max=4.7889, std=1.6258
|
| 303 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 576])
|
| 304 |
+
min=-1.0000, max=1.0000, std=0.7539
|
| 305 |
+
==================================================
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
==================================================
|
| 309 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 310 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 80, 76])
|
| 311 |
+
min=-1.7528, max=1.9838, std=0.4715
|
| 312 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 313 |
+
sigma_data=1.0
|
| 314 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 80, 76])
|
| 315 |
+
min=-3.3567, max=3.4733, std=1.0891
|
| 316 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 640, 608])
|
| 317 |
+
min=-0.9685, max=0.9913, std=0.4626
|
| 318 |
+
==================================================
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
==================================================
|
| 322 |
+
[ОТЛАДКА VAE DECODE НА НУЛЕВОМ ШАГЕ]
|
| 323 |
+
1. current_latents: shape=torch.Size([1, 16, 1, 40, 80])
|
| 324 |
+
min=-2.1861, max=2.0078, std=0.5870
|
| 325 |
+
2. l_std shape=torch.Size([1, 16, 1, 1, 1]), l_mean shape=torch.Size([1, 16, 1, 1, 1])
|
| 326 |
+
sigma_data=1.0
|
| 327 |
+
3. latents_for_decode: shape=torch.Size([1, 16, 1, 40, 80])
|
| 328 |
+
min=-4.1700, max=4.9017, std=1.3442
|
| 329 |
+
4. decoded_fp32 (после VAE): shape=torch.Size([1, 3, 320, 640])
|
| 330 |
+
min=-0.8055, max=0.9961, std=0.3967
|
| 331 |
+
==================================================
|
| 332 |
+
|
| 333 |
+
Total steps per GPU: 14580
|
| 334 |
+
Training: 36%|████████████████████████████▉ | 5205/14580 [15:41:32<27:32:51, 10.58s/step]
|
| 335 |
+
Saving: True | Max: 0.1232 | Last: 0.1094 | Avg: 0.1232
|
| 336 |
+
Saving: True | Max: 0.1423 | Last: 0.1423 | Avg: 0.1245
|
| 337 |
+
Saving: True | Max: 0.1248 | Last: 0.0964 | Avg: 0.1248
|
| 338 |
+
Saving: True | Max: 0.1423 | Last: 0.1423 | Avg: 0.1246
|
| 339 |
+
Saving: True | Max: 0.1417 | Last: 0.1417 | Avg: 0.1236
|
| 340 |
+
Saving: True | Max: 0.1405 | Last: 0.1405 | Avg: 0.1234
|
| 341 |
+
Saving: True | Max: 0.1534 | Last: 0.1534 | Avg: 0.1232
|
| 342 |
+
Saving: True | Max: 0.1559 | Last: 0.1559 | Avg: 0.1234
|
| 343 |
+
Saving: True | Max: 0.1231 | Last: 0.1104 | Avg: 0.1231
|
| 344 |
+
Saving: True | Max: 0.1229 | Last: 0.1139 | Avg: 0.1229
|
| 345 |
+
Saving: True | Max: 0.1228 | Last: 0.1152 | Avg: 0.1228
|
| 346 |
+
Saving: True | Max: 0.1485 | Last: 0.1485 | Avg: 0.1228
|
| 347 |
+
Saving: True | Max: 0.1231 | Last: 0.0736 | Avg: 0.1231
|
| 348 |
+
Saving: True | Max: 0.1232 | Last: 0.1016 | Avg: 0.1232
|
| 349 |
+
Saving: True | Max: 0.1519 | Last: 0.1519 | Avg: 0.1234
|
| 350 |
+
Saving: True | Max: 0.1233 | Last: 0.1096 | Avg: 0.1233
|
| 351 |
+
Saving: True | Max: 0.1232 | Last: 0.1051 | Avg: 0.1232
|
| 352 |
+
Saving: True | Max: 0.1234 | Last: 0.1173 | Avg: 0.1234
|
| 353 |
+
Saving: True | Max: 0.1233 | Last: 0.1168 | Avg: 0.1233
|
| 354 |
+
Saving: True | Max: 0.1309 | Last: 0.1309 | Avg: 0.1229
|
| 355 |
+
Saving: True | Max: 0.1432 | Last: 0.1432 | Avg: 0.1227
|
| 356 |
+
Saving: True | Max: 0.1226 | Last: 0.1211 | Avg: 0.1226
|
| 357 |
+
Saving: True | Max: 0.1227 | Last: 0.1227 | Avg: 0.1221
|
| 358 |
+
Saving: True | Max: 0.1219 | Last: 0.1029 | Avg: 0.1219
|
| 359 |
+
Saving: True | Max: 0.1217 | Last: 0.1058 | Avg: 0.1217
|
| 360 |
+
Saving: True | Max: 0.1218 | Last: 0.1206 | Avg: 0.1218
|
| 361 |
+
Saving: True | Max: 0.1379 | Last: 0.1379 | Avg: 0.1221
|
| 362 |
+
Saving: True | Max: 0.1228 | Last: 0.1012 | Avg: 0.1228
|
| 363 |
+
Saving: True | Max: 0.1226 | Last: 0.1121 | Avg: 0.1226
|
| 364 |
+
Saving: True | Max: 0.1226 | Last: 0.0930 | Avg: 0.1226
|
| 365 |
+
Saving: False | Max: 0.1564 | Last: 0.1564 | Avg: 0.1230
|
| 366 |
+
Saving: True | Max: 0.1266 | Last: 0.1266 | Avg: 0.1234
|
| 367 |
+
Saving: True | Max: 0.1234 | Last: 0.1050 | Avg: 0.1234
|
| 368 |
+
Saving: True | Max: 0.1235 | Last: 0.1031 | Avg: 0.1235
|
| 369 |
+
Saving: True | Max: 0.1235 | Last: 0.0956 | Avg: 0.1235
|
| 370 |
+
Saving: True | Max: 0.1233 | Last: 0.1117 | Avg: 0.1233
|
| 371 |
+
Saving: False | Max: 0.1559 | Last: 0.1559 | Avg: 0.1229
|
| 372 |
+
Saving: True | Max: 0.1532 | Last: 0.1532 | Avg: 0.1234
|
| 373 |
+
Saving: True | Max: 0.1248 | Last: 0.1248 | Avg: 0.1231
|
| 374 |
+
Saving: True | Max: 0.1445 | Last: 0.1445 | Avg: 0.1228
|
| 375 |
+
Saving: True | Max: 0.1514 | Last: 0.1514 | Avg: 0.1229
|
| 376 |
+
Saving: True | Max: 0.1225 | Last: 0.1021 | Avg: 0.1225
|
| 377 |
+
Saving: True | Max: 0.1317 | Last: 0.1317 | Avg: 0.1221
|
| 378 |
+
Saving: True | Max: 0.1220 | Last: 0.1002 | Avg: 0.1220
|
| 379 |
+
Saving: True | Max: 0.1321 | Last: 0.1321 | Avg: 0.1221
|
| 380 |
+
Saving: True | Max: 0.1260 | Last: 0.1260 | Avg: 0.1218
|
| 381 |
+
Saving: True | Max: 0.1212 | Last: 0.1191 | Avg: 0.1212
|
| 382 |
+
Saving: True | Max: 0.1213 | Last: 0.1155 | Avg: 0.1213
|
| 383 |
+
Saving: True | Max: 0.1212 | Last: 0.1138 | Avg: 0.1212
|
| 384 |
+
Saving: True | Max: 0.1213 | Last: 0.1184 | Avg: 0.1213
|
| 385 |
+
Saving: True | Max: 0.1371 | Last: 0.1371 | Avg: 0.1215
|
wandb/run-20260428_171645-wt40fdyx/files/requirements.txt
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cuda-toolkit==13.0.2
|
| 2 |
+
typing_extensions==4.15.0
|
| 3 |
+
nvidia-nvjitlink==13.0.88
|
| 4 |
+
MarkupSafe==3.0.3
|
| 5 |
+
nvidia-cufile==1.15.1.6
|
| 6 |
+
certifi==2026.4.22
|
| 7 |
+
nvidia-cusolver==12.0.4.66
|
| 8 |
+
nvidia-curand==10.4.0.35
|
| 9 |
+
Jinja2==3.1.6
|
| 10 |
+
nvidia-nvtx==13.0.85
|
| 11 |
+
nvidia-cuda-cupti==13.0.85
|
| 12 |
+
torchaudio==2.11.0+cu130
|
| 13 |
+
safetensors==0.7.0
|
| 14 |
+
nvidia-cuda-runtime==13.0.96
|
| 15 |
+
torchvision==0.26.0+cu130
|
| 16 |
+
nvidia-cufft==12.0.0.61
|
| 17 |
+
nvidia-cusparse==12.6.3.3
|
| 18 |
+
nvidia-cuda-nvrtc==13.0.88
|
| 19 |
+
fsspec==2026.2.0
|
| 20 |
+
nvidia-cusparselt-cu13==0.8.0
|
| 21 |
+
nvidia-nccl-cu13==2.28.9
|
| 22 |
+
nvidia-nvshmem-cu13==3.4.5
|
| 23 |
+
nvidia-cublas==13.1.0.3
|
| 24 |
+
nvidia-cudnn-cu13==9.19.0.56
|
| 25 |
+
mpmath==1.3.0
|
| 26 |
+
triton==3.6.0
|
| 27 |
+
networkx==3.6.1
|
| 28 |
+
sympy==1.14.0
|
| 29 |
+
torch==2.11.0+cu130
|
| 30 |
+
hf_transfer==0.1.9
|
| 31 |
+
six==1.17.0
|
| 32 |
+
typer==0.25.0
|
| 33 |
+
typing-inspection==0.4.2
|
| 34 |
+
muon-adamw8bit==0.5.0
|
| 35 |
+
aiosignal==1.4.0
|
| 36 |
+
wurlitzer==3.1.1
|
| 37 |
+
semantic-version==2.10.0
|
| 38 |
+
aiohappyeyeballs==2.6.1
|
| 39 |
+
cycler==0.12.1
|
| 40 |
+
tokenizers==0.22.2
|
| 41 |
+
annotated-doc==0.0.4
|
| 42 |
+
rpds-py==0.30.0
|
| 43 |
+
configobj==5.0.9
|
| 44 |
+
regex==2026.4.4
|
| 45 |
+
zipp==3.23.1
|
| 46 |
+
annotated-types==0.7.0
|
| 47 |
+
everett==3.1.0
|
| 48 |
+
pydantic_core==2.46.3
|
| 49 |
+
mdurl==0.1.2
|
| 50 |
+
platformdirs==4.9.6
|
| 51 |
+
idna==3.13
|
| 52 |
+
psutil==7.2.2
|
| 53 |
+
xxhash==3.7.0
|
| 54 |
+
smmap==5.0.3
|
| 55 |
+
frozenlist==1.8.0
|
| 56 |
+
multidict==6.7.1
|
| 57 |
+
shellingham==1.5.4
|
| 58 |
+
kiwisolver==1.5.0
|
| 59 |
+
propcache==0.4.1
|
| 60 |
+
h11==0.16.0
|
| 61 |
+
hf-xet==1.4.3
|
| 62 |
+
pyparsing==3.3.2
|
| 63 |
+
yarl==1.23.0
|
| 64 |
+
importlib_metadata==9.0.0
|
| 65 |
+
referencing==0.37.0
|
| 66 |
+
requests==2.33.1
|
| 67 |
+
filelock==3.29.0
|
| 68 |
+
charset-normalizer==3.4.7
|
| 69 |
+
wrapt==2.1.2
|
| 70 |
+
contourpy==1.3.3
|
| 71 |
+
python-box==6.1.0
|
| 72 |
+
python-dateutil==2.9.0.post0
|
| 73 |
+
packaging==26.2
|
| 74 |
+
httpx==0.28.1
|
| 75 |
+
PyYAML==6.0.3
|
| 76 |
+
click==8.3.3
|
| 77 |
+
jsonschema-specifications==2025.9.1
|
| 78 |
+
gitdb==4.0.12
|
| 79 |
+
einops==0.8.2
|
| 80 |
+
attrs==26.1.0
|
| 81 |
+
httpcore==1.0.9
|
| 82 |
+
cuda-pathfinder==1.5.4
|
| 83 |
+
requests-toolbelt==1.0.0
|
| 84 |
+
GitPython==3.1.48
|
| 85 |
+
jsonschema==4.26.0
|
| 86 |
+
tqdm==4.67.3
|
| 87 |
+
urllib3==2.6.3
|
| 88 |
+
anyio==4.13.0
|
| 89 |
+
simplejson==4.1.1
|
| 90 |
+
multiprocess==0.70.19
|
| 91 |
+
dill==0.4.1
|
| 92 |
+
protobuf==7.34.1
|
| 93 |
+
markdown-it-py==4.0.0
|
| 94 |
+
bitsandbytes==0.49.2
|
| 95 |
+
cuda-bindings==13.2.0
|
| 96 |
+
aiohttp==3.13.5
|
| 97 |
+
accelerate==1.13.0
|
| 98 |
+
dulwich==0.25.2
|
| 99 |
+
pydantic==2.13.3
|
| 100 |
+
datasets==4.8.5
|
| 101 |
+
rich==15.0.0
|
| 102 |
+
flash-linear-attention==0.5.0
|
| 103 |
+
pillow==12.2.0
|
| 104 |
+
huggingface_hub==1.12.0
|
| 105 |
+
sentry-sdk==2.58.0
|
| 106 |
+
fla-core==0.5.0
|
| 107 |
+
Pygments==2.20.0
|
| 108 |
+
diffusers==0.37.1
|
| 109 |
+
fonttools==4.62.1
|
| 110 |
+
comet_ml==3.57.3
|
| 111 |
+
setuptools==81.0.0
|
| 112 |
+
matplotlib==3.10.9
|
| 113 |
+
pyarrow==24.0.0
|
| 114 |
+
wandb==0.26.1
|
| 115 |
+
numpy==2.4.4
|
| 116 |
+
pandas==3.0.2
|
| 117 |
+
transformers==5.6.2
|
wandb/run-20260428_171645-wt40fdyx/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-6.8.0-110-generic-x86_64-with-glibc2.39",
|
| 3 |
+
"python": "CPython 3.12.3",
|
| 4 |
+
"startedAt": "2026-04-28T17:16:45.135482Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--batch",
|
| 7 |
+
"16",
|
| 8 |
+
"--lvl",
|
| 9 |
+
"1"
|
| 10 |
+
],
|
| 11 |
+
"program": "/root/sdxs-2b/train.py",
|
| 12 |
+
"codePath": "train.py",
|
| 13 |
+
"codePathLocal": "train.py",
|
| 14 |
+
"git": {
|
| 15 |
+
"remote": "https://huggingface.co/AiArtLab/sdxs-2b",
|
| 16 |
+
"commit": "ab8719f79299a6e86448b407298689048767b261"
|
| 17 |
+
},
|
| 18 |
+
"email": "vadim-kulibaba@yandex.ru",
|
| 19 |
+
"root": "/root/sdxs-2b",
|
| 20 |
+
"host": "O-1649582",
|
| 21 |
+
"executable": "/root/.venv/bin/python3",
|
| 22 |
+
"cpu_count": 48,
|
| 23 |
+
"cpu_count_logical": 96,
|
| 24 |
+
"gpu": "NVIDIA GeForce RTX 5090",
|
| 25 |
+
"gpu_count": 1,
|
| 26 |
+
"disk": {
|
| 27 |
+
"/": {
|
| 28 |
+
"total": "888178696192",
|
| 29 |
+
"used": "598432870400"
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
"memory": {
|
| 33 |
+
"total": "134889213952"
|
| 34 |
+
},
|
| 35 |
+
"gpu_nvidia": [
|
| 36 |
+
{
|
| 37 |
+
"name": "NVIDIA GeForce RTX 5090",
|
| 38 |
+
"memoryTotal": "34190917632",
|
| 39 |
+
"cudaCores": 21760,
|
| 40 |
+
"architecture": "Blackwell",
|
| 41 |
+
"uuid": "GPU-af06c899-cefd-2303-137f-17f69c648771"
|
| 42 |
+
}
|
| 43 |
+
],
|
| 44 |
+
"cudaVersion": "13.0",
|
| 45 |
+
"writerId": "9ndk10qtzdsvighcagxlxbtug93n98at"
|
| 46 |
+
}
|
wandb/run-20260428_171645-wt40fdyx/logs/debug-core.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-04-28T17:16:45.179418861Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpvlugdolt/port-14112.txt","pid":14112,"detached":false,"idle-timeout":600000000000,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
|
| 2 |
+
{"time":"2026-04-28T17:16:45.18135139Z","level":"INFO","msg":"server: will exit if parent process dies","ppid":14112}
|
| 3 |
+
{"time":"2026-04-28T17:16:45.181241106Z","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/tmp/wandb-14112-14129-3488405791/socket","Net":"unix"}}
|
| 4 |
+
{"time":"2026-04-28T17:16:45.343101118Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1(@)"}
|
| 5 |
+
{"time":"2026-04-28T17:16:45.350678398Z","level":"INFO","msg":"handleInformInit: received","streamId":"wt40fdyx","id":"1(@)"}
|
| 6 |
+
{"time":"2026-04-28T17:16:45.730308466Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"wt40fdyx","id":"1(@)"}
|
| 7 |
+
{"time":"2026-04-28T17:16:54.001250093Z","level":"INFO","msg":"connection: cancelling request","id":"1(@)","requestId":"lrv3btfsqddl"}
|
wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
wandb/run-20260428_171645-wt40fdyx/logs/debug.log
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Current SDK version is 0.26.1
|
| 2 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Configure stats pid to 14112
|
| 3 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():723] Logging user logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug.log
|
| 5 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:setup_run_log_directory():724] Logging internal logs to /root/sdxs-2b/wandb/run-20260428_171645-wt40fdyx/logs/debug-internal.log
|
| 6 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():850] calling init triggers
|
| 7 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():855] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'batch_size': 16, 'base_learning_rate': 1.3333333333333335e-05, 'num_epochs': 1, 'optimizer_type': 'adafactor', '_wandb': {}}
|
| 9 |
+
2026-04-28 17:16:45,138 INFO MainThread:14112 [wandb_init.py:init():898] starting backend
|
| 10 |
+
2026-04-28 17:16:45,343 INFO MainThread:14112 [wandb_init.py:init():913] sending inform_init request
|
| 11 |
+
2026-04-28 17:16:45,731 INFO MainThread:14112 [wandb_init.py:init():918] backend started and connected
|
| 12 |
+
2026-04-28 17:16:45,734 INFO MainThread:14112 [wandb_init.py:init():988] updated telemetry
|
| 13 |
+
2026-04-28 17:16:45,742 INFO MainThread:14112 [wandb_init.py:init():1011] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-04-28 17:16:46,973 INFO MainThread:14112 [wandb_init.py:init():1056] starting run threads in backend
|
| 15 |
+
2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_console_start():2554] atexit reg
|
| 16 |
+
2026-04-28 17:16:47,099 INFO MainThread:14112 [wandb_run.py:_redirect():2403] redirect: wrap_raw
|
| 17 |
+
2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2472] Wrapping output streams.
|
| 18 |
+
2026-04-28 17:16:47,100 INFO MainThread:14112 [wandb_run.py:_redirect():2495] Redirects installed.
|
| 19 |
+
2026-04-28 17:16:47,104 INFO MainThread:14112 [wandb_init.py:init():1094] run started, returning control to user process
|
wandb/run-20260428_171645-wt40fdyx/run-wt40fdyx.wandb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84f6382ddf7402e5b98379478d7897d5a678f939eba1a8a5d028a988674120a5
|
| 3 |
+
size 15499264
|