ShukGEN / app.py
shukdev3's picture
Update app.py
59ce6a0 verified
Raw
History Blame Contribute Delete
23.6 kB
"""
ShukGEN v3 β€” Flask Backend Server (Hugging Face Spaces Edition)
Serves both the UI (index.html) and the API on port 7860.
Model is auto-loaded from shukgen_v3_final.pth in the same directory at startup.
"""
import io, base64, os, sys, traceback, threading, time, math
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
try:
import numpy as np
print(f" numpy {np.__version__} βœ…")
except ImportError as e:
print(f"❌ numpy not found: {e}")
raise
try:
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw, ImageFont
except ImportError:
sys.exit("❌ Pillow not found. Run: pip install Pillow")
try:
from scipy.ndimage import gaussian_filter
except ImportError:
sys.exit("❌ scipy not found. Run: pip install scipy")
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
except ImportError:
sys.exit("❌ PyTorch not found. Run: pip install torch torchvision")
app = Flask(__name__)
CORS(app)
# ─────────────────────────────────────────────────────────────────────────────
# SERVE FRONTEND β€” index.html is served at the root URL
# ─────────────────────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@app.route('/')
def index():
return send_from_directory(BASE_DIR, 'index.html')
# ─────────────────────────────────────────────────────────────────────────────
# MODEL DEFINITION (identical to new_gui.py)
# ─────────────────────────────────────────────────────────────────────────────
def norm(ch):
for g in [8, 4, 2, 1]:
if ch % g == 0:
return nn.GroupNorm(g, ch)
class ResBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.net = nn.Sequential(
norm(ch), nn.SiLU(),
nn.Conv2d(ch, ch, 3, padding=1, bias=False),
norm(ch), nn.SiLU(),
nn.Conv2d(ch, ch, 3, padding=1, bias=False),
)
def forward(self, x): return x + self.net(x)
class SE(nn.Module):
def __init__(self, ch, ratio=8):
super().__init__()
mid = max(ch // ratio, 4)
self.net = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
nn.Linear(ch, mid), nn.SiLU(),
nn.Linear(mid, ch), nn.Sigmoid(),
)
def forward(self, x):
return x * self.net(x).view(x.size(0), x.size(1), 1, 1)
class AttnBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.norm = norm(ch)
self.q = nn.Conv2d(ch, ch, 1)
self.k = nn.Conv2d(ch, ch, 1)
self.v = nn.Conv2d(ch, ch, 1)
self.proj = nn.Conv2d(ch, ch, 1)
def forward(self, x):
B, C, H, W = x.shape
h = self.norm(x)
q = self.q(h).reshape(B, C, -1).permute(0, 2, 1)
k = self.k(h).reshape(B, C, -1).permute(0, 2, 1)
v = self.v(h).reshape(B, C, -1).permute(0, 2, 1)
a = F.scaled_dot_product_attention(q, k, v)
a = a.permute(0, 2, 1).reshape(B, C, H, W)
return x + self.proj(a)
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), norm(out_ch), nn.SiLU())
self.res = ResBlock(out_ch)
self.se = SE(out_ch)
self.down = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1, bias=False)
def forward(self, x):
x = self.se(self.res(self.conv(x)))
skip = x
return self.down(x), skip
class UpBlock(nn.Module):
def __init__(self, in_ch, skip_ch, out_ch):
super().__init__()
self.up = nn.ConvTranspose2d(in_ch, in_ch, 4, stride=2, padding=1, bias=False)
self.conv = nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1, bias=False)
self.norm = norm(out_ch)
self.act = nn.SiLU()
self.res = ResBlock(out_ch)
def forward(self, x, skip):
x = self.up(x)
if x.shape[2:] != skip.shape[2:]:
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, skip], dim=1)
return self.res(self.act(self.norm(self.conv(x))))
class FaceVAE(nn.Module):
def __init__(self, latent_dim=512, base_filters=48):
super().__init__()
f = base_filters
self.f = f
self.latent = latent_dim
self.pool_sz = 4
flat = f*16 * self.pool_sz * self.pool_sz
self.inc = nn.Sequential(nn.Conv2d(3, f, 3, padding=1, bias=False), norm(f), nn.SiLU())
self.down1 = DownBlock(f, f*2)
self.down2 = DownBlock(f*2, f*4)
self.down3 = DownBlock(f*4, f*8)
self.down4 = DownBlock(f*8, f*16)
self.btn = nn.Sequential(ResBlock(f*16), AttnBlock(f*16), ResBlock(f*16), SE(f*16))
self.pool = nn.AdaptiveAvgPool2d(self.pool_sz)
self.fc_mu = nn.Linear(flat, latent_dim)
self.fc_logvar = nn.Linear(flat, latent_dim)
self.fc_dec = nn.Linear(latent_dim, flat)
self.up1 = UpBlock(f*16, f*16, f*8)
self.up2 = UpBlock(f*8, f*8, f*4)
self.up3 = UpBlock(f*4, f*4, f*2)
self.up4 = UpBlock(f*2, f*2, f)
self.outc = nn.Sequential(norm(f), nn.SiLU(), nn.Conv2d(f, 3, 3, padding=1), nn.Tanh())
def encode(self, x):
x = self.inc(x)
x, s1 = self.down1(x)
x, s2 = self.down2(x)
x, s3 = self.down3(x)
x, s4 = self.down4(x)
h = self.pool(self.btn(x)).flatten(1)
return self.fc_mu(h), self.fc_logvar(h), (s1, s2, s3, s4)
def decode(self, z, skips):
s1, s2, s3, s4 = skips
h = self.fc_dec(z).view(-1, self.f*16, self.pool_sz, self.pool_sz)
h = F.interpolate(h, size=(16,16), mode='bilinear', align_corners=False)
h = self.up1(h, s4)
h = self.up2(h, s3)
h = self.up3(h, s2)
h = self.up4(h, s1)
return self.outc(h)
def reparameterize(self, mu, logvar):
return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
def forward(self, x):
mu, logvar, skips = self.encode(x)
z = self.reparameterize(mu, logvar)
out = self.decode(z, skips)
return out, mu, logvar
@torch.no_grad()
def reconstruct(self, x):
self.eval()
mu, _, skips = self.encode(x)
return self.decode(mu, skips)
# ─────────────────────────────────────────────────────────────────────────────
# STYLE TRANSFORMS (identical to new_gui.py)
# ─────────────────────────────────────────────────────────────────────────────
class StyleTransformBank:
STYLE_NAMES = [
"Youthful", "Aged / Mature", "Dramatic Light", "Soft Glow",
"Intense / Bold", "Warm Golden Hour", "Cool / Moody", "Sketch / Artistic",
]
@classmethod
def _np(cls, img):
return np.array(img.convert('RGB'), dtype=np.float32) / 255.0
@classmethod
def _pil(cls, arr):
return Image.fromarray((arr.clip(0, 1) * 255).astype(np.uint8))
@classmethod
def _blur(cls, arr, sigma=2.0):
return np.stack([gaussian_filter(arr[..., c], sigma=sigma) for c in range(3)], axis=-1)
@classmethod
def _saturate(cls, arr, factor):
grey = arr.mean(axis=-1, keepdims=True)
return (grey + (arr - grey) * factor).clip(0, 1)
@classmethod
def _vignette(cls, arr, strength=0.5):
H, W = arr.shape[:2]
Y, X = np.ogrid[:H, :W]
cx, cy = W / 2, H / 2
dist = np.sqrt(((X - cx) / cx) ** 2 + ((Y - cy) / cy) ** 2)
mask = 1 - strength * dist.clip(0, 1)
return (arr * mask[..., None]).clip(0, 1)
@classmethod
def _grain(cls, arr, amount=0.03, seed=0):
rng = np.random.default_rng(seed)
return (arr + rng.normal(0, amount, arr.shape)).clip(0, 1)
@classmethod
def style_0_youthful(cls, img):
arr = cls._np(img)
arr = cls._saturate(arr, 1.35)
arr[..., 0] = (arr[..., 0] * 0.95).clip(0, 1)
arr[..., 2] = (arr[..., 2] * 1.12).clip(0, 1)
blr = cls._blur(arr, sigma=0.7)
arr = arr * 0.88 + blr * 0.12
arr = (arr * 0.9 + 0.06).clip(0, 1)
return cls._pil(arr)
@classmethod
def style_1_aged_mature(cls, img):
arr = cls._np(img)
arr = cls._saturate(arr, 0.45)
arr[..., 0] = (arr[..., 0] * 1.08 + 0.03).clip(0, 1)
arr[..., 2] = (arr[..., 2] * 0.80).clip(0, 1)
arr = cls._grain(arr, amount=0.035, seed=42)
arr = cls._vignette(arr, 0.55)
arr = ((arr - 0.5) * 0.88 + 0.5).clip(0, 1)
return cls._pil(arr)
@classmethod
def style_2_dramatic_light(cls, img):
arr = cls._np(img)
H, W = arr.shape[:2]
x = np.linspace(0, 1, W)
grad = np.power(x, 1.5)
arr[..., 0] = (arr[..., 0] * (0.3 + 1.4*grad)).clip(0, 1)
arr[..., 1] = (arr[..., 1] * (0.15 + 1.2*grad)).clip(0, 1)
arr[..., 2] = (arr[..., 2] * (0.1 + 0.9*grad)).clip(0, 1)
arr = ((arr - 0.5) * 3.0 + 0.5).clip(0, 1)
arr = cls._vignette(arr, 0.7)
return cls._pil(arr)
@classmethod
def style_3_soft_glow(cls, img):
arr = cls._np(img)
blr = cls._blur(arr, sigma=8.0)
arr = (arr + blr * 0.55).clip(0, 1)
arr[..., 0] = (arr[..., 0] * 1.08).clip(0, 1)
arr[..., 2] = (arr[..., 2] * 0.88).clip(0, 1)
arr = (arr * 0.85 + 0.08).clip(0, 1)
return cls._pil(arr)
@classmethod
def style_4_intense_bold(cls, img):
arr = cls._np(img)
arr = cls._saturate(arr, 2.5)
arr = ((arr - 0.5) * 1.8 + 0.5).clip(0, 1)
pil = cls._pil(arr)
return ImageEnhance.Sharpness(pil).enhance(3.0)
@classmethod
def style_5_warm_golden(cls, img):
arr = cls._np(img)
arr[..., 0] = (arr[..., 0] * 1.25 + 0.05).clip(0, 1)
arr[..., 1] = (arr[..., 1] * 1.10 + 0.02).clip(0, 1)
arr[..., 2] = (arr[..., 2] * 0.55).clip(0, 1)
return cls._pil(arr)
@classmethod
def style_6_cool_moody(cls, img):
arr = cls._np(img)
arr[..., 0] = (arr[..., 0] * 0.72).clip(0, 1)
arr[..., 2] = (arr[..., 2] * 1.30 + 0.05).clip(0, 1)
arr = cls._grain(arr, amount=0.03, seed=99)
arr = cls._vignette(arr, 0.45)
return cls._pil(arr)
@classmethod
def style_7_sketch_artistic(cls, img):
arr = cls._np(img)
post = (arr * 4).astype(int) / 4.0
grey = post.mean(axis=-1)
blr = gaussian_filter(grey, sigma=1.5)
edge = np.abs(grey - blr) * 6.0
canvas = cls._grain(np.ones_like(arr) * 0.92, amount=0.02, seed=12)
result = post * canvas * (1 - edge[..., None] * 0.8)
return cls._pil(result)
@classmethod
def apply(cls, img, style_idx):
methods = [
cls.style_0_youthful, cls.style_1_aged_mature,
cls.style_2_dramatic_light, cls.style_3_soft_glow,
cls.style_4_intense_bold, cls.style_5_warm_golden,
cls.style_6_cool_moody, cls.style_7_sketch_artistic,
]
return methods[style_idx](img)
@classmethod
def apply_with_strength(cls, recon_pil, style_idx, alpha):
styled = cls.apply(recon_pil, style_idx)
rec_arr = np.array(recon_pil, dtype=np.float32)
styled_arr = np.array(styled, dtype=np.float32)
blended = rec_arr * (1 - alpha) + styled_arr * alpha
return Image.fromarray(blended.clip(0, 255).astype(np.uint8))
# ─────────────────────────────────────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def load_model(path, device):
ckpt = torch.load(path, map_location=device, weights_only=False)
cfg = ckpt['config']
m = FaceVAE(cfg['latent_dim'], cfg['base_filters']).to(device)
m.load_state_dict(ckpt['model_state'])
m.eval()
return m, cfg['image_size'], cfg
def preprocess_image(pil_img, size):
tf = transforms.Compose([
transforms.Resize((size, size), antialias=True),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
return tf(pil_img.convert('RGB')).unsqueeze(0)
def tensor_to_pil(t):
t = t.squeeze(0).cpu().clamp(-1, 1)
return transforms.ToPILImage()((t + 1.0) / 2.0)
def pil_to_b64(pil_img):
buf = io.BytesIO()
pil_img.convert('RGB').save(buf, format='PNG')
return 'data:image/png;base64,' + base64.b64encode(buf.getvalue()).decode()
def b64_to_pil(data_url):
header, data = data_url.split(',', 1)
return Image.open(io.BytesIO(base64.b64decode(data))).convert('RGB')
# ─────────────────────────────────────────────────────────────────────────────
# APP STATE
# ─────────────────────────────────────────────────────────────────────────────
STATE = {
'model': None,
'model_config': {},
'image_size': 256,
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
'orig_pil': None,
'recon_pil': None,
'styled_pils': [None] * 8,
'mu': None,
'skips': None,
}
# ─────────────────────────────────────────────────────────────────────────────
# ROUTES
# ─────────────────────────────────────────────────────────────────────────────
@app.route('/status', methods=['GET'])
def status():
dev = STATE['device']
dev_str = f"GPU: {torch.cuda.get_device_name(0)}" if dev.type == 'cuda' else "CPU"
return jsonify({
'model_loaded': STATE['model'] is not None,
'model_config': STATE['model_config'],
'image_loaded': STATE['orig_pil'] is not None,
'device': dev_str,
'torch_version': torch.__version__,
})
@app.route('/load_model', methods=['POST'])
def api_load_model():
data = request.json
path = data.get('path', '').strip()
if not path:
return jsonify({'error': 'No path provided'}), 400
if not os.path.isfile(path):
return jsonify({'error': f'File not found: {path}'}), 404
try:
model, image_size, cfg = load_model(path, STATE['device'])
STATE['model'] = model
STATE['image_size'] = image_size
STATE['model_config'] = cfg
STATE['styled_pils'] = [None] * 8
return jsonify({
'success': True,
'config': cfg,
'image_size': image_size,
'filename': os.path.basename(path),
})
except Exception as e:
return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500
@app.route('/load_image', methods=['POST'])
def api_load_image():
data = request.json
if 'image_b64' in data:
try:
pil = b64_to_pil(data['image_b64'])
STATE['orig_pil'] = pil
STATE['recon_pil'] = None
STATE['styled_pils'] = [None] * 8
STATE['mu'] = None
STATE['skips'] = None
return jsonify({'success': True, 'preview': pil_to_b64(pil)})
except Exception as e:
return jsonify({'error': str(e)}), 500
return jsonify({'error': 'No image data'}), 400
@app.route('/generate', methods=['POST'])
def api_generate():
"""Encode image, produce reconstruction + all 8 styles."""
if STATE['model'] is None:
return jsonify({'error': 'No model loaded'}), 400
if STATE['orig_pil'] is None:
return jsonify({'error': 'No image loaded'}), 400
try:
x = preprocess_image(STATE['orig_pil'], STATE['image_size']).to(STATE['device'])
with torch.no_grad():
mu, logvar, skips = STATE['model'].encode(x)
recon_t = STATE['model'].decode(mu, skips)
recon_pil = tensor_to_pil(recon_t)
STATE['recon_pil'] = recon_pil
STATE['mu'] = mu
STATE['skips'] = skips
styled_b64 = []
for i in range(8):
pil = StyleTransformBank.apply(recon_pil, i)
STATE['styled_pils'][i] = pil
styled_b64.append(pil_to_b64(pil))
return jsonify({
'success': True,
'recon': pil_to_b64(recon_pil),
'styles': styled_b64,
})
except Exception as e:
return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500
@app.route('/strength_preview', methods=['POST'])
def api_strength_preview():
"""Alpha-blend reconstruction with a style."""
if STATE['recon_pil'] is None:
return jsonify({'error': 'Generate styles first'}), 400
data = request.json
style_idx = int(data.get('style_idx', 0))
alpha = float(data.get('alpha', 1.0))
try:
blend = StyleTransformBank.apply_with_strength(STATE['recon_pil'], style_idx, alpha)
return jsonify({'success': True, 'blend': pil_to_b64(blend)})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/latent_walk', methods=['POST'])
def api_latent_walk():
"""Latent walk between two styles."""
if STATE['model'] is None:
return jsonify({'error': 'No model loaded'}), 400
if STATE['mu'] is None:
return jsonify({'error': 'Generate styles first'}), 400
data = request.json
style_a = int(data.get('style_a', 0))
style_b = int(data.get('style_b', 1))
steps = int(data.get('steps', 7))
steps = max(2, min(steps, 16))
try:
mu = STATE['mu']
skips = STATE['skips']
direction = F.normalize(torch.randn_like(mu), dim=-1)
frames_b64 = []
with torch.no_grad():
for i, t in enumerate(torch.linspace(-2.5, 2.5, steps)):
z_walk = mu + t.item() * direction
rec_t = tensor_to_pil(STATE['model'].decode(z_walk, skips))
alpha = i / (steps - 1) if steps > 1 else 0
styled_a = StyleTransformBank.apply(rec_t, style_a)
styled_b = StyleTransformBank.apply(rec_t, style_b)
arr_a = np.array(styled_a, dtype=np.float32)
arr_b = np.array(styled_b, dtype=np.float32)
blended_arr = arr_a * (1 - alpha) + arr_b * alpha
frame = Image.fromarray(blended_arr.clip(0, 255).astype(np.uint8))
frames_b64.append({'img': pil_to_b64(frame), 'alpha': round(alpha, 2)})
return jsonify({'success': True, 'frames': frames_b64})
except Exception as e:
return jsonify({'error': str(e), 'traceback': traceback.format_exc()}), 500
@app.route('/save_style', methods=['POST'])
def api_save_style():
"""Return a single styled image as downloadable b64."""
data = request.json
idx = int(data.get('style_idx', 0))
if STATE['styled_pils'][idx] is None:
return jsonify({'error': 'Style not generated yet'}), 400
return jsonify({'success': True, 'image': pil_to_b64(STATE['styled_pils'][idx])})
@app.route('/save_all', methods=['GET'])
def api_save_all():
"""Return all generated images as b64."""
result = {}
if STATE['orig_pil']:
result['original'] = pil_to_b64(STATE['orig_pil'])
if STATE['recon_pil']:
result['reconstruction'] = pil_to_b64(STATE['recon_pil'])
for i, pil in enumerate(STATE['styled_pils']):
if pil:
result[f'style_{i}'] = pil_to_b64(pil)
return jsonify({'success': True, 'images': result,
'style_names': StyleTransformBank.STYLE_NAMES})
# ─────────────────────────────────────────────────────────────────────────────
# AUTO-LOAD MODEL ON STARTUP
# Place your model file as shukgen_v3_final.pth in the same directory as this script.
# It will be loaded automatically when the Space starts.
# ─────────────────────────────────────────────────────────────────────────────
MODEL_PATH = os.path.join(BASE_DIR, 'shukgen_v3_final.pth')
def auto_load_model():
if os.path.isfile(MODEL_PATH):
print(f" πŸ” Found shukgen_v3_final.pth β€” loading...")
try:
model, image_size, cfg = load_model(MODEL_PATH, STATE['device'])
STATE['model'] = model
STATE['image_size'] = image_size
STATE['model_config'] = cfg
print(f" βœ… Model loaded! latent_dim={cfg['latent_dim']} image_size={image_size}")
except Exception as e:
print(f" ❌ Failed to auto-load model: {e}")
traceback.print_exc()
else:
print(" ⚠️ No shukgen_v3_final.pth found in app directory.")
print(" Upload shukgen_v3_final.pth to the Space repo and restart.")
if __name__ == '__main__':
print("╔══════════════════════════════════════════════════════╗")
print("β•‘ ShukGEN v3 β€” HF Spaces Edition (port 7860) β•‘")
print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
print(f" PyTorch : {torch.__version__}")
print(f" Device : {'GPU β€” ' + torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print()
auto_load_model()
print()
# host=0.0.0.0 is REQUIRED for HF Spaces β€” do not change to 127.0.0.1
app.run(host='0.0.0.0', port=7860, debug=False)