| | |
| | import os |
| | import re |
| | import time |
| | import torch |
| | import torchvision |
| | from huggingface_hub import HfApi, HfFolder |
| | from transformers import Pipeline |
| | from factories import UNet_conditional |
| | from wrapper import DiffusionManager, Schedule |
| | from bert_vectorize import vectorize_text_with_bert |
| | from logger import save_grid_with_label |
| |
|
| | class TextToImagePipeline(Pipeline): |
| | def __init__(self, model_dir: str = "runs/run_3_jxa", device: str = "cpu"): |
| | |
| | self.device = device |
| | self.model_dir = model_dir |
| | |
| | |
| | os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True) |
| |
|
| | |
| | self.net = UNet_conditional(num_classes=768) |
| | self.net.to(self.device) |
| | self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest.pt"), weights_only=True)) |
| |
|
| | |
| | self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000) |
| | self.wrapper.set_schedule(Schedule.LINEAR) |
| |
|
| | def __call__(self, inputs): |
| | |
| | return self.generate_sample_save_images(inputs, 8) |
| |
|
| | def generate_sample_save_images(self, prompt: str, amt: int = 1): |
| | |
| | output_path = os.path.join(self.model_dir, "inferred", |
| | re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_") + str(int(time.time())) + ".png") |
| |
|
| | |
| | vprompt = vectorize_text_with_bert(prompt).unsqueeze(0) |
| |
|
| | |
| | generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu() |
| |
|
| | |
| | save_grid_with_label(torchvision.utils.make_grid(generated), prompt, output_path) |
| |
|
| | return output_path |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | device = "mps" if torch.backends.mps.is_available() else "cpu" |
| | model_dir = "runs/run_3_jxa" |
| |
|
| | |
| | pipeline = TextToImagePipeline(model_dir=model_dir, device=device) |
| |
|
| | |
| | prompt = input("Prompt? ") |
| | image_path = pipeline(prompt, amt=8) |
| | print(f"Generated image saved at: {image_path}") |
| |
|