import os import tensorflow as tf import numpy as np from flask import Flask, request, jsonify, render_template, send_from_directory from werkzeug.utils import secure_filename from tf_models import Generator from PIL import Image import base64 from io import BytesIO app = Flask(__name__) app.config['UPLOAD_FOLDER'] = 'static/uploads' app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB limit # Load the models try: generator_h2z = Generator() generator_z2h = Generator() # Load H2Z weights h2z_weights = ["GeneratorHtoZ.h5", "GeneratorHtoZ_25.h5", "gen_g_epoch_0.h5"] h2z_loaded = False for weight_path in h2z_weights: if os.path.exists(weight_path): try: generator_h2z.load_weights(weight_path, by_name=True, skip_mismatch=True) print(f"Loaded H2Z weights from {weight_path}") h2z_loaded = True break except Exception as e: print(f"Failed to load H2Z {weight_path}: {e}") # Load Z2H weights z2h_weights = ["GeneratorZtoH.h5", "GeneratorZtoH_25.h5", "gen_f_epoch_0.h5"] z2h_loaded = False for weight_path in z2h_weights: if os.path.exists(weight_path): try: generator_z2h.load_weights(weight_path, by_name=True, skip_mismatch=True) print(f"Loaded Z2H weights from {weight_path}") z2h_loaded = True break except Exception as e: print(f"Failed to load Z2H {weight_path}: {e}") except Exception as e: print(f"Error initializing model: {e}") def preprocess_image(image_path): img = Image.open(image_path).convert('RGB') img = img.resize((256, 256)) img_array = np.array(img).astype(np.float32) img_array = (img_array * 2 / 255.0) - 1.0 # Normalize to [-1, 1] img_array = np.expand_dims(img_array, axis=0) return img_array def postprocess_image(tensor): # tensor is (1, 256, 256, 3) in range [-1, 1] img = tensor[0] img = (img + 1.0) * 127.5 # Scale back to [0, 255] img = np.clip(img, 0, 255).astype(np.uint8) return Image.fromarray(img) @app.route('/') def index(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): if 'image' not in request.files: return jsonify({'error': 'No image uploaded'}), 400 mode = request.form.get('mode', 'h2z') # Default to horse to zebra file = request.files['image'] if file.filename == '': return jsonify({'error': 'Empty filename'}), 400 if file: filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) # Inference try: input_tensor = preprocess_image(filepath) if mode == 'z2h': prediction = generator_z2h(input_tensor, training=False) else: prediction = generator_h2z(input_tensor, training=False) output_img = postprocess_image(prediction.numpy()) # Save to buffer for base64 return buffered = BytesIO() output_img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') return jsonify({ 'success': True, 'result': f"data:image/png;base64,{img_str}" }) except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) app.run(host='0.0.0.0', port=7860, debug=False)