Spaces:
Runtime error
Runtime error
| # pipeline.py | |
| import os | |
| import re | |
| import time | |
| import torch | |
| import torchvision | |
| from huggingface_hub import HfApi, HfFolder | |
| from factories import UNet_conditional | |
| from wrapper import DiffusionManager, Schedule | |
| from bert_vectorize import vectorize_text_with_bert | |
| from logger import save_grid_with_label | |
| class TextToImagePipeline(): | |
| def __init__(self, model_dir: str = "runs/run_3_jxa", device: str = "cpu"): | |
| # Initialize model, diffusion manager, and set up environment | |
| self.device = device | |
| self.model_dir = model_dir | |
| # Create directories if they do not exist | |
| os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True) | |
| # Load model | |
| self.net = UNet_conditional(num_classes=768) | |
| self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=True)) | |
| self.net.to(self.device) | |
| # Set up DiffusionManager | |
| self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000) | |
| self.wrapper.set_schedule(Schedule.LINEAR) | |
| def __call__(self, prompt,num_steps,amt): | |
| self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=num_steps) | |
| self.wrapper.set_schedule(Schedule.LINEAR) | |
| return self.generate_sample_save_images(prompt, amt) | |
| def generate_sample_save_images(self, prompt: str, amt: int = 1): | |
| # Prepare the output path | |
| output_path = os.path.join(self.model_dir, "inferred", | |
| re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_") + str(int(time.time())) + ".png") | |
| # Vectorize the prompt | |
| vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) | |
| # Generate images | |
| generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu() | |
| # Save images using the provided save function | |
| save_grid_with_label(torchvision.utils.make_grid(generated), prompt, output_path) | |
| return output_path # Return the path to the saved image | |
| # Usage example | |
| if __name__ == "__main__": | |
| device ="cpu" | |
| model_dir = "runs/run_3_jxa" # Path to your model directory | |
| # Create an instance of the pipeline | |
| pipeline = TextToImagePipeline(model_dir=model_dir, device=device) | |
| # Get user input and generate an image | |
| prompt = input("Prompt? ") | |
| image_path = pipeline(prompt, amt=8) | |
| print(f"Generated image saved at: {image_path}") | |