""" ShukGEN v3 — Flask Backend Server (Hugging Face Spaces Edition) Serves both the UI (index.html) and the API on port 7860. Model is auto-loaded from shukgen_v3_final.pth in the same directory at startup. """ import io, base64, os, sys, traceback, threading, time, math from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS try: import numpy as np print(f" numpy {np.__version__} ✅") except ImportError as e: print(f"❌ numpy not found: {e}") raise try: from PIL import Image, ImageFilter, ImageEnhance, ImageDraw, ImageFont except ImportError: sys.exit("❌ Pillow not found. Run: pip install Pillow") try: from scipy.ndimage import gaussian_filter except ImportError: sys.exit("❌ scipy not found. Run: pip install scipy") try: import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms except ImportError: sys.exit("❌ PyTorch not found. Run: pip install torch torchvision") app = Flask(__name__) CORS(app) # ───────────────────────────────────────────────────────────────────────────── # SERVE FRONTEND — index.html is served at the root URL # ───────────────────────────────────────────────────────────────────────────── BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @app.route('/') def index(): return send_from_directory(BASE_DIR, 'index.html') # ───────────────────────────────────────────────────────────────────────────── # MODEL DEFINITION (identical to new_gui.py) # ───────────────────────────────────────────────────────────────────────────── def norm(ch): for g in [8, 4, 2, 1]: if ch % g == 0: return nn.GroupNorm(g, ch) class ResBlock(nn.Module): def __init__(self, ch): super().__init__() self.net = nn.Sequential( norm(ch), nn.SiLU(), nn.Conv2d(ch, ch, 3, padding=1, bias=False), norm(ch), nn.SiLU(), nn.Conv2d(ch, ch, 3, padding=1, bias=False), ) def forward(self, x): return x + self.net(x) class SE(nn.Module): def __init__(self, ch, ratio=8): super().__init__() mid = max(ch // ratio, 4) self.net = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(ch, mid), nn.SiLU(), nn.Linear(mid, ch), nn.Sigmoid(), ) def forward(self, x): return x * self.net(x).view(x.size(0), x.size(1), 1, 1) class AttnBlock(nn.Module): def __init__(self, ch): super().__init__() self.norm = norm(ch) self.q = nn.Conv2d(ch, ch, 1) self.k = nn.Conv2d(ch, ch, 1) self.v = nn.Conv2d(ch, ch, 1) self.proj = nn.Conv2d(ch, ch, 1) def forward(self, x): B, C, H, W = x.shape h = self.norm(x) q = self.q(h).reshape(B, C, -1).permute(0, 2, 1) k = self.k(h).reshape(B, C, -1).permute(0, 2, 1) v = self.v(h).reshape(B, C, -1).permute(0, 2, 1) a = F.scaled_dot_product_attention(q, k, v) a = a.permute(0, 2, 1).reshape(B, C, H, W) return x + self.proj(a) class DownBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), norm(out_ch), nn.SiLU()) self.res = ResBlock(out_ch) self.se = SE(out_ch) self.down = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1, bias=False) def forward(self, x): x = self.se(self.res(self.conv(x))) skip = x return self.down(x), skip class UpBlock(nn.Module): def __init__(self, in_ch, skip_ch, out_ch): super().__init__() self.up = nn.ConvTranspose2d(in_ch, in_ch, 4, stride=2, padding=1, bias=False) self.conv = nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1, bias=False) self.norm = norm(out_ch) self.act = nn.SiLU() self.res = ResBlock(out_ch) def forward(self, x, skip): x = self.up(x) if x.shape[2:] != skip.shape[2:]: x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, skip], dim=1) return self.res(self.act(self.norm(self.conv(x)))) class FaceVAE(nn.Module): def __init__(self, latent_dim=512, base_filters=48): super().__init__() f = base_filters self.f = f self.latent = latent_dim self.pool_sz = 4 flat = f*16 * self.pool_sz * self.pool_sz self.inc = nn.Sequential(nn.Conv2d(3, f, 3, padding=1, bias=False), norm(f), nn.SiLU()) self.down1 = DownBlock(f, f*2) self.down2 = DownBlock(f*2, f*4) self.down3 = DownBlock(f*4, f*8) self.down4 = DownBlock(f*8, f*16) self.btn = nn.Sequential(ResBlock(f*16), AttnBlock(f*16), ResBlock(f*16), SE(f*16)) self.pool = nn.AdaptiveAvgPool2d(self.pool_sz) self.fc_mu = nn.Linear(flat, latent_dim) self.fc_logvar = nn.Linear(flat, latent_dim) self.fc_dec = nn.Linear(latent_dim, flat) self.up1 = UpBlock(f*16, f*16, f*8) self.up2 = UpBlock(f*8, f*8, f*4) self.up3 = UpBlock(f*4, f*4, f*2) self.up4 = UpBlock(f*2, f*2, f) self.outc = nn.Sequential(norm(f), nn.SiLU(), nn.Conv2d(f, 3, 3, padding=1), nn.Tanh()) def encode(self, x): x = self.inc(x) x, s1 = self.down1(x) x, s2 = self.down2(x) x, s3 = self.down3(x) x, s4 = self.down4(x) h = self.pool(self.btn(x)).flatten(1) return self.fc_mu(h), self.fc_logvar(h), (s1, s2, s3, s4) def decode(self, z, skips): s1, s2, s3, s4 = skips h = self.fc_dec(z).view(-1, self.f*16, self.pool_sz, self.pool_sz) h = F.interpolate(h, size=(16,16), mode='bilinear', align_corners=False) h = self.up1(h, s4) h = self.up2(h, s3) h = self.up3(h, s2) h = self.up4(h, s1) return self.outc(h) def reparameterize(self, mu, logvar): return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar) def forward(self, x): mu, logvar, skips = self.encode(x) z = self.reparameterize(mu, logvar) out = self.decode(z, skips) return out, mu, logvar @torch.no_grad() def reconstruct(self, x): self.eval() mu, _, skips = self.encode(x) return self.decode(mu, skips) # ───────────────────────────────────────────────────────────────────────────── # STYLE TRANSFORMS (identical to new_gui.py) # ───────────────────────────────────────────────────────────────────────────── class StyleTransformBank: STYLE_NAMES = [ "Youthful", "Aged / Mature", "Dramatic Light", "Soft Glow", "Intense / Bold", "Warm Golden Hour", "Cool / Moody", "Sketch / Artistic", ] @classmethod def _np(cls, img): return np.array(img.convert('RGB'), dtype=np.float32) / 255.0 @classmethod def _pil(cls, arr): return Image.fromarray((arr.clip(0, 1) * 255).astype(np.uint8)) @classmethod def _blur(cls, arr, sigma=2.0): return np.stack([gaussian_filter(arr[..., c], sigma=sigma) for c in range(3)], axis=-1) @classmethod def _saturate(cls, arr, factor): grey = arr.mean(axis=-1, keepdims=True) return (grey + (arr - grey) * factor).clip(0, 1) @classmethod def _vignette(cls, arr, strength=0.5): H, W = arr.shape[:2] Y, X = np.ogrid[:H, :W] cx, cy = W / 2, H / 2 dist = np.sqrt(((X - cx) / cx) ** 2 + ((Y - cy) / cy) ** 2) mask = 1 - strength * dist.clip(0, 1) return (arr * mask[..., None]).clip(0, 1) @classmethod def _grain(cls, arr, amount=0.03, seed=0): rng = np.random.default_rng(seed) return (arr + rng.normal(0, amount, arr.shape)).clip(0, 1) @classmethod def style_0_youthful(cls, img): arr = cls._np(img) arr = cls._saturate(arr, 1.35) arr[..., 0] = (arr[..., 0] * 0.95).clip(0, 1) arr[..., 2] = (arr[..., 2] * 1.12).clip(0, 1) blr = cls._blur(arr, sigma=0.7) arr = arr * 0.88 + blr * 0.12 arr = (arr * 0.9 + 0.06).clip(0, 1) return cls._pil(arr) @classmethod def style_1_aged_mature(cls, img): arr = cls._np(img) arr = cls._saturate(arr, 0.45) arr[..., 0] = (arr[..., 0] * 1.08 + 0.03).clip(0, 1) arr[..., 2] = (arr[..., 2] * 0.80).clip(0, 1) arr = cls._grain(arr, amount=0.035, seed=42) arr = cls._vignette(arr, 0.55) arr = ((arr - 0.5) * 0.88 + 0.5).clip(0, 1) return cls._pil(arr) @classmethod def style_2_dramatic_light(cls, img): arr = cls._np(img) H, W = arr.shape[:2] x = np.linspace(0, 1, W) grad = np.power(x, 1.5) arr[..., 0] = (arr[..., 0] * (0.3 + 1.4*grad)).clip(0, 1) arr[..., 1] = (arr[..., 1] * (0.15 + 1.2*grad)).clip(0, 1) arr[..., 2] = (arr[..., 2] * (0.1 + 0.9*grad)).clip(0, 1) arr = ((arr - 0.5) * 3.0 + 0.5).clip(0, 1) arr = cls._vignette(arr, 0.7) return cls._pil(arr) @classmethod def style_3_soft_glow(cls, img): arr = cls._np(img) blr = cls._blur(arr, sigma=8.0) arr = (arr + blr * 0.55).clip(0, 1) arr[..., 0] = (arr[..., 0] * 1.08).clip(0, 1) arr[..., 2] = (arr[..., 2] * 0.88).clip(0, 1) arr = (arr * 0.85 + 0.08).clip(0, 1) return cls._pil(arr) @classmethod def style_4_intense_bold(cls, img): arr = cls._np(img) arr = cls._saturate(arr, 2.5) arr = ((arr - 0.5) * 1.8 + 0.5).clip(0, 1) pil = cls._pil(arr) return ImageEnhance.Sharpness(pil).enhance(3.0) @classmethod def style_5_warm_golden(cls, img): arr = cls._np(img) arr[..., 0] = (arr[..., 0] * 1.25 + 0.05).clip(0, 1) arr[..., 1] = (arr[..., 1] * 1.10 + 0.02).clip(0, 1) arr[..., 2] = (arr[..., 2] * 0.55).clip(0, 1) return cls._pil(arr) @classmethod def style_6_cool_moody(cls, img): arr = cls._np(img) arr[..., 0] = (arr[..., 0] * 0.72).clip(0, 1) arr[..., 2] = (arr[..., 2] * 1.30 + 0.05).clip(0, 1) arr = cls._grain(arr, amount=0.03, seed=99) arr = cls._vignette(arr, 0.45) return cls._pil(arr) @classmethod def style_7_sketch_artistic(cls, img): arr = cls._np(img) post = (arr * 4).astype(int) / 4.0 grey = post.mean(axis=-1) blr = gaussian_filter(grey, sigma=1.5) edge = np.abs(grey - blr) * 6.0 canvas = cls._grain(np.ones_like(arr) * 0.92, amount=0.02, seed=12) result = post * canvas * (1 - edge[..., None] * 0.8) return cls._pil(result) @classmethod def apply(cls, img, style_idx): methods = [ cls.style_0_youthful, cls.style_1_aged_mature, cls.style_2_dramatic_light, cls.style_3_soft_glow, cls.style_4_intense_bold, cls.style_5_warm_golden, cls.style_6_cool_moody, cls.style_7_sketch_artistic, ] return methods[style_idx](img) @classmethod def apply_with_strength(cls, recon_pil, style_idx, alpha): styled = cls.apply(recon_pil, style_idx) rec_arr = np.array(recon_pil, dtype=np.float32) styled_arr = np.array(styled, dtype=np.float32) blended = rec_arr * (1 - alpha) + styled_arr * alpha return Image.fromarray(blended.clip(0, 255).astype(np.uint8)) # ───────────────────────────────────────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────────────────────────────────────── def load_model(path, device): ckpt = torch.load(path, map_location=device, weights_only=False) cfg = ckpt['config'] m = FaceVAE(cfg['latent_dim'], cfg['base_filters']).to(device) m.load_state_dict(ckpt['model_state']) m.eval() return m, cfg['image_size'], cfg def preprocess_image(pil_img, size): tf = transforms.Compose([ transforms.Resize((size, size), antialias=True), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3), ]) return tf(pil_img.convert('RGB')).unsqueeze(0) def tensor_to_pil(t): t = t.squeeze(0).cpu().clamp(-1, 1) return transforms.ToPILImage()((t + 1.0) / 2.0) def pil_to_b64(pil_img): buf = io.BytesIO() pil_img.convert('RGB').save(buf, format='PNG') return 'data:image/png;base64,' + base64.b64encode(buf.getvalue()).decode() def b64_to_pil(data_url): header, data = data_url.split(',', 1) return Image.open(io.BytesIO(base64.b64decode(data))).convert('RGB') # ───────────────────────────────────────────────────────────────────────────── # APP STATE # ───────────────────────────────────────────────────────────────────────────── STATE = { 'model': None, 'model_config': {}, 'image_size': 256, 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 'orig_pil': None, 'recon_pil': None, 'styled_pils': [None] * 8, 'mu': None, 'skips': None, } # ───────────────────────────────────────────────────────────────────────────── # ROUTES # ───────────────────────────────────────────────────────────────────────────── @app.route('/status', methods=['GET']) def status(): dev = STATE['device'] dev_str = f"GPU: {torch.cuda.get_device_name(0)}" if dev.type == 'cuda' else "CPU" return jsonify({ 'model_loaded': STATE['model'] is not None, 'model_config': STATE['model_config'], 'image_loaded': STATE['orig_pil'] is not None, 'device': dev_str, 'torch_version': torch.__version__, }) @app.route('/load_model', methods=['POST']) def api_load_model(): data = request.json path = data.get('path', '').strip() if not path: return jsonify({'error': 'No path provided'}), 400 if not os.path.isfile(path): return jsonify({'error': f'File not found: {path}'}), 404 try: model, image_size, cfg = load_model(path, STATE['device']) STATE['model'] = model STATE['image_size'] = image_size STATE['model_config'] = cfg STATE['styled_pils'] = [None] * 8 return jsonify({ 'success': True, 'config': cfg, 'image_size': image_size, 'filename': os.path.basename(path), }) except Exception as e: return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500 @app.route('/load_image', methods=['POST']) def api_load_image(): data = request.json if 'image_b64' in data: try: pil = b64_to_pil(data['image_b64']) STATE['orig_pil'] = pil STATE['recon_pil'] = None STATE['styled_pils'] = [None] * 8 STATE['mu'] = None STATE['skips'] = None return jsonify({'success': True, 'preview': pil_to_b64(pil)}) except Exception as e: return jsonify({'error': str(e)}), 500 return jsonify({'error': 'No image data'}), 400 @app.route('/generate', methods=['POST']) def api_generate(): """Encode image, produce reconstruction + all 8 styles.""" if STATE['model'] is None: return jsonify({'error': 'No model loaded'}), 400 if STATE['orig_pil'] is None: return jsonify({'error': 'No image loaded'}), 400 try: x = preprocess_image(STATE['orig_pil'], STATE['image_size']).to(STATE['device']) with torch.no_grad(): mu, logvar, skips = STATE['model'].encode(x) recon_t = STATE['model'].decode(mu, skips) recon_pil = tensor_to_pil(recon_t) STATE['recon_pil'] = recon_pil STATE['mu'] = mu STATE['skips'] = skips styled_b64 = [] for i in range(8): pil = StyleTransformBank.apply(recon_pil, i) STATE['styled_pils'][i] = pil styled_b64.append(pil_to_b64(pil)) return jsonify({ 'success': True, 'recon': pil_to_b64(recon_pil), 'styles': styled_b64, }) except Exception as e: return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500 @app.route('/strength_preview', methods=['POST']) def api_strength_preview(): """Alpha-blend reconstruction with a style.""" if STATE['recon_pil'] is None: return jsonify({'error': 'Generate styles first'}), 400 data = request.json style_idx = int(data.get('style_idx', 0)) alpha = float(data.get('alpha', 1.0)) try: blend = StyleTransformBank.apply_with_strength(STATE['recon_pil'], style_idx, alpha) return jsonify({'success': True, 'blend': pil_to_b64(blend)}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/latent_walk', methods=['POST']) def api_latent_walk(): """Latent walk between two styles.""" if STATE['model'] is None: return jsonify({'error': 'No model loaded'}), 400 if STATE['mu'] is None: return jsonify({'error': 'Generate styles first'}), 400 data = request.json style_a = int(data.get('style_a', 0)) style_b = int(data.get('style_b', 1)) steps = int(data.get('steps', 7)) steps = max(2, min(steps, 16)) try: mu = STATE['mu'] skips = STATE['skips'] direction = F.normalize(torch.randn_like(mu), dim=-1) frames_b64 = [] with torch.no_grad(): for i, t in enumerate(torch.linspace(-2.5, 2.5, steps)): z_walk = mu + t.item() * direction rec_t = tensor_to_pil(STATE['model'].decode(z_walk, skips)) alpha = i / (steps - 1) if steps > 1 else 0 styled_a = StyleTransformBank.apply(rec_t, style_a) styled_b = StyleTransformBank.apply(rec_t, style_b) arr_a = np.array(styled_a, dtype=np.float32) arr_b = np.array(styled_b, dtype=np.float32) blended_arr = arr_a * (1 - alpha) + arr_b * alpha frame = Image.fromarray(blended_arr.clip(0, 255).astype(np.uint8)) frames_b64.append({'img': pil_to_b64(frame), 'alpha': round(alpha, 2)}) return jsonify({'success': True, 'frames': frames_b64}) except Exception as e: return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500 @app.route('/save_style', methods=['POST']) def api_save_style(): """Return a single styled image as downloadable b64.""" data = request.json idx = int(data.get('style_idx', 0)) if STATE['styled_pils'][idx] is None: return jsonify({'error': 'Style not generated yet'}), 400 return jsonify({'success': True, 'image': pil_to_b64(STATE['styled_pils'][idx])}) @app.route('/save_all', methods=['GET']) def api_save_all(): """Return all generated images as b64.""" result = {} if STATE['orig_pil']: result['original'] = pil_to_b64(STATE['orig_pil']) if STATE['recon_pil']: result['reconstruction'] = pil_to_b64(STATE['recon_pil']) for i, pil in enumerate(STATE['styled_pils']): if pil: result[f'style_{i}'] = pil_to_b64(pil) return jsonify({'success': True, 'images': result, 'style_names': StyleTransformBank.STYLE_NAMES}) # ───────────────────────────────────────────────────────────────────────────── # AUTO-LOAD MODEL ON STARTUP # Place your model file as shukgen_v3_final.pth in the same directory as this script. # It will be loaded automatically when the Space starts. # ───────────────────────────────────────────────────────────────────────────── MODEL_PATH = os.path.join(BASE_DIR, 'shukgen_v3_final.pth') def auto_load_model(): if os.path.isfile(MODEL_PATH): print(f" 🔍 Found shukgen_v3_final.pth — loading...") try: model, image_size, cfg = load_model(MODEL_PATH, STATE['device']) STATE['model'] = model STATE['image_size'] = image_size STATE['model_config'] = cfg print(f" ✅ Model loaded! latent_dim={cfg['latent_dim']} image_size={image_size}") except Exception as e: print(f" ❌ Failed to auto-load model: {e}") traceback.print_exc() else: print(" ⚠️ No shukgen_v3_final.pth found in app directory.") print(" Upload shukgen_v3_final.pth to the Space repo and restart.") if __name__ == '__main__': print("╔══════════════════════════════════════════════════════╗") print("║ ShukGEN v3 — HF Spaces Edition (port 7860) ║") print("╚══════════════════════════════════════════════════════╝") print(f" PyTorch : {torch.__version__}") print(f" Device : {'GPU — ' + torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") print() auto_load_model() print() # host=0.0.0.0 is REQUIRED for HF Spaces — do not change to 127.0.0.1 app.run(host='0.0.0.0', port=7860, debug=False)