import os import sys import random import torch import pickle import numpy as np from PIL import Image import torch.nn.functional as F import gradio as gr from omegaconf import OmegaConf from scipy.stats import truncnorm import subprocess # First run the download_models.py script if models haven't been downloaded if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth'): print("Downloading necessary model files...") try: subprocess.check_call([sys.executable, "download_models.py"]) except subprocess.CalledProcessError as e: print(f"Error downloading models: {e}") print("Please run download_models.py manually before starting the app.") # Add the code directory to the Python path sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code")) # Import necessary modules from the DF-GAN code from models.DAMSM import RNN_ENCODER from models.GAN import NetG # Utility functions def load_model_weights(model, weights, multi_gpus=False, train=False): """Load model weights with proper handling of module prefix""" if list(weights.keys())[0].find('module')==-1: pretrained_with_multi_gpu = False else: pretrained_with_multi_gpu = True if (multi_gpus==False) or (train==False): if pretrained_with_multi_gpu: state_dict = { key[7:]: value for key, value in weights.items() } else: state_dict = weights else: state_dict = weights model.load_state_dict(state_dict) return model def get_tokenizer(): """Get NLTK tokenizer""" from nltk.tokenize import RegexpTokenizer tokenizer = RegexpTokenizer(r'\w+') return tokenizer def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None): """Generate truncated noise""" state = None if seed is None else np.random.RandomState(seed) values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) return truncation * values def tokenize_and_build_captions(input_text, wordtoix): """Tokenize text and convert to indices using wordtoix mapping""" tokenizer = get_tokenizer() tokens = tokenizer.tokenize(input_text.lower()) cap = [] for t in tokens: t = t.encode('ascii', 'ignore').decode('ascii') if len(t) > 0 and t in wordtoix: cap.append(wordtoix[t]) # Create padded array for the caption max_len = 18 # As defined in the bird.yml cap_array = np.zeros(max_len, dtype='int64') cap_len = len(cap) if cap_len <= max_len: cap_array[:cap_len] = cap else: # Truncate if too long cap_array = cap[:max_len] cap_len = max_len return cap_array, cap_len def encode_caption(caption, caption_len, text_encoder, device): """Encode caption using text encoder""" with torch.no_grad(): caption = torch.tensor([caption]).to(device) caption_len = torch.tensor([caption_len]).to(device) hidden = text_encoder.init_hidden(1) _, sent_emb = text_encoder(caption, caption_len, hidden) return sent_emb def save_img(img_tensor): """Convert image tensor to PIL Image""" im = img_tensor.data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) return im # Load configuration config = { 'z_dim': 100, 'cond_dim': 256, 'imsize': 256, 'nf': 32, 'ch_size': 3, 'truncation': True, 'trunc_rate': 0.88, } device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load vocab and models def load_models(): # Load vocabulary with open('data/captions_DAMSM.pickle', 'rb') as f: x = pickle.load(f) wordtoix = x[3] ixtoword = x[2] del x # Initialize text encoder text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim']) text_encoder_path = 'data/text_encoder200.pth' state_dict = torch.load(text_encoder_path, map_location='cpu') text_encoder = load_model_weights(text_encoder, state_dict) text_encoder.to(device) for p in text_encoder.parameters(): p.requires_grad = False text_encoder.eval() # Initialize generator netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size']) netG_path = 'data/state_epoch_1220.pth' state_dict = torch.load(netG_path, map_location='cpu') netG = load_model_weights(netG, state_dict['model']['netG']) netG.to(device) netG.eval() return wordtoix, ixtoword, text_encoder, netG wordtoix, ixtoword, text_encoder, netG = load_models() def generate_image(text_input, num_images=1, seed=None): """Generate images from text description""" if not text_input.strip(): return [None] * num_images cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix) if cap_len == 0: return [Image.new('RGB', (256, 256), color='red')] * num_images sent_emb = encode_caption(cap_array, cap_len, text_encoder, device) # Set random seed if provided if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Generate multiple images if requested result_images = [] with torch.no_grad(): for _ in range(num_images): # Generate noise if config['truncation']: noise = truncated_noise(1, config['z_dim'], config['trunc_rate']) noise = torch.tensor(noise, dtype=torch.float).to(device) else: noise = torch.randn(1, config['z_dim']).to(device) # Generate image fake_img = netG(noise, sent_emb) img = save_img(fake_img[0]) result_images.append(img) return result_images # Create Gradio interface def generate_images_interface(text, num_images, random_seed): seed = int(random_seed) if random_seed else None return generate_image(text, num_images, seed) with gr.Blocks(title="Bird Image Generator") as demo: gr.Markdown("# Bird Image Generator using DF-GAN") gr.Markdown("Enter a description of a bird and the model will generate corresponding images.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Bird Description", placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')", lines=3 ) num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images") seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results") submit_btn = gr.Button("Generate Image") with gr.Column(): image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto") submit_btn.click( fn=generate_images_interface, inputs=[text_input, num_images, seed], outputs=image_output ) gr.Markdown("## Example Descriptions") example_descriptions = [ "this bird has an orange bill, a white belly and white eyebrows", "a small bird with a red head, breast, and belly and black wings", "this bird is yellow with black and has a long, pointy beak", "this bird is white in color, and has a orange beak" ] gr.Examples( examples=[[desc, 1, ""] for desc in example_descriptions], inputs=[text_input, num_images, seed], outputs=image_output, fn=generate_images_interface ) # Launch the app with appropriate configurations for Hugging Face Spaces if __name__ == "__main__": demo.launch( server_name="0.0.0.0", # Bind to all network interfaces share=False, # Don't use share links favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png" )