Instructions to use babkasotona/2b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/2b with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/2b", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # pip install flash-attn --no-build-isolation | |
| import torch | |
| import os | |
| import gc | |
| import numpy as np | |
| import random | |
| import json | |
| import shutil | |
| import time | |
| from datasets import Dataset, load_from_disk, concatenate_datasets | |
| from diffusers import AutoencoderKLQwenImage | |
| from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda | |
| from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM | |
| from typing import Dict, List, Tuple, Optional, Any | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from datetime import timedelta | |
| from accelerate import Accelerator | |
| accelerator = Accelerator() | |
| device = accelerator.device | |
| is_main_process = accelerator.is_main_process | |
| process_index = accelerator.process_index | |
| num_processes = accelerator.num_processes | |
| # ---------------- 1️⃣ Настройки ---------------- | |
| dtype = torch.float16 | |
| batch_size = 5 | |
| min_size = 320 | |
| max_size = 640 | |
| step = 32 | |
| empty_share = 0.0 | |
| limit = 0 | |
| folder_path = "/workspace/dataset/d23" | |
| save_path = "/workspace/ds234_640_vae_qwen" | |
| os.makedirs(save_path, exist_ok=True) | |
| def clear_cuda_memory(): | |
| if torch.cuda.is_available(): | |
| used_gb = torch.cuda.max_memory_allocated() / 1024**3 | |
| print(f"[GPU {process_index}] used_gb: {used_gb:.2f} GB") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # ---------------- 2️⃣ Загрузка моделей ---------------- | |
| def load_models(): | |
| print(f"[GPU {process_index}] Загрузка моделей...") | |
| vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval() | |
| return vae | |
| vae = load_models() | |
| shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0 | |
| scaling_factor = getattr(vae.config, "scaling_factor", 1.0) or 1.0 | |
| mean = getattr(vae.config, "latents_mean", None) | |
| std = getattr(vae.config, "latents_std", None) | |
| if mean is not None and std is not None: | |
| latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1, 1) | |
| latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1, 1) | |
| # ---------------- 3️⃣ Трансформации ---------------- | |
| def get_image_transform(min_size=256, max_size=512, step=64): | |
| def transform(img, dry_run=False): | |
| original_width, original_height = img.size | |
| if original_width >= original_height: | |
| new_width = max_size | |
| new_height = int(max_size * original_height / original_width) | |
| else: | |
| new_height = max_size | |
| new_width = int(max_size * original_width / original_height) | |
| if new_height < min_size or new_width < min_size: | |
| if original_width <= original_height: | |
| new_width = min_size | |
| new_height = int(min_size * original_height / original_width) | |
| else: | |
| new_height = min_size | |
| new_width = int(min_size * original_width / original_height) | |
| crop_width = min(max_size, (new_width // step) * step) | |
| crop_height = min(max_size, (new_height // step) * step) | |
| crop_width = max(min_size, crop_width) | |
| crop_height = max(min_size, crop_height) | |
| if dry_run: | |
| return crop_width, crop_height | |
| img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS) | |
| top = (new_height - crop_height) // 3 | |
| left = 0 | |
| img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height)) | |
| final_width, final_height = img_cropped.size | |
| img_tensor = ToTensor()(img_cropped) | |
| img_tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(img_tensor) | |
| return img_tensor, img_cropped, final_width, final_height | |
| return transform | |
| # ---------------- 4️⃣ Функции обработки ---------------- | |
| def clean_label(label): | |
| label = label.replace("Image 1","").replace("Image 2","").replace("Image 3","").replace("Image 4","") | |
| label = label.replace("The image depicts ","").replace("The image presents ","") | |
| label = label.replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip() | |
| if label.startswith("."): | |
| label = label[1:].lstrip() | |
| return label | |
| def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01): | |
| labels_for_model = [] | |
| labels_for_logging = [] | |
| for label in original_labels: | |
| if random.random() < prob_to_make_empty: | |
| labels_for_model.append("") | |
| labels_for_logging.append(f"zero: {label}") | |
| else: | |
| labels_for_model.append(label) | |
| labels_for_logging.append(label) | |
| return labels_for_model, labels_for_logging | |
| def encode_to_latents(images, texts): | |
| transform = get_image_transform(min_size, max_size, step) | |
| transformed_tensors = [] | |
| widths, heights = [], [] | |
| for img in images: | |
| try: | |
| t_img, _, w, h = transform(img) | |
| transformed_tensors.append(t_img) | |
| widths.append(w) | |
| heights.append(h) | |
| except Exception as e: | |
| print(f"Ошибка трансформации: {e}") | |
| if not transformed_tensors: | |
| return None | |
| batch_tensor = torch.stack(transformed_tensors).to(device, dtype) | |
| if batch_tensor.ndim==4: | |
| batch_tensor = batch_tensor.unsqueeze(2) | |
| with torch.no_grad(): | |
| posteriors = vae.encode(batch_tensor).latent_dist.mode() | |
| if mean is not None and std is not None: | |
| posteriors = (posteriors - latents_mean) / latents_std | |
| posteriors = (posteriors - shift_factor) / scaling_factor | |
| #latents_np = posteriors.cpu().numpy() | |
| latents_np = posteriors.squeeze(2).cpu().numpy() | |
| text_labels = [clean_label(text) for text in texts] | |
| _, text_labels = process_labels_for_guidance(text_labels, empty_share) | |
| return { | |
| "vae": latents_np, | |
| "text": text_labels, | |
| "width": widths, | |
| "height": heights | |
| } | |
| # ---------------- 5️⃣ Обработка папки ---------------- | |
| def process_folder(folder_path, limit=None): | |
| image_paths, text_paths, width, height = [], [], [], [] | |
| transform = get_image_transform(min_size, max_size, step) | |
| for root, _, files in os.walk(folder_path): | |
| for filename in files: | |
| if filename.lower().endswith((".jpg",".jpeg",".png")): | |
| image_path = os.path.join(root, filename) | |
| try: | |
| img = Image.open(image_path) | |
| except: | |
| continue | |
| w,h = transform(img, dry_run=True) | |
| text_path = os.path.splitext(image_path)[0]+".txt" | |
| if os.path.exists(text_path): | |
| image_paths.append(image_path) | |
| text_paths.append(text_path) | |
| width.append(w) | |
| height.append(h) | |
| print(f"Найдено {len(image_paths)} изображений") | |
| return image_paths, text_paths, width, height | |
| def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1): | |
| total_files = len(image_paths) | |
| start_time = time.time() | |
| for chunk_idx, start in enumerate(range(0,total_files,chunk_size),1): | |
| end = min(start+chunk_size,total_files) | |
| chunk_image_paths = image_paths[start:end] | |
| chunk_text_paths = text_paths[start:end] | |
| chunk_widths = width[start:end] | |
| chunk_heights = height[start:end] | |
| chunk_texts = [] | |
| for text_path in chunk_text_paths: | |
| try: | |
| with open(text_path,'r',encoding='utf-8') as f: | |
| chunk_texts.append(f.read().strip()) | |
| except: | |
| chunk_texts.append("") | |
| size_groups = {} | |
| for i in range(len(chunk_image_paths)): | |
| key=(chunk_widths[i],chunk_heights[i]) | |
| size_groups.setdefault(key,{"image_paths":[],"texts":[]}) | |
| size_groups[key]["image_paths"].append(chunk_image_paths[i]) | |
| size_groups[key]["texts"].append(chunk_texts[i]) | |
| for size_key,group_data in size_groups.items(): | |
| group_dataset = Dataset.from_dict(group_data) | |
| processed_group = group_dataset.map( | |
| lambda ex: encode_to_latents( | |
| [Image.open(p) for p in ex["image_paths"]], | |
| #[Image.open(p).convert("RGB") for p in ex["image_paths"]], # <--- Добавил .convert("RGB"), чтобы картинка загрузилась в память | |
| ex["texts"] | |
| ), | |
| batched=True, | |
| batch_size=batch_size, | |
| ) | |
| # --- NEW: уникальный путь --- | |
| group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_{size_key[0]}x{size_key[1]}_proc_{process_index}_" | |
| # --- END NEW --- | |
| processed_group.save_to_disk(group_save_path) | |
| clear_cuda_memory() | |
| # ---------------- 7️⃣ Объединение ---------------- | |
| def combine_chunks(temp_path, final_path): | |
| chunks = sorted([ | |
| os.path.join(temp_path,d) | |
| for d in os.listdir(temp_path) | |
| if "chunk_" in d | |
| ]) | |
| datasets = [load_from_disk(c) for c in chunks] | |
| combined = concatenate_datasets(datasets) | |
| combined.save_to_disk(final_path) | |
| print("✅ Сохранено") | |
| # ---------------- MAIN ---------------- | |
| temp_path = f"{save_path}_temp" | |
| os.makedirs(temp_path, exist_ok=True) | |
| image_paths, text_paths, width, height = process_folder(folder_path,limit) | |
| # сортировка | |
| sorted_indices = sorted(range(len(width)), key=lambda i:(width[i],height[i])) | |
| image_paths = [image_paths[i] for i in sorted_indices] | |
| text_paths = [text_paths[i] for i in sorted_indices] | |
| width = [width[i] for i in sorted_indices] | |
| height = [height[i] for i in sorted_indices] | |
| # --- shard по GPU --- | |
| indices = list(range(len(image_paths))) | |
| indices = indices[process_index::num_processes] | |
| image_paths = [image_paths[i] for i in indices] | |
| text_paths = [text_paths[i] for i in indices] | |
| width = [width[i] for i in indices] | |
| height = [height[i] for i in indices] | |
| print(f"[GPU {process_index}] обрабатывает {len(image_paths)} файлов") | |
| process_in_chunks(image_paths, text_paths, width, height, chunk_size=5000, batch_size=batch_size) | |
| accelerator.wait_for_everyone() | |
| # --- NEW: только главный процесс --- | |
| if is_main_process: | |
| #try: | |
| #shutil.rmtree(folder_path) | |
| #except: | |
| # pass | |
| combine_chunks(temp_path, save_path) | |
| try: | |
| shutil.rmtree(temp_path) | |
| except: | |
| pass |