Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import v2 | |
| from transformers import CLIPTextModel, CLIPTokenizer, \ | |
| CLIPProcessor, CLIPVisionModelWithProjection, CLIPTextModelWithProjection | |
| import os | |
| # from image_generator import get_output_embeds, position_embeddings | |
| # Set device | |
| torch_device = "cuda" if torch.cuda.is_available() else "mps" \ | |
| if torch.backends.mps.is_available() else "cpu" | |
| if "mps" == torch_device: | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
| # Load the tokenizer and text encoder to tokenize and encode the text. | |
| clip_model_name = "openai/clip-vit-large-patch14" | |
| tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) | |
| text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(torch_device); | |
| vision_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_model_name).to(torch_device); | |
| processor = CLIPProcessor.from_pretrained(clip_model_name) | |
| # # additional textual prompt | |
| def get_text_embed(prompt = "on a mountain"): | |
| inputs = processor(text=prompt, | |
| return_tensors="pt", | |
| padding=True) | |
| with torch.no_grad(): | |
| text_embed = CLIPTextModelWithProjection.from_pretrained( | |
| clip_model_name)(**inputs).text_embeds.to(torch_device) | |
| return text_embed | |
| # def get_text_embed(prompt = "on a mountain"): | |
| # text_input = tokenizer([prompt], | |
| # padding="max_length", | |
| # max_length=tokenizer.model_max_length, | |
| # truncation=True, | |
| # return_tensors="pt") | |
| # with torch.no_grad(): | |
| # text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
| # input_embeddings = text_embeddings + position_embeddings.to(torch_device) | |
| # modified_output_embeddings = get_output_embeds(input_embeddings) | |
| # return modified_output_embeddings | |
| class cosine_loss(nn.Module): | |
| def __init__(self, prompt) -> None: | |
| self.text_embed = get_text_embed(prompt) | |
| super().__init__() | |
| def forward(self, gen_image): | |
| gen_image_clamped = gen_image.clamp(0, 1).mul(255) | |
| resized_image = v2.Resize(224)(gen_image_clamped) | |
| image_embed = vision_encoder(resized_image).image_embeds | |
| similarity = F.cosine_similarity(self.text_embed, image_embed, dim=1) | |
| loss = 1 - similarity.mean() | |
| return loss | |
| def blue_loss(images): | |
| # How far are the blue channel values to 0.9: | |
| error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel | |
| return error | |