File size: 3,849 Bytes
d1bfee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import tensorflow as tf
import numpy as np
from flask import Flask, request, jsonify, render_template, send_from_directory
from werkzeug.utils import secure_filename
from tf_models import Generator
from PIL import Image
import base64
from io import BytesIO

app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB limit

# Load the models
try:
    generator_h2z = Generator()
    generator_z2h = Generator()
    
    # Load H2Z weights
    h2z_weights = ["GeneratorHtoZ.h5", "GeneratorHtoZ_25.h5", "gen_g_epoch_0.h5"]
    h2z_loaded = False
    for weight_path in h2z_weights:
        if os.path.exists(weight_path):
            try:
                generator_h2z.load_weights(weight_path, by_name=True, skip_mismatch=True)
                print(f"Loaded H2Z weights from {weight_path}")
                h2z_loaded = True
                break
            except Exception as e:
                print(f"Failed to load H2Z {weight_path}: {e}")
    
    # Load Z2H weights
    z2h_weights = ["GeneratorZtoH.h5", "GeneratorZtoH_25.h5", "gen_f_epoch_0.h5"]
    z2h_loaded = False
    for weight_path in z2h_weights:
        if os.path.exists(weight_path):
            try:
                generator_z2h.load_weights(weight_path, by_name=True, skip_mismatch=True)
                print(f"Loaded Z2H weights from {weight_path}")
                z2h_loaded = True
                break
            except Exception as e:
                print(f"Failed to load Z2H {weight_path}: {e}")
except Exception as e:
    print(f"Error initializing model: {e}")

def preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')
    img = img.resize((256, 256))
    img_array = np.array(img).astype(np.float32)
    img_array = (img_array * 2 / 255.0) - 1.0  # Normalize to [-1, 1]
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

def postprocess_image(tensor):
    # tensor is (1, 256, 256, 3) in range [-1, 1]
    img = tensor[0]
    img = (img + 1.0) * 127.5  # Scale back to [0, 255]
    img = np.clip(img, 0, 255).astype(np.uint8)
    return Image.fromarray(img)

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({'error': 'No image uploaded'}), 400
    
    mode = request.form.get('mode', 'h2z') # Default to horse to zebra
    
    file = request.files['image']
    if file.filename == '':
        return jsonify({'error': 'Empty filename'}), 400
    
    if file:
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        # Inference
        try:
            input_tensor = preprocess_image(filepath)
            
            if mode == 'z2h':
                prediction = generator_z2h(input_tensor, training=False)
            else:
                prediction = generator_h2z(input_tensor, training=False)
                
            output_img = postprocess_image(prediction.numpy())
            
            # Save to buffer for base64 return
            buffered = BytesIO()
            output_img.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
            
            return jsonify({
                'success': True,
                'result': f"data:image/png;base64,{img_str}"
            })
        except Exception as e:
            return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
    app.run(host='0.0.0.0', port=7860, debug=False)