Spaces:
Build error
Build error
File size: 1,243 Bytes
b047b31 838a805 b047b31 838a805 b047b31 838a805 b047b31 838a805 b047b31 838a805 b047b31 838a805 b047b31 838a805 b047b31 838a805 b047b31 838a805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from transformers import CLIPProcessor, CLIPTokenizer, CLIPModel
# Load pre-trained CLIP model and processor
model = CLIPModel.from_pretrained("clip-vit-base-patch16-224", torch.device("cpu"))
processor = CLIPProcessor.from_pretrained("clip-vit-base-patch16-224")
# Define a function to generate an image from a text prompt
def generate_image(prompt, height=256, width=256):
# Preprocess the input text
inputs = processor(text=prompt, return_tensors="pt")
# Forward pass
outputs = model(**inputs)
# Get the image embeddings
image_embeddings = outputs.image_embeddings
# Generate an image from the embeddings
image = torch.randn(1, 3, height, width)
image = image.to(torch.device("cpu"))
for i in range(10): # 10 iterations
image = model.get_input_embeddings()(image) + image_embeddings
image = torch.clamp(image, -1, 1)
# Convert the image to PIL format
from PIL import Image
image = image.detach().numpy()
image = image.transpose(0, 2, 3, 1)
image = (image + 1) / 2
image = Image.fromarray((image * 255).astype("uint8"))
return image
# Test the function
prompt = "A picture of a cat"
image = generate_image(prompt)
image.save("cat.png") |