Instructions to use ViTeX-Bench/ViTeX-Edit-14B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ViTeX-Bench/ViTeX-Edit-14B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ViTeX-Bench/ViTeX-Edit-14B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from .operators import * | |
| import torch, json, pandas | |
| class UnifiedDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| base_path=None, metadata_path=None, | |
| repeat=1, | |
| data_file_keys=tuple(), | |
| main_data_operator=lambda x: x, | |
| special_operator_map=None, | |
| max_data_items=None, | |
| ): | |
| self.base_path = base_path | |
| self.metadata_path = metadata_path | |
| self.repeat = repeat | |
| self.data_file_keys = data_file_keys | |
| self.main_data_operator = main_data_operator | |
| self.cached_data_operator = LoadTorchPickle() | |
| self.special_operator_map = {} if special_operator_map is None else special_operator_map | |
| self.max_data_items = max_data_items | |
| self.data = [] | |
| self.cached_data = [] | |
| self.load_from_cache = metadata_path is None | |
| self.load_metadata(metadata_path) | |
| def default_image_operator( | |
| base_path="", | |
| max_pixels=1920*1080, height=None, width=None, | |
| height_division_factor=16, width_division_factor=16, | |
| ): | |
| return RouteByType(operator_map=[ | |
| (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), | |
| (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), | |
| ]) | |
| def default_video_operator( | |
| base_path="", | |
| max_pixels=1920*1080, height=None, width=None, | |
| height_division_factor=16, width_division_factor=16, | |
| num_frames=81, time_division_factor=4, time_division_remainder=1, | |
| frame_rate=24, fix_frame_rate=False, | |
| ): | |
| return RouteByType(operator_map=[ | |
| (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ | |
| (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), | |
| (("gif",), LoadGIF( | |
| num_frames, time_division_factor, time_division_remainder, | |
| frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), | |
| )), | |
| (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( | |
| num_frames, time_division_factor, time_division_remainder, | |
| frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), | |
| frame_rate=frame_rate, fix_frame_rate=fix_frame_rate, | |
| )), | |
| ])), | |
| ]) | |
| def search_for_cached_data_files(self, path): | |
| for file_name in os.listdir(path): | |
| subpath = os.path.join(path, file_name) | |
| if os.path.isdir(subpath): | |
| self.search_for_cached_data_files(subpath) | |
| elif subpath.endswith(".pth"): | |
| self.cached_data.append(subpath) | |
| def load_metadata(self, metadata_path): | |
| if metadata_path is None: | |
| print("No metadata_path. Searching for cached data files.") | |
| self.search_for_cached_data_files(self.base_path) | |
| print(f"{len(self.cached_data)} cached data files found.") | |
| elif metadata_path.endswith(".json"): | |
| with open(metadata_path, "r") as f: | |
| metadata = json.load(f) | |
| self.data = metadata | |
| elif metadata_path.endswith(".jsonl"): | |
| metadata = [] | |
| with open(metadata_path, 'r') as f: | |
| for line in f: | |
| metadata.append(json.loads(line.strip())) | |
| self.data = metadata | |
| else: | |
| metadata = pandas.read_csv(metadata_path) | |
| self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
| def __getitem__(self, data_id): | |
| if self.load_from_cache: | |
| data = self.cached_data[data_id % len(self.cached_data)] | |
| data = self.cached_data_operator(data) | |
| else: | |
| data = self.data[data_id % len(self.data)].copy() | |
| for key in self.data_file_keys: | |
| if key in data: | |
| if key in self.special_operator_map: | |
| data[key] = self.special_operator_map[key](data[key]) | |
| elif key in self.data_file_keys: | |
| data[key] = self.main_data_operator(data[key]) | |
| return data | |
| def __len__(self): | |
| if self.max_data_items is not None: | |
| return self.max_data_items | |
| elif self.load_from_cache: | |
| return len(self.cached_data) * self.repeat | |
| else: | |
| return len(self.data) * self.repeat | |
| def check_data_equal(self, data1, data2): | |
| # Debug only | |
| if len(data1) != len(data2): | |
| return False | |
| for k in data1: | |
| if data1[k] != data2[k]: | |
| return False | |
| return True | |