Utkarsh64 commited on
Commit
656123f
·
verified ·
1 Parent(s): 36ba560

Upload 3 files

Browse files
Files changed (3) hide show
  1. backend/app2.py +167 -0
  2. backend/model.h5 +3 -0
  3. backend/requirements.txt +5 -0
backend/app2.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from flask import Flask, jsonify, request, send_file
7
+ from flask_cors import CORS
8
+ from PIL import Image, ImageEnhance, ImageFilter
9
+
10
+ app = Flask(__name__)
11
+ CORS(app)
12
+
13
+ BASE_DIR = Path(__file__).resolve().parent
14
+ MODEL_PATH = BASE_DIR.parent / "model.h5"
15
+ TARGET_SHORT_SIDE = 2048
16
+ MAX_LONG_SIDE = 4096
17
+ GENERATOR_WORKING_LONG_SIDE = 768
18
+
19
+ gan_generator = None
20
+ model_load_error = None
21
+
22
+
23
+ class GANEnhancementGenerator:
24
+ def __init__(self, model_path):
25
+ self.model_path = model_path
26
+ self.generator = tf.keras.models.load_model(str(model_path), compile=False)
27
+ self.generator.trainable = False
28
+
29
+ output_shape = getattr(self.generator, "output_shape", None)
30
+ if output_shape is not None and output_shape[-1] != 24:
31
+ raise ValueError(
32
+ f"Expected GAN generator output with 24 enhancement channels, got {output_shape}"
33
+ )
34
+
35
+ def generate(self, image):
36
+ working_image = resize_for_generator(image)
37
+ input_tensor = preprocess(working_image)
38
+ generated_tensor = self.generator(input_tensor, training=False)
39
+ enhanced_tensor = apply_generator_enhancement(input_tensor, generated_tensor)
40
+ result = postprocess(enhanced_tensor)
41
+ return improve_clarity(image, Image.fromarray(result))
42
+
43
+
44
+ def _load_gan_generator():
45
+ global gan_generator, model_load_error
46
+
47
+ if not MODEL_PATH.exists():
48
+ model_load_error = f"{MODEL_PATH.name} not found at project root"
49
+ return False
50
+
51
+ try:
52
+ gan_generator = GANEnhancementGenerator(MODEL_PATH)
53
+ print(f"Loaded GAN generator from {MODEL_PATH.name}")
54
+ return True
55
+ except Exception as err:
56
+ model_load_error = f"Failed to load GAN generator: {err}"
57
+ return False
58
+
59
+
60
+ if not _load_gan_generator():
61
+ print(f"No model loaded: {model_load_error}")
62
+
63
+
64
+ def preprocess(image):
65
+ image = np.array(image).astype("float32") / 255.0
66
+ return np.expand_dims(image, axis=0)
67
+
68
+
69
+ def resize_for_generator(image):
70
+ width, height = image.size
71
+ longest_side = max(width, height)
72
+
73
+ if longest_side <= GENERATOR_WORKING_LONG_SIDE:
74
+ return image
75
+
76
+ scale = GENERATOR_WORKING_LONG_SIDE / longest_side
77
+ resized_size = (round(width * scale), round(height * scale))
78
+ return image.resize(resized_size, Image.Resampling.LANCZOS)
79
+
80
+
81
+ def apply_generator_enhancement(image_tensor, generated_tensor):
82
+ r1, r2, r3, r4, r5, r6, r7, r8 = tf.split(generated_tensor, 8, axis=-1)
83
+
84
+ x = image_tensor + r1 * (tf.square(image_tensor) - image_tensor)
85
+ x = x + r2 * (tf.square(x) - x)
86
+ x = x + r3 * (tf.square(x) - x)
87
+ enhanced = x + r4 * (tf.square(x) - x)
88
+ x = enhanced + r5 * (tf.square(enhanced) - enhanced)
89
+ x = x + r6 * (tf.square(x) - x)
90
+ x = x + r7 * (tf.square(x) - x)
91
+ enhanced = x + r8 * (tf.square(x) - x)
92
+
93
+ return tf.clip_by_value(enhanced, 0.0, 1.0)
94
+
95
+
96
+ def postprocess(enhanced_tensor):
97
+ enhanced = enhanced_tensor[0].numpy()
98
+ return np.clip(enhanced * 255.0, 0, 255).astype("uint8")
99
+
100
+
101
+ def improve_clarity(original_image, enhanced_image):
102
+ enhanced_image = enhanced_image.resize(original_image.size, Image.Resampling.LANCZOS)
103
+
104
+ image = Image.blend(original_image, enhanced_image, 0.6)
105
+
106
+ pixels = np.asarray(image).astype("float32")
107
+ brightness = float(np.mean(pixels))
108
+ night_scene = brightness < 95
109
+ if brightness < 95:
110
+ image = ImageEnhance.Brightness(image).enhance(1.08)
111
+ elif brightness < 135:
112
+ image = ImageEnhance.Brightness(image).enhance(1.05)
113
+ elif brightness < 170:
114
+ image = ImageEnhance.Brightness(image).enhance(1.02)
115
+ elif brightness > 190:
116
+ image = ImageEnhance.Brightness(image).enhance(max(0.92, 205 / brightness))
117
+
118
+ if night_scene:
119
+ boosted_pixels = np.asarray(image).astype("float32")
120
+ boosted_brightness = float(np.mean(boosted_pixels))
121
+ if boosted_brightness > 145:
122
+ image = ImageEnhance.Brightness(image).enhance(145 / boosted_brightness)
123
+
124
+ width, height = image.size
125
+ shortest_side = min(width, height)
126
+ longest_side = max(width, height)
127
+ scale = max(1.0, TARGET_SHORT_SIDE / shortest_side)
128
+ scale = min(scale, MAX_LONG_SIDE / longest_side)
129
+ image = image.resize((round(width * scale), round(height * scale)), Image.Resampling.LANCZOS)
130
+
131
+ image = ImageEnhance.Contrast(image).enhance(1.08)
132
+ image = image.filter(ImageFilter.UnsharpMask(radius=0.8, percent=175, threshold=2))
133
+ image = ImageEnhance.Sharpness(image).enhance(1.18)
134
+ return image
135
+
136
+
137
+ @app.route("/")
138
+ def home():
139
+ return "GAN enhancement backend is running"
140
+
141
+
142
+ @app.route("/enhance", methods=["POST"])
143
+ def enhance():
144
+ if gan_generator is None:
145
+ return jsonify({"error": f"GAN generator not loaded: {model_load_error}"}), 500
146
+
147
+ try:
148
+ if "image" not in request.files:
149
+ return jsonify({"error": "No image file provided in 'image' field"}), 400
150
+
151
+ file = request.files["image"]
152
+ image = Image.open(file.stream).convert("RGB")
153
+
154
+ img = gan_generator.generate(image)
155
+ buf = io.BytesIO()
156
+ img.save(buf, format="PNG")
157
+ buf.seek(0)
158
+
159
+ return send_file(buf, mimetype="image/png")
160
+
161
+ except Exception as e:
162
+ print("Error:", e)
163
+ return jsonify({"error": str(e)}), 500
164
+
165
+
166
+ if __name__ == "__main__":
167
+ app.run(debug=True, use_reloader=False)
backend/model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3afb2c4ca9df1824125544107af4d252e72617170d9d2baf2d394c8498958260
3
+ size 358760
backend/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Flask==3.1.3
2
+ flask-cors==6.0.2
3
+ numpy==2.2.6
4
+ pillow==12.2.0
5
+ tensorflow==2.21.0