| """ |
| ShukGEN v3 β Flask Backend Server (Hugging Face Spaces Edition) |
| Serves both the UI (index.html) and the API on port 7860. |
| Model is auto-loaded from shukgen_v3_final.pth in the same directory at startup. |
| """ |
|
|
| import io, base64, os, sys, traceback, threading, time, math |
| from flask import Flask, request, jsonify, send_from_directory |
| from flask_cors import CORS |
|
|
| try: |
| import numpy as np |
| print(f" numpy {np.__version__} β
") |
| except ImportError as e: |
| print(f"β numpy not found: {e}") |
| raise |
|
|
| try: |
| from PIL import Image, ImageFilter, ImageEnhance, ImageDraw, ImageFont |
| except ImportError: |
| sys.exit("β Pillow not found. Run: pip install Pillow") |
|
|
| try: |
| from scipy.ndimage import gaussian_filter |
| except ImportError: |
| sys.exit("β scipy not found. Run: pip install scipy") |
|
|
| try: |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms |
| except ImportError: |
| sys.exit("β PyTorch not found. Run: pip install torch torchvision") |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| |
| |
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
| @app.route('/') |
| def index(): |
| return send_from_directory(BASE_DIR, 'index.html') |
|
|
|
|
| |
| |
| |
|
|
| def norm(ch): |
| for g in [8, 4, 2, 1]: |
| if ch % g == 0: |
| return nn.GroupNorm(g, ch) |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.net = nn.Sequential( |
| norm(ch), nn.SiLU(), |
| nn.Conv2d(ch, ch, 3, padding=1, bias=False), |
| norm(ch), nn.SiLU(), |
| nn.Conv2d(ch, ch, 3, padding=1, bias=False), |
| ) |
| def forward(self, x): return x + self.net(x) |
|
|
|
|
| class SE(nn.Module): |
| def __init__(self, ch, ratio=8): |
| super().__init__() |
| mid = max(ch // ratio, 4) |
| self.net = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), nn.Flatten(), |
| nn.Linear(ch, mid), nn.SiLU(), |
| nn.Linear(mid, ch), nn.Sigmoid(), |
| ) |
| def forward(self, x): |
| return x * self.net(x).view(x.size(0), x.size(1), 1, 1) |
|
|
|
|
| class AttnBlock(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.norm = norm(ch) |
| self.q = nn.Conv2d(ch, ch, 1) |
| self.k = nn.Conv2d(ch, ch, 1) |
| self.v = nn.Conv2d(ch, ch, 1) |
| self.proj = nn.Conv2d(ch, ch, 1) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| h = self.norm(x) |
| q = self.q(h).reshape(B, C, -1).permute(0, 2, 1) |
| k = self.k(h).reshape(B, C, -1).permute(0, 2, 1) |
| v = self.v(h).reshape(B, C, -1).permute(0, 2, 1) |
| a = F.scaled_dot_product_attention(q, k, v) |
| a = a.permute(0, 2, 1).reshape(B, C, H, W) |
| return x + self.proj(a) |
|
|
|
|
| class DownBlock(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), norm(out_ch), nn.SiLU()) |
| self.res = ResBlock(out_ch) |
| self.se = SE(out_ch) |
| self.down = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1, bias=False) |
|
|
| def forward(self, x): |
| x = self.se(self.res(self.conv(x))) |
| skip = x |
| return self.down(x), skip |
|
|
|
|
| class UpBlock(nn.Module): |
| def __init__(self, in_ch, skip_ch, out_ch): |
| super().__init__() |
| self.up = nn.ConvTranspose2d(in_ch, in_ch, 4, stride=2, padding=1, bias=False) |
| self.conv = nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1, bias=False) |
| self.norm = norm(out_ch) |
| self.act = nn.SiLU() |
| self.res = ResBlock(out_ch) |
|
|
| def forward(self, x, skip): |
| x = self.up(x) |
| if x.shape[2:] != skip.shape[2:]: |
| x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) |
| x = torch.cat([x, skip], dim=1) |
| return self.res(self.act(self.norm(self.conv(x)))) |
|
|
|
|
| class FaceVAE(nn.Module): |
| def __init__(self, latent_dim=512, base_filters=48): |
| super().__init__() |
| f = base_filters |
| self.f = f |
| self.latent = latent_dim |
| self.pool_sz = 4 |
| flat = f*16 * self.pool_sz * self.pool_sz |
|
|
| self.inc = nn.Sequential(nn.Conv2d(3, f, 3, padding=1, bias=False), norm(f), nn.SiLU()) |
| self.down1 = DownBlock(f, f*2) |
| self.down2 = DownBlock(f*2, f*4) |
| self.down3 = DownBlock(f*4, f*8) |
| self.down4 = DownBlock(f*8, f*16) |
| self.btn = nn.Sequential(ResBlock(f*16), AttnBlock(f*16), ResBlock(f*16), SE(f*16)) |
| self.pool = nn.AdaptiveAvgPool2d(self.pool_sz) |
| self.fc_mu = nn.Linear(flat, latent_dim) |
| self.fc_logvar = nn.Linear(flat, latent_dim) |
|
|
| self.fc_dec = nn.Linear(latent_dim, flat) |
| self.up1 = UpBlock(f*16, f*16, f*8) |
| self.up2 = UpBlock(f*8, f*8, f*4) |
| self.up3 = UpBlock(f*4, f*4, f*2) |
| self.up4 = UpBlock(f*2, f*2, f) |
| self.outc = nn.Sequential(norm(f), nn.SiLU(), nn.Conv2d(f, 3, 3, padding=1), nn.Tanh()) |
|
|
| def encode(self, x): |
| x = self.inc(x) |
| x, s1 = self.down1(x) |
| x, s2 = self.down2(x) |
| x, s3 = self.down3(x) |
| x, s4 = self.down4(x) |
| h = self.pool(self.btn(x)).flatten(1) |
| return self.fc_mu(h), self.fc_logvar(h), (s1, s2, s3, s4) |
|
|
| def decode(self, z, skips): |
| s1, s2, s3, s4 = skips |
| h = self.fc_dec(z).view(-1, self.f*16, self.pool_sz, self.pool_sz) |
| h = F.interpolate(h, size=(16,16), mode='bilinear', align_corners=False) |
| h = self.up1(h, s4) |
| h = self.up2(h, s3) |
| h = self.up3(h, s2) |
| h = self.up4(h, s1) |
| return self.outc(h) |
|
|
| def reparameterize(self, mu, logvar): |
| return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar) |
|
|
| def forward(self, x): |
| mu, logvar, skips = self.encode(x) |
| z = self.reparameterize(mu, logvar) |
| out = self.decode(z, skips) |
| return out, mu, logvar |
|
|
| @torch.no_grad() |
| def reconstruct(self, x): |
| self.eval() |
| mu, _, skips = self.encode(x) |
| return self.decode(mu, skips) |
|
|
|
|
| |
| |
| |
|
|
| class StyleTransformBank: |
| STYLE_NAMES = [ |
| "Youthful", "Aged / Mature", "Dramatic Light", "Soft Glow", |
| "Intense / Bold", "Warm Golden Hour", "Cool / Moody", "Sketch / Artistic", |
| ] |
|
|
| @classmethod |
| def _np(cls, img): |
| return np.array(img.convert('RGB'), dtype=np.float32) / 255.0 |
|
|
| @classmethod |
| def _pil(cls, arr): |
| return Image.fromarray((arr.clip(0, 1) * 255).astype(np.uint8)) |
|
|
| @classmethod |
| def _blur(cls, arr, sigma=2.0): |
| return np.stack([gaussian_filter(arr[..., c], sigma=sigma) for c in range(3)], axis=-1) |
|
|
| @classmethod |
| def _saturate(cls, arr, factor): |
| grey = arr.mean(axis=-1, keepdims=True) |
| return (grey + (arr - grey) * factor).clip(0, 1) |
|
|
| @classmethod |
| def _vignette(cls, arr, strength=0.5): |
| H, W = arr.shape[:2] |
| Y, X = np.ogrid[:H, :W] |
| cx, cy = W / 2, H / 2 |
| dist = np.sqrt(((X - cx) / cx) ** 2 + ((Y - cy) / cy) ** 2) |
| mask = 1 - strength * dist.clip(0, 1) |
| return (arr * mask[..., None]).clip(0, 1) |
|
|
| @classmethod |
| def _grain(cls, arr, amount=0.03, seed=0): |
| rng = np.random.default_rng(seed) |
| return (arr + rng.normal(0, amount, arr.shape)).clip(0, 1) |
|
|
| @classmethod |
| def style_0_youthful(cls, img): |
| arr = cls._np(img) |
| arr = cls._saturate(arr, 1.35) |
| arr[..., 0] = (arr[..., 0] * 0.95).clip(0, 1) |
| arr[..., 2] = (arr[..., 2] * 1.12).clip(0, 1) |
| blr = cls._blur(arr, sigma=0.7) |
| arr = arr * 0.88 + blr * 0.12 |
| arr = (arr * 0.9 + 0.06).clip(0, 1) |
| return cls._pil(arr) |
|
|
| @classmethod |
| def style_1_aged_mature(cls, img): |
| arr = cls._np(img) |
| arr = cls._saturate(arr, 0.45) |
| arr[..., 0] = (arr[..., 0] * 1.08 + 0.03).clip(0, 1) |
| arr[..., 2] = (arr[..., 2] * 0.80).clip(0, 1) |
| arr = cls._grain(arr, amount=0.035, seed=42) |
| arr = cls._vignette(arr, 0.55) |
| arr = ((arr - 0.5) * 0.88 + 0.5).clip(0, 1) |
| return cls._pil(arr) |
|
|
| @classmethod |
| def style_2_dramatic_light(cls, img): |
| arr = cls._np(img) |
| H, W = arr.shape[:2] |
| x = np.linspace(0, 1, W) |
| grad = np.power(x, 1.5) |
| arr[..., 0] = (arr[..., 0] * (0.3 + 1.4*grad)).clip(0, 1) |
| arr[..., 1] = (arr[..., 1] * (0.15 + 1.2*grad)).clip(0, 1) |
| arr[..., 2] = (arr[..., 2] * (0.1 + 0.9*grad)).clip(0, 1) |
| arr = ((arr - 0.5) * 3.0 + 0.5).clip(0, 1) |
| arr = cls._vignette(arr, 0.7) |
| return cls._pil(arr) |
|
|
| @classmethod |
| def style_3_soft_glow(cls, img): |
| arr = cls._np(img) |
| blr = cls._blur(arr, sigma=8.0) |
| arr = (arr + blr * 0.55).clip(0, 1) |
| arr[..., 0] = (arr[..., 0] * 1.08).clip(0, 1) |
| arr[..., 2] = (arr[..., 2] * 0.88).clip(0, 1) |
| arr = (arr * 0.85 + 0.08).clip(0, 1) |
| return cls._pil(arr) |
|
|
| @classmethod |
| def style_4_intense_bold(cls, img): |
| arr = cls._np(img) |
| arr = cls._saturate(arr, 2.5) |
| arr = ((arr - 0.5) * 1.8 + 0.5).clip(0, 1) |
| pil = cls._pil(arr) |
| return ImageEnhance.Sharpness(pil).enhance(3.0) |
|
|
| @classmethod |
| def style_5_warm_golden(cls, img): |
| arr = cls._np(img) |
| arr[..., 0] = (arr[..., 0] * 1.25 + 0.05).clip(0, 1) |
| arr[..., 1] = (arr[..., 1] * 1.10 + 0.02).clip(0, 1) |
| arr[..., 2] = (arr[..., 2] * 0.55).clip(0, 1) |
| return cls._pil(arr) |
|
|
| @classmethod |
| def style_6_cool_moody(cls, img): |
| arr = cls._np(img) |
| arr[..., 0] = (arr[..., 0] * 0.72).clip(0, 1) |
| arr[..., 2] = (arr[..., 2] * 1.30 + 0.05).clip(0, 1) |
| arr = cls._grain(arr, amount=0.03, seed=99) |
| arr = cls._vignette(arr, 0.45) |
| return cls._pil(arr) |
|
|
| @classmethod |
| def style_7_sketch_artistic(cls, img): |
| arr = cls._np(img) |
| post = (arr * 4).astype(int) / 4.0 |
| grey = post.mean(axis=-1) |
| blr = gaussian_filter(grey, sigma=1.5) |
| edge = np.abs(grey - blr) * 6.0 |
| canvas = cls._grain(np.ones_like(arr) * 0.92, amount=0.02, seed=12) |
| result = post * canvas * (1 - edge[..., None] * 0.8) |
| return cls._pil(result) |
|
|
| @classmethod |
| def apply(cls, img, style_idx): |
| methods = [ |
| cls.style_0_youthful, cls.style_1_aged_mature, |
| cls.style_2_dramatic_light, cls.style_3_soft_glow, |
| cls.style_4_intense_bold, cls.style_5_warm_golden, |
| cls.style_6_cool_moody, cls.style_7_sketch_artistic, |
| ] |
| return methods[style_idx](img) |
|
|
| @classmethod |
| def apply_with_strength(cls, recon_pil, style_idx, alpha): |
| styled = cls.apply(recon_pil, style_idx) |
| rec_arr = np.array(recon_pil, dtype=np.float32) |
| styled_arr = np.array(styled, dtype=np.float32) |
| blended = rec_arr * (1 - alpha) + styled_arr * alpha |
| return Image.fromarray(blended.clip(0, 255).astype(np.uint8)) |
|
|
|
|
| |
| |
| |
|
|
| def load_model(path, device): |
| ckpt = torch.load(path, map_location=device, weights_only=False) |
| cfg = ckpt['config'] |
| m = FaceVAE(cfg['latent_dim'], cfg['base_filters']).to(device) |
| m.load_state_dict(ckpt['model_state']) |
| m.eval() |
| return m, cfg['image_size'], cfg |
|
|
|
|
| def preprocess_image(pil_img, size): |
| tf = transforms.Compose([ |
| transforms.Resize((size, size), antialias=True), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5]*3, [0.5]*3), |
| ]) |
| return tf(pil_img.convert('RGB')).unsqueeze(0) |
|
|
|
|
| def tensor_to_pil(t): |
| t = t.squeeze(0).cpu().clamp(-1, 1) |
| return transforms.ToPILImage()((t + 1.0) / 2.0) |
|
|
|
|
| def pil_to_b64(pil_img): |
| buf = io.BytesIO() |
| pil_img.convert('RGB').save(buf, format='PNG') |
| return 'data:image/png;base64,' + base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
| def b64_to_pil(data_url): |
| header, data = data_url.split(',', 1) |
| return Image.open(io.BytesIO(base64.b64decode(data))).convert('RGB') |
|
|
|
|
| |
| |
| |
|
|
| STATE = { |
| 'model': None, |
| 'model_config': {}, |
| 'image_size': 256, |
| 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), |
| 'orig_pil': None, |
| 'recon_pil': None, |
| 'styled_pils': [None] * 8, |
| 'mu': None, |
| 'skips': None, |
| } |
|
|
|
|
| |
| |
| |
|
|
| @app.route('/status', methods=['GET']) |
| def status(): |
| dev = STATE['device'] |
| dev_str = f"GPU: {torch.cuda.get_device_name(0)}" if dev.type == 'cuda' else "CPU" |
| return jsonify({ |
| 'model_loaded': STATE['model'] is not None, |
| 'model_config': STATE['model_config'], |
| 'image_loaded': STATE['orig_pil'] is not None, |
| 'device': dev_str, |
| 'torch_version': torch.__version__, |
| }) |
|
|
|
|
| @app.route('/load_model', methods=['POST']) |
| def api_load_model(): |
| data = request.json |
| path = data.get('path', '').strip() |
| if not path: |
| return jsonify({'error': 'No path provided'}), 400 |
| if not os.path.isfile(path): |
| return jsonify({'error': f'File not found: {path}'}), 404 |
| try: |
| model, image_size, cfg = load_model(path, STATE['device']) |
| STATE['model'] = model |
| STATE['image_size'] = image_size |
| STATE['model_config'] = cfg |
| STATE['styled_pils'] = [None] * 8 |
| return jsonify({ |
| 'success': True, |
| 'config': cfg, |
| 'image_size': image_size, |
| 'filename': os.path.basename(path), |
| }) |
| except Exception as e: |
| return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500 |
|
|
|
|
| @app.route('/load_image', methods=['POST']) |
| def api_load_image(): |
| data = request.json |
| if 'image_b64' in data: |
| try: |
| pil = b64_to_pil(data['image_b64']) |
| STATE['orig_pil'] = pil |
| STATE['recon_pil'] = None |
| STATE['styled_pils'] = [None] * 8 |
| STATE['mu'] = None |
| STATE['skips'] = None |
| return jsonify({'success': True, 'preview': pil_to_b64(pil)}) |
| except Exception as e: |
| return jsonify({'error': str(e)}), 500 |
| return jsonify({'error': 'No image data'}), 400 |
|
|
|
|
| @app.route('/generate', methods=['POST']) |
| def api_generate(): |
| """Encode image, produce reconstruction + all 8 styles.""" |
| if STATE['model'] is None: |
| return jsonify({'error': 'No model loaded'}), 400 |
| if STATE['orig_pil'] is None: |
| return jsonify({'error': 'No image loaded'}), 400 |
| try: |
| x = preprocess_image(STATE['orig_pil'], STATE['image_size']).to(STATE['device']) |
| with torch.no_grad(): |
| mu, logvar, skips = STATE['model'].encode(x) |
| recon_t = STATE['model'].decode(mu, skips) |
| recon_pil = tensor_to_pil(recon_t) |
| STATE['recon_pil'] = recon_pil |
| STATE['mu'] = mu |
| STATE['skips'] = skips |
|
|
| styled_b64 = [] |
| for i in range(8): |
| pil = StyleTransformBank.apply(recon_pil, i) |
| STATE['styled_pils'][i] = pil |
| styled_b64.append(pil_to_b64(pil)) |
|
|
| return jsonify({ |
| 'success': True, |
| 'recon': pil_to_b64(recon_pil), |
| 'styles': styled_b64, |
| }) |
| except Exception as e: |
| return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500 |
|
|
|
|
| @app.route('/strength_preview', methods=['POST']) |
| def api_strength_preview(): |
| """Alpha-blend reconstruction with a style.""" |
| if STATE['recon_pil'] is None: |
| return jsonify({'error': 'Generate styles first'}), 400 |
| data = request.json |
| style_idx = int(data.get('style_idx', 0)) |
| alpha = float(data.get('alpha', 1.0)) |
| try: |
| blend = StyleTransformBank.apply_with_strength(STATE['recon_pil'], style_idx, alpha) |
| return jsonify({'success': True, 'blend': pil_to_b64(blend)}) |
| except Exception as e: |
| return jsonify({'error': str(e)}), 500 |
|
|
|
|
| @app.route('/latent_walk', methods=['POST']) |
| def api_latent_walk(): |
| """Latent walk between two styles.""" |
| if STATE['model'] is None: |
| return jsonify({'error': 'No model loaded'}), 400 |
| if STATE['mu'] is None: |
| return jsonify({'error': 'Generate styles first'}), 400 |
| data = request.json |
| style_a = int(data.get('style_a', 0)) |
| style_b = int(data.get('style_b', 1)) |
| steps = int(data.get('steps', 7)) |
| steps = max(2, min(steps, 16)) |
| try: |
| mu = STATE['mu'] |
| skips = STATE['skips'] |
| direction = F.normalize(torch.randn_like(mu), dim=-1) |
| frames_b64 = [] |
| with torch.no_grad(): |
| for i, t in enumerate(torch.linspace(-2.5, 2.5, steps)): |
| z_walk = mu + t.item() * direction |
| rec_t = tensor_to_pil(STATE['model'].decode(z_walk, skips)) |
| alpha = i / (steps - 1) if steps > 1 else 0 |
| styled_a = StyleTransformBank.apply(rec_t, style_a) |
| styled_b = StyleTransformBank.apply(rec_t, style_b) |
| arr_a = np.array(styled_a, dtype=np.float32) |
| arr_b = np.array(styled_b, dtype=np.float32) |
| blended_arr = arr_a * (1 - alpha) + arr_b * alpha |
| frame = Image.fromarray(blended_arr.clip(0, 255).astype(np.uint8)) |
| frames_b64.append({'img': pil_to_b64(frame), 'alpha': round(alpha, 2)}) |
| return jsonify({'success': True, 'frames': frames_b64}) |
| except Exception as e: |
| return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500 |
|
|
|
|
| @app.route('/save_style', methods=['POST']) |
| def api_save_style(): |
| """Return a single styled image as downloadable b64.""" |
| data = request.json |
| idx = int(data.get('style_idx', 0)) |
| if STATE['styled_pils'][idx] is None: |
| return jsonify({'error': 'Style not generated yet'}), 400 |
| return jsonify({'success': True, 'image': pil_to_b64(STATE['styled_pils'][idx])}) |
|
|
|
|
| @app.route('/save_all', methods=['GET']) |
| def api_save_all(): |
| """Return all generated images as b64.""" |
| result = {} |
| if STATE['orig_pil']: |
| result['original'] = pil_to_b64(STATE['orig_pil']) |
| if STATE['recon_pil']: |
| result['reconstruction'] = pil_to_b64(STATE['recon_pil']) |
| for i, pil in enumerate(STATE['styled_pils']): |
| if pil: |
| result[f'style_{i}'] = pil_to_b64(pil) |
| return jsonify({'success': True, 'images': result, |
| 'style_names': StyleTransformBank.STYLE_NAMES}) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| MODEL_PATH = os.path.join(BASE_DIR, 'shukgen_v3_final.pth') |
|
|
| def auto_load_model(): |
| if os.path.isfile(MODEL_PATH): |
| print(f" π Found shukgen_v3_final.pth β loading...") |
| try: |
| model, image_size, cfg = load_model(MODEL_PATH, STATE['device']) |
| STATE['model'] = model |
| STATE['image_size'] = image_size |
| STATE['model_config'] = cfg |
| print(f" β
Model loaded! latent_dim={cfg['latent_dim']} image_size={image_size}") |
| except Exception as e: |
| print(f" β Failed to auto-load model: {e}") |
| traceback.print_exc() |
| else: |
| print(" β οΈ No shukgen_v3_final.pth found in app directory.") |
| print(" Upload shukgen_v3_final.pth to the Space repo and restart.") |
|
|
|
|
| if __name__ == '__main__': |
| print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") |
| print("β ShukGEN v3 β HF Spaces Edition (port 7860) β") |
| print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") |
| print(f" PyTorch : {torch.__version__}") |
| print(f" Device : {'GPU β ' + torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") |
| print() |
| auto_load_model() |
| print() |
| |
| app.run(host='0.0.0.0', port=7860, debug=False) |