BitRoss / app.py
OVAWARE's picture
Update app.py
cf238b7 verified
raw
history blame
4.06 kB
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
import gradio as gr
import numpy as np
import os
import time
from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def clean_image(image, threshold=0.75):
np_image = np.array(image)
alpha_channel = np_image[:, :, 3]
alpha_channel[alpha_channel <= int(threshold * 255)] = 0
alpha_channel[alpha_channel > int(threshold * 255)] = 255 # Set to 100% visibility
return Image.fromarray(np_image)
def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
with torch.no_grad():
text_encoding = model.text_encoder(input_ids, attention_mask)
z = torch.randn(1, LATENT_DIM).to(device)
with torch.no_grad():
generated_image = model.decode(z, text_encoding)
if input_image is not None:
input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST)
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
generated_image = img_control * input_image + (1 - img_control) * generated_image
generated_image = generated_image.squeeze(0).cpu()
generated_image = (generated_image + 1) / 2 # Rescale from [-1, 1] to [0, 1]
generated_image = generated_image.clamp(0, 1)
generated_image = transforms.ToPILImage()(generated_image)
return generated_image
def process(prompt, model_path, clean, size, input_image, img_control, output_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load input image if provided
input_image = Image.open(input_image).convert("RGBA") if input_image else None
# Initialize model
text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
model = CVAE(text_encoder).to(device)
# Load the trained model
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
start_time = time.time()
# Generate image from prompt
generated_image = generate_image(model, prompt, device, input_image, img_control)
end_time = time.time()
generation_time = end_time - start_time
# Clean up the image if the flag is set
if clean:
generated_image = clean_image(generated_image)
# Resize the generated image
generated_image = generated_image.resize((size, size), resample=Image.NEAREST)
# Save the generated image to the specified directory
model_name = os.path.splitext(os.path.basename(model_path))[0]
output_file = os.path.join(output_dir, f"{model_name}_{prompt}.png")
os.makedirs(output_dir, exist_ok=True)
generated_image.save(output_file)
print(f"Generated image saved as {output_file}")
print(f"Generation time: {generation_time:.10f} seconds")
return generated_image
# Gradio Interface
interface = gr.Interface(
fn=process,
inputs=[
gr.Textbox(label="Text Prompt"),
gr.File(label="Model Path (.pth file)", file_types=['.pth']),
gr.Checkbox(label="Clean Image (Remove Low Opacity Pixels)", default=False),
gr.Slider(label="Image Size", minimum=16, maximum=512, step=16, default=16),
gr.File(label="Input Image (Optional)", file_types=["image"]),
gr.Slider(label="Image Control (0-1)", minimum=0.0, maximum=1.0, step=0.01, default=0.5),
gr.Textbox(label="Output Directory", value="generated_images")
],
outputs=gr.Image(label="Generated Image"),
title="Text-to-Image Generator",
description="Generate an image from a text prompt using a trained CVAE model."
)
if __name__ == "__main__":
interface.launch()