|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
```python |
|
|
# You can use the following code to call our trained style encoder. Hope it helps. |
|
|
import torchvision.transforms.functional as F |
|
|
from torchvision import transforms |
|
|
from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoConfig, |
|
|
CLIPImageProcessor, CLIPVisionModelWithProjection) |
|
|
class SEStyleEmbedding: |
|
|
def __init__(self, pretrained_path: str = "xingpng/OneIG-StyleEncoder", device: str = "cuda", dtype=torch.bfloat16): |
|
|
self.device = torch.device(device) |
|
|
self.dtype = dtype |
|
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_path) |
|
|
self.image_encoder.to(self.device, dtype=self.dtype) |
|
|
self.image_encoder.eval() |
|
|
self.processor = CLIPImageProcessor() |
|
|
|
|
|
def _l2_normalize(self, x): |
|
|
return torch.nn.functional.normalize(x, p=2, dim=-1) |
|
|
|
|
|
def get_style_embedding(self, image_path: str): |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
inputs = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.image_encoder(inputs) |
|
|
image_embeds = outputs.image_embeds |
|
|
image_embeds_norm = self._l2_normalize(image_embeds) |
|
|
return image_embeds_norm |