GAN_Project / backend /app2.py
Utkarsh64's picture
Update backend/app2.py
7df5273 verified
import io
import os
from pathlib import Path
import numpy as np
import tensorflow as tf
from flask import Flask, jsonify, request, send_file, send_from_directory
from flask_cors import CORS
from PIL import Image, ImageEnhance, ImageFilter
# Prevent unnecessary GPU init on cloud
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# React build folder
FRONTEND_DIST = Path(__file__).resolve().parent.parent / "frontend" / "dist"
app = Flask(
__name__,
static_folder=str(FRONTEND_DIST),
static_url_path=""
)
CORS(app)
BASE_DIR = Path(__file__).resolve().parent
MODEL_PATH = BASE_DIR / "model.h5"
TARGET_SHORT_SIDE = 2048
MAX_LONG_SIDE = 4096
GENERATOR_WORKING_LONG_SIDE = 768
gan_generator = None
model_load_error = None
class GANEnhancementGenerator:
def __init__(self, model_path):
self.model_path = model_path
self.generator = tf.keras.models.load_model(str(model_path), compile=False)
self.generator.trainable = False
output_shape = getattr(self.generator, "output_shape", None)
if output_shape is not None and output_shape[-1] != 24:
raise ValueError(
f"Expected GAN generator output with 24 enhancement channels, got {output_shape}"
)
def generate(self, image):
working_image = resize_for_generator(image)
input_tensor = preprocess(working_image)
generated_tensor = self.generator(input_tensor, training=False)
enhanced_tensor = apply_generator_enhancement(input_tensor, generated_tensor)
result = postprocess(enhanced_tensor)
return improve_clarity(image, Image.fromarray(result))
def _load_gan_generator():
global gan_generator, model_load_error
if not MODEL_PATH.exists():
model_load_error = f"{MODEL_PATH.name} not found in backend folder"
return False
try:
gan_generator = GANEnhancementGenerator(MODEL_PATH)
print(f"Loaded GAN generator from {MODEL_PATH.name}")
return True
except Exception as err:
model_load_error = f"Failed to load GAN generator: {err}"
return False
if not _load_gan_generator():
print(f"No model loaded: {model_load_error}")
def preprocess(image):
image = np.array(image).astype("float32") / 255.0
return np.expand_dims(image, axis=0)
def resize_for_generator(image):
width, height = image.size
longest_side = max(width, height)
if longest_side <= GENERATOR_WORKING_LONG_SIDE:
return image
scale = GENERATOR_WORKING_LONG_SIDE / longest_side
resized_size = (round(width * scale), round(height * scale))
return image.resize(resized_size, Image.Resampling.LANCZOS)
def apply_generator_enhancement(image_tensor, generated_tensor):
r1, r2, r3, r4, r5, r6, r7, r8 = tf.split(generated_tensor, 8, axis=-1)
x = image_tensor + r1 * (tf.square(image_tensor) - image_tensor)
x = x + r2 * (tf.square(x) - x)
x = x + r3 * (tf.square(x) - x)
enhanced = x + r4 * (tf.square(x) - x)
x = enhanced + r5 * (tf.square(enhanced) - enhanced)
x = x + r6 * (tf.square(x) - x)
x = x + r7 * (tf.square(x) - x)
enhanced = x + r8 * (tf.square(x) - x)
return tf.clip_by_value(enhanced, 0.0, 1.0)
def postprocess(enhanced_tensor):
enhanced = enhanced_tensor[0].numpy()
return np.clip(enhanced * 255.0, 0, 255).astype("uint8")
def improve_clarity(original_image, enhanced_image):
enhanced_image = enhanced_image.resize(original_image.size, Image.Resampling.LANCZOS)
image = Image.blend(original_image, enhanced_image, 0.6)
pixels = np.asarray(image).astype("float32")
brightness = float(np.mean(pixels))
night_scene = brightness < 95
if brightness < 95:
image = ImageEnhance.Brightness(image).enhance(1.08)
elif brightness < 135:
image = ImageEnhance.Brightness(image).enhance(1.05)
elif brightness < 170:
image = ImageEnhance.Brightness(image).enhance(1.02)
elif brightness > 190:
image = ImageEnhance.Brightness(image).enhance(max(0.92, 205 / brightness))
if night_scene:
boosted_pixels = np.asarray(image).astype("float32")
boosted_brightness = float(np.mean(boosted_pixels))
if boosted_brightness > 145:
image = ImageEnhance.Brightness(image).enhance(145 / boosted_brightness)
width, height = image.size
shortest_side = min(width, height)
longest_side = max(width, height)
scale = max(1.0, TARGET_SHORT_SIDE / shortest_side)
scale = min(scale, MAX_LONG_SIDE / longest_side)
image = image.resize(
(round(width * scale), round(height * scale)),
Image.Resampling.LANCZOS
)
image = ImageEnhance.Contrast(image).enhance(1.08)
image = image.filter(ImageFilter.UnsharpMask(radius=0.8, percent=175, threshold=2))
image = ImageEnhance.Sharpness(image).enhance(1.18)
return image
# ---------- FRONTEND ROUTES ----------
@app.route("/")
def serve_react():
return send_from_directory(app.static_folder, "index.html")
@app.route("/<path:path>")
def serve_static(path):
requested = FRONTEND_DIST / path
if requested.exists() and requested.is_file():
return send_from_directory(app.static_folder, path)
return send_from_directory(app.static_folder, "index.html")
# ---------- BACKEND API ----------
@app.route("/enhance", methods=["POST"])
def enhance():
if gan_generator is None:
return jsonify({"error": f"GAN generator not loaded: {model_load_error}"}), 500
try:
if "image" not in request.files:
return jsonify({"error": "No image file provided in 'image' field"}), 400
file = request.files["image"]
image = Image.open(file.stream).convert("RGB")
img = gan_generator.generate(image)
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png")
except Exception as e:
print("Error:", e)
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)