| import random |
| import torch |
| from torch import nn |
| import numpy as np |
| import re |
| import urllib.parse as ul |
| from bs4 import BeautifulSoup |
| from einops import rearrange |
| from dataclasses import dataclass |
| from torchvision import transforms |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from transformers import AutoImageProcessor, AutoModel |
| from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer |
| from transformers.utils import ModelOutput |
| from typing import Iterable, Optional, Union, List |
|
|
| import step1x3d_geometry |
| from step1x3d_geometry.utils.typing import * |
|
|
| from .base import BaseCaptionEncoder |
|
|
| bad_punct_regex = re.compile( |
| r"[" |
| + "#®•©™&@·º½¾¿¡§~" |
| + "\)" |
| + "\(" |
| + "\]" |
| + "\[" |
| + "\}" |
| + "\{" |
| + "\|" |
| + "\\" |
| + "\/" |
| + "\*" |
| + r"]{1,}" |
| ) |
|
|
|
|
| @step1x3d_geometry.register("t5-encoder") |
| class T5Encoder(BaseCaptionEncoder, ModelMixin): |
|
|
| @dataclass |
| class Config(BaseCaptionEncoder.Config): |
| pretrained_model_name_or_path: Optional[str] = ( |
| None |
| ) |
| pretrained_t5_name_or_path: Optional[str] = ( |
| None |
| ) |
| preprocessing_text: bool = False |
| text_max_length: int = 77 |
| t5_type: Optional[str] = None |
|
|
| cfg: Config |
|
|
| def configure(self) -> None: |
| super().configure() |
|
|
| |
| if self.cfg.pretrained_t5_name_or_path is not None: |
| self.cfg.t5_type = f"google-t5/{self.cfg.pretrained_t5_name_or_path.split('google-t5--')[-1].split('/')[0]}" |
| self.tokenizer = T5Tokenizer.from_pretrained( |
| self.cfg.pretrained_t5_name_or_path |
| ) |
| self.text_model = T5EncoderModel.from_pretrained( |
| self.cfg.pretrained_t5_name_or_path, torch_dtype=torch.bfloat16 |
| ) |
| else: |
| if ( |
| self.cfg.pretrained_model_name_or_path is None |
| ): |
| assert self.cfg.t5_type is not None, "The t5_type should be provided" |
| print(f"Loading T5 model from {self.cfg.t5_type}") |
| self.text_model = T5EncoderModel( |
| config=T5EncoderModel.config_class.from_pretrained( |
| self.cfg.t5_type, |
| ) |
| ).to(torch.bfloat16) |
| elif "t5small" in self.cfg.pretrained_model_name_or_path: |
| print("Loading Dinov2 model from google-t5/t5-small") |
| self.cfg.t5_type = "google-t5/t5-small" |
| self.text_model = T5EncoderModel.from_pretrained( |
| self.cfg.t5_type, torch_dtype=torch.bfloat16 |
| ) |
| elif "t5base" in self.cfg.pretrained_model_name_or_path: |
| print("Loading Dinov2 model from google-t5/t5-base") |
| self.cfg.t5_type = "google-t5/t5-base" |
| self.text_model = T5EncoderModel.from_pretrained( |
| self.cfg.t5_type, torch_dtype=torch.bfloat16 |
| ) |
| else: |
| raise ValueError( |
| f"Unknown T5 model: {self.cfg.pretrained_model_name_or_path}" |
| ) |
| self.tokenizer = T5Tokenizer.from_pretrained(self.cfg.t5_type) |
|
|
| |
| if self.cfg.zero_uncond_embeds: |
| self.empty_text_embeds = torch.zeros( |
| (1, self.cfg.text_max_length, self.text_model.config.hidden_size) |
| ).detach() |
| else: |
| self.empty_text_embeds = self.encode_text([""]).detach() |
|
|
| |
| if self.cfg.pretrained_model_name_or_path is not None: |
| print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") |
| ckpt = torch.load( |
| self.cfg.pretrained_model_name_or_path, map_location="cpu" |
| )["state_dict"] |
| pretrained_model_ckpt = {} |
| for k, v in ckpt.items(): |
| if k.startswith("caption_condition."): |
| pretrained_model_ckpt[k.replace("caption_condition.", "")] = v |
| self.load_state_dict(pretrained_model_ckpt, strict=True) |
|
|
| def clean_caption(self, caption): |
| caption = str(caption) |
| caption = ul.unquote_plus(caption) |
| caption = caption.strip().lower() |
| caption = re.sub("<person>", "person", caption) |
| |
| caption = re.sub( |
| r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", |
| "", |
| caption, |
| ) |
| caption = re.sub( |
| r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", |
| "", |
| caption, |
| ) |
| |
| caption = BeautifulSoup(caption, features="html.parser").text |
|
|
| |
| caption = re.sub(r"@[\w\d]+\b", "", caption) |
|
|
| |
| |
| |
| |
| |
| |
| |
| caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) |
| caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) |
| caption = re.sub(r"[\u3200-\u32ff]+", "", caption) |
| caption = re.sub(r"[\u3300-\u33ff]+", "", caption) |
| caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) |
| caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) |
| caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) |
| |
|
|
| |
| caption = re.sub( |
| r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", |
| "-", |
| caption, |
| ) |
|
|
| |
| caption = re.sub(r"[`´«»“”¨]", '"', caption) |
| caption = re.sub(r"[‘’]", "'", caption) |
|
|
| |
| caption = re.sub(r""?", "", caption) |
| |
| caption = re.sub(r"&", "", caption) |
|
|
| |
| caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) |
|
|
| |
| caption = re.sub(r"\d:\d\d\s+$", "", caption) |
|
|
| |
| caption = re.sub(r"\\n", " ", caption) |
|
|
| |
| caption = re.sub(r"#\d{1,3}\b", "", caption) |
| |
| caption = re.sub(r"#\d{5,}\b", "", caption) |
| |
| caption = re.sub(r"\b\d{6,}\b", "", caption) |
| |
| caption = re.sub( |
| r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption |
| ) |
|
|
| |
| caption = re.sub(r"[\"\']{2,}", r'"', caption) |
| caption = re.sub(r"[\.]{2,}", r" ", caption) |
|
|
| caption = re.sub( |
| bad_punct_regex, r" ", caption |
| ) |
| caption = re.sub(r"\s+\.\s+", r" ", caption) |
|
|
| |
| regex2 = re.compile(r"(?:\-|\_)") |
| if len(re.findall(regex2, caption)) > 3: |
| caption = re.sub(regex2, " ", caption) |
|
|
| caption = self.basic_clean(caption) |
|
|
| caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) |
| caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) |
| caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) |
|
|
| caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) |
| caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) |
| caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) |
| caption = re.sub( |
| r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption |
| ) |
| caption = re.sub(r"\bpage\s+\d+\b", "", caption) |
|
|
| caption = re.sub( |
| r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption |
| ) |
|
|
| caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) |
|
|
| caption = re.sub(r"\b\s+\:\s+", r": ", caption) |
| caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) |
| caption = re.sub(r"\s+", " ", caption) |
|
|
| caption.strip() |
|
|
| caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) |
| caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) |
| caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) |
| caption = re.sub(r"^\.\S+$", "", caption) |
|
|
| return caption.strip() |
|
|
| def text_preprocessing(self, text): |
| if self.cfg.preprocessing_text: |
| |
| text = self.clean_caption(text) |
| return text |
| else: |
| return text.lower().strip() |
|
|
| def encode_text(self, texts: List[str]) -> torch.FloatTensor: |
| texts = [self.text_preprocessing(text) for text in texts] |
|
|
| text_tokens_and_mask = self.tokenizer( |
| texts, |
| max_length=self.cfg.text_max_length, |
| padding="max_length", |
| truncation=True, |
| return_attention_mask=True, |
| add_special_tokens=True, |
| return_tensors="pt", |
| ) |
|
|
| text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"] |
| text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"] |
|
|
| with torch.no_grad(): |
| label_embeds = self.text_model( |
| input_ids=text_tokens_and_mask["input_ids"].to(self.text_model.device), |
| attention_mask=text_tokens_and_mask["attention_mask"].to( |
| self.text_model.device |
| ), |
| )["last_hidden_state"].detach() |
|
|
| return label_embeds |
|
|