Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import random | |
| import torch | |
| import pickle | |
| import numpy as np | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| from scipy.stats import truncnorm | |
| import subprocess | |
| import traceback | |
| import time | |
| # Create a flag to track model loading status | |
| models_loaded_successfully = False | |
| # First run the download_models.py script if models haven't been downloaded | |
| if not os.path.exists('data/state_epoch_1220.pth') or not os.path.exists('data/text_encoder200.pth') or not os.path.exists('data/captions_DAMSM.pickle'): | |
| print("Downloading necessary model files...") | |
| try: | |
| subprocess.check_call([sys.executable, "download_models.py"]) | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error downloading models: {e}") | |
| print("Please check the error message above. The application will attempt to continue with fallback settings.") | |
| # Setup system paths | |
| try: | |
| # Add the code directory to the Python path | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "DF-GAN/code")) | |
| # Import necessary modules from the DF-GAN code | |
| from models.DAMSM import RNN_ENCODER | |
| from models.GAN import NetG | |
| except ImportError as e: | |
| print(f"Error importing required modules: {e}") | |
| print("The application may not function correctly.") | |
| # Utility functions | |
| def load_model_weights(model, weights, multi_gpus=False, train=False): | |
| """Load model weights with proper handling of module prefix""" | |
| try: | |
| if list(weights.keys())[0].find('module')==-1: | |
| pretrained_with_multi_gpu = False | |
| else: | |
| pretrained_with_multi_gpu = True | |
| if (multi_gpus==False) or (train==False): | |
| if pretrained_with_multi_gpu: | |
| state_dict = { | |
| key[7:]: value | |
| for key, value in weights.items() | |
| } | |
| else: | |
| state_dict = weights | |
| else: | |
| state_dict = weights | |
| model.load_state_dict(state_dict) | |
| except Exception as e: | |
| print(f"Error loading model weights: {e}") | |
| print("Using model with random weights instead.") | |
| return model | |
| def get_tokenizer(): | |
| """Get NLTK tokenizer""" | |
| from nltk.tokenize import RegexpTokenizer | |
| tokenizer = RegexpTokenizer(r'\w+') | |
| return tokenizer | |
| def truncated_noise(batch_size=1, dim_z=100, truncation=1.0, seed=None): | |
| """Generate truncated noise""" | |
| state = None if seed is None else np.random.RandomState(seed) | |
| values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) | |
| return truncation * values | |
| def tokenize_and_build_captions(input_text, wordtoix): | |
| """Tokenize text and convert to indices using wordtoix mapping""" | |
| tokenizer = get_tokenizer() | |
| tokens = tokenizer.tokenize(input_text.lower()) | |
| cap = [] | |
| for t in tokens: | |
| t = t.encode('ascii', 'ignore').decode('ascii') | |
| if len(t) > 0 and t in wordtoix: | |
| cap.append(wordtoix[t]) | |
| # Create padded array for the caption | |
| max_len = 18 # As defined in the bird.yml | |
| cap_array = np.zeros(max_len, dtype='int64') | |
| cap_len = len(cap) | |
| if cap_len <= max_len: | |
| cap_array[:cap_len] = cap | |
| else: | |
| # Truncate if too long | |
| cap_array = cap[:max_len] | |
| cap_len = max_len | |
| return cap_array, cap_len | |
| def encode_caption(caption, caption_len, text_encoder, device): | |
| """Encode caption using text encoder""" | |
| try: | |
| with torch.no_grad(): | |
| caption = torch.tensor([caption]).to(device) | |
| caption_len = torch.tensor([caption_len]).to(device) | |
| hidden = text_encoder.init_hidden(1) | |
| _, sent_emb = text_encoder(caption, caption_len, hidden) | |
| return sent_emb | |
| except Exception as e: | |
| print(f"Error encoding caption: {e}") | |
| # Return a random embedding as fallback | |
| return torch.randn(1, 256).to(device) | |
| def save_img(img_tensor): | |
| """Convert image tensor to PIL Image""" | |
| try: | |
| im = img_tensor.data.cpu().numpy() | |
| # [-1, 1] --> [0, 255] | |
| im = (im + 1.0) * 127.5 | |
| im = im.astype(np.uint8) | |
| im = np.transpose(im, (1, 2, 0)) | |
| im = Image.fromarray(im) | |
| return im | |
| except Exception as e: | |
| print(f"Error converting image tensor to PIL Image: {e}") | |
| # Return a red placeholder image as fallback | |
| return Image.new('RGB', (256, 256), color='red') | |
| # Load configuration | |
| config = { | |
| 'z_dim': 100, | |
| 'cond_dim': 256, | |
| 'imsize': 256, | |
| 'nf': 32, | |
| 'ch_size': 3, | |
| 'truncation': True, | |
| 'trunc_rate': 0.88, | |
| } | |
| # Determine device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Global variables for models | |
| wordtoix = {} | |
| ixtoword = {} | |
| text_encoder = None | |
| netG = None | |
| models_loaded = False | |
| # Load vocab and models | |
| def load_models(): | |
| global wordtoix, ixtoword, text_encoder, netG, models_loaded, models_loaded_successfully | |
| try: | |
| # Load vocabulary | |
| if os.path.exists('data/captions_DAMSM.pickle'): | |
| with open('data/captions_DAMSM.pickle', 'rb') as f: | |
| x = pickle.load(f) | |
| wordtoix = x[3] | |
| ixtoword = x[2] | |
| del x | |
| else: | |
| print("Warning: captions_DAMSM.pickle not found. Using fallback vocabulary.") | |
| # Fallback vocabulary | |
| wordtoix = {"the": 1, "bird": 2, "is": 3, "a": 4, "with": 5, "and": 6, "red": 7, "black": 8, "yellow": 9} | |
| ixtoword = {v: k for k, v in wordtoix.items()} | |
| # Initialize text encoder | |
| text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim']) | |
| text_encoder_path = 'data/text_encoder200.pth' | |
| if os.path.exists(text_encoder_path): | |
| state_dict = torch.load(text_encoder_path, map_location='cpu') | |
| text_encoder = load_model_weights(text_encoder, state_dict) | |
| else: | |
| print("Warning: text_encoder200.pth not found. Using random weights.") | |
| text_encoder.to(device) | |
| for p in text_encoder.parameters(): | |
| p.requires_grad = False | |
| text_encoder.eval() | |
| # Initialize generator | |
| netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size']) | |
| netG_path = 'data/state_epoch_1220.pth' | |
| if os.path.exists(netG_path): | |
| state_dict = torch.load(netG_path, map_location='cpu') | |
| if 'model' in state_dict and 'netG' in state_dict['model']: | |
| netG = load_model_weights(netG, state_dict['model']['netG']) | |
| models_loaded_successfully = True | |
| else: | |
| print("Warning: state_epoch_1220.pth has unexpected format. Using random weights.") | |
| else: | |
| print("Warning: state_epoch_1220.pth not found. Using random weights.") | |
| netG.to(device) | |
| netG.eval() | |
| models_loaded = True | |
| return wordtoix, ixtoword, text_encoder, netG | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| traceback.print_exc() | |
| print("Using fallback models instead.") | |
| # Fallback vocabulary | |
| wordtoix = {"the": 1, "bird": 2, "is": 3, "a": 4, "with": 5, "and": 6, "red": 7, "black": 8, "yellow": 9} | |
| ixtoword = {v: k for k, v in wordtoix.items()} | |
| # Create fallback models | |
| try: | |
| text_encoder = RNN_ENCODER(len(wordtoix), nhidden=config['cond_dim']).to(device) | |
| netG = NetG(config['nf'], config['z_dim'], config['cond_dim'], config['imsize'], config['ch_size']).to(device) | |
| models_loaded = False | |
| except Exception as e2: | |
| print(f"Failed to create fallback models: {e2}") | |
| return wordtoix, ixtoword, text_encoder, netG | |
| # Try to load the models | |
| try: | |
| wordtoix, ixtoword, text_encoder, netG = load_models() | |
| except Exception as e: | |
| print(f"Error during model loading: {e}") | |
| print("The application will attempt to continue but may not function correctly.") | |
| def generate_image(text_input, num_images=1, seed=None): | |
| """Generate images from text description""" | |
| if not text_input.strip(): | |
| return [Image.new('RGB', (256, 256), color='lightgray')] * num_images | |
| try: | |
| cap_array, cap_len = tokenize_and_build_captions(text_input, wordtoix) | |
| if cap_len == 0: | |
| return [Image.new('RGB', (256, 256), color='red')] * num_images | |
| sent_emb = encode_caption(cap_array, cap_len, text_encoder, device) | |
| # Set random seed if provided | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| # Generate multiple images if requested | |
| result_images = [] | |
| with torch.no_grad(): | |
| for _ in range(num_images): | |
| # Generate noise | |
| if config['truncation']: | |
| noise = truncated_noise(1, config['z_dim'], config['trunc_rate']) | |
| noise = torch.tensor(noise, dtype=torch.float).to(device) | |
| else: | |
| noise = torch.randn(1, config['z_dim']).to(device) | |
| # Generate image | |
| try: | |
| fake_img = netG(noise, sent_emb) | |
| img = save_img(fake_img[0]) | |
| result_images.append(img) | |
| except Exception as e: | |
| print(f"Error generating image: {e}") | |
| # Return a placeholder image as fallback | |
| img = Image.new('RGB', (256, 256), color=(255, 200, 200)) | |
| result_images.append(img) | |
| return result_images | |
| except Exception as e: | |
| print(f"Error in generate_image: {e}") | |
| traceback.print_exc() | |
| return [Image.new('RGB', (256, 256), color='orange')] * num_images | |
| # Create a simple message for model loading status | |
| model_status = "✅ Models loaded successfully" if models_loaded_successfully else "⚠️ Using fallback models - images may not look good" | |
| # Function to render error page if needed | |
| def serve_error_page(): | |
| if os.path.exists('error_page.html'): | |
| with open('error_page.html', 'r') as f: | |
| return f.read() | |
| else: | |
| return "<html><body><h1>Error loading models</h1><p>The application failed to load the required models.</p></body></html>" | |
| # Create Gradio interface | |
| def generate_images_interface(text, num_images, random_seed): | |
| seed = int(random_seed) if random_seed and random_seed.strip().isdigit() else None | |
| return generate_image(text, num_images, seed) | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Bird Image Generator") as demo: | |
| if models_loaded_successfully: | |
| # Normal interface when models loaded successfully | |
| gr.Markdown("# Bird Image Generator using DF-GAN") | |
| gr.Markdown("Enter a description of a bird and the model will generate corresponding images.") | |
| gr.Markdown(f"**Model Status:** {model_status}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Bird Description", | |
| placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')", | |
| lines=3 | |
| ) | |
| num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images") | |
| seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results") | |
| submit_btn = gr.Button("Generate Image") | |
| with gr.Column(): | |
| image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto") | |
| submit_btn.click( | |
| fn=generate_images_interface, | |
| inputs=[text_input, num_images, seed], | |
| outputs=image_output | |
| ) | |
| gr.Markdown("## Example Descriptions") | |
| example_descriptions = [ | |
| "this bird has an orange bill, a white belly and white eyebrows", | |
| "a small bird with a red head, breast, and belly and black wings", | |
| "this bird is yellow with black and has a long, pointy beak", | |
| "this bird is white in color, and has a orange beak" | |
| ] | |
| gr.Examples( | |
| examples=[[desc, 1, ""] for desc in example_descriptions], | |
| inputs=[text_input, num_images, seed], | |
| outputs=image_output, | |
| fn=generate_images_interface | |
| ) | |
| else: | |
| # Modified interface with warning when models failed to load | |
| gr.Markdown("# ⚠️ Bird Image Generator - Limited Functionality") | |
| gr.Markdown("The pre-trained models could not be loaded correctly. The application will run with randomly initialized models.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Bird Description", | |
| placeholder="Enter a description of a bird (e.g., 'a small bird with a red head and black wings')", | |
| lines=3 | |
| ) | |
| num_images = gr.Slider(minimum=1, maximum=4, value=1, step=1, label="Number of Images") | |
| seed = gr.Textbox(label="Random Seed (optional)", placeholder="Leave empty for random results") | |
| submit_btn = gr.Button("Generate Image (Results will be random shapes)") | |
| with gr.Column(): | |
| image_output = gr.Gallery(label="Generated Images (Random)").style(grid=2, height="auto") | |
| submit_btn.click( | |
| fn=generate_images_interface, | |
| inputs=[text_input, num_images, seed], | |
| outputs=image_output | |
| ) | |
| gr.Markdown(""" | |
| ### Model Loading Error | |
| The application encountered an error while loading the pre-trained models. This could be due to: | |
| 1. Network connectivity issues | |
| 2. The model hosting service might be temporarily unavailable | |
| 3. The model files might have been moved or deleted | |
| Please try refreshing the page or contact the Space owner if the issue persists. | |
| """) | |
| # Launch the app with appropriate configurations for Hugging Face Spaces | |
| if __name__ == "__main__": | |
| # Wait a moment before starting to make sure all logs are printed | |
| time.sleep(1) | |
| demo.launch( | |
| server_name="0.0.0.0", # Bind to all network interfaces | |
| share=False, # Don't use share links | |
| favicon_path="https://raw.githubusercontent.com/tobran/DF-GAN/main/framework.png" | |
| ) |