text-to-image-model / pipeline.py
JBlitzar
a
6b36151
# pipeline.py
import os
import re
import time
import torch
import torchvision
from huggingface_hub import HfApi, HfFolder
from transformers import Pipeline
from factories import UNet_conditional
from wrapper import DiffusionManager, Schedule
from bert_vectorize import vectorize_text_with_bert
from logger import save_grid_with_label
class TextToImagePipeline(Pipeline):
def __init__(self, model_dir: str = "runs/run_3_jxa", device: str = "cpu"):
# Initialize model, diffusion manager, and set up environment
self.device = device
self.model_dir = model_dir
# Create directories if they do not exist
os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)
# Load model
self.net = UNet_conditional(num_classes=768)
self.net.to(self.device)
self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest.pt"), weights_only=True))
# Set up DiffusionManager
self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
self.wrapper.set_schedule(Schedule.LINEAR)
def __call__(self, inputs):
return self.generate_sample_save_images(inputs, 8)
def generate_sample_save_images(self, prompt: str, amt: int = 1):
# Prepare the output path
output_path = os.path.join(self.model_dir, "inferred",
re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_") + str(int(time.time())) + ".png")
# Vectorize the prompt
vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
# Generate images
generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()
# Save images using the provided save function
save_grid_with_label(torchvision.utils.make_grid(generated), prompt, output_path)
return output_path # Return the path to the saved image
# Usage example
if __name__ == "__main__":
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_dir = "runs/run_3_jxa" # Path to your model directory
# Create an instance of the pipeline
pipeline = TextToImagePipeline(model_dir=model_dir, device=device)
# Get user input and generate an image
prompt = input("Prompt? ")
image_path = pipeline(prompt, amt=8)
print(f"Generated image saved at: {image_path}")