Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, jsonify | |
| from tensorflow.keras.models import load_model | |
| from numpy.random import randn | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| app = Flask(__name__) | |
| # Load your GAN model from the H5 file | |
| model = load_model('gan.h5') | |
| def generate_latent_points(latent_dim, n_samples): | |
| x_input = randn(latent_dim * n_samples) | |
| z_input = x_input.reshape(n_samples, latent_dim) | |
| return z_input | |
| def generate_images(model, latent_points): | |
| generated_images = model.predict(latent_points) | |
| return generated_images | |
| def plot_generated(examples, n_rows, n_cols, image_size=(80, 80)): | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10)) | |
| for i in range(n_rows): | |
| for j in range(n_cols): | |
| index = i * n_cols + j | |
| if index < len(examples): | |
| axes[i, j].axis('off') | |
| axes[i, j].imshow(examples[index, :, :]) | |
| else: | |
| axes[i, j].axis('off') | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close(fig) | |
| return base64.b64encode(buf.read()).decode('utf-8') | |
| def index(): | |
| return render_template('index.html') | |
| import math | |
| def generate(): | |
| latent_dim = 100 | |
| n_samples = max(int(request.form.get('n_samples', 4)), 1) | |
| # Calculate the number of rows dynamically based on the square root of n_samples | |
| n_rows = max(int(math.sqrt(n_samples)), 1) | |
| # Calculate the number of columns based on the number of rows | |
| n_cols = (n_samples + n_rows - 1) // n_rows | |
| latent_points = generate_latent_points(latent_dim, n_samples) | |
| generated_images = generate_images(model, latent_points) | |
| generated_images = (generated_images + 1) / 2.0 | |
| img_data = plot_generated(generated_images, n_rows, n_cols) | |
| return jsonify({'success': True, 'generated_image': img_data}) | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |