Spaces:
Build error
Build error
| import torch | |
| from transformers import CLIPProcessor, CLIPTokenizer, CLIPModel | |
| # Load pre-trained CLIP model and processor | |
| model = CLIPModel.from_pretrained("clip-vit-base-patch16-224", torch.device("cpu")) | |
| processor = CLIPProcessor.from_pretrained("clip-vit-base-patch16-224") | |
| # Define a function to generate an image from a text prompt | |
| def generate_image(prompt, height=256, width=256): | |
| # Preprocess the input text | |
| inputs = processor(text=prompt, return_tensors="pt") | |
| # Forward pass | |
| outputs = model(**inputs) | |
| # Get the image embeddings | |
| image_embeddings = outputs.image_embeddings | |
| # Generate an image from the embeddings | |
| image = torch.randn(1, 3, height, width) | |
| image = image.to(torch.device("cpu")) | |
| for i in range(10): # 10 iterations | |
| image = model.get_input_embeddings()(image) + image_embeddings | |
| image = torch.clamp(image, -1, 1) | |
| # Convert the image to PIL format | |
| from PIL import Image | |
| image = image.detach().numpy() | |
| image = image.transpose(0, 2, 3, 1) | |
| image = (image + 1) / 2 | |
| image = Image.fromarray((image * 255).astype("uint8")) | |
| return image | |
| # Test the function | |
| prompt = "A picture of a cat" | |
| image = generate_image(prompt) | |
| image.save("cat.png") |