| import os |
| import io |
| import uuid |
| import requests |
| import gradio as gr |
| import numpy as np |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
| FONT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Inter-Bold.ttf") |
|
|
| DLSS_API_URL = os.environ.get("DLSS_API_URL", "https://dlss.alphanetplus.com") |
| OSS_API_URL = os.environ.get("OSS_API_URL", "https://oss.alphanetplus.com") |
| API_KEY = os.environ.get("DLSS_API_KEY", "") |
|
|
|
|
| def get_font(size): |
| try: |
| return ImageFont.truetype(FONT_PATH, size) |
| except Exception: |
| return ImageFont.load_default(size=size) |
|
|
|
|
| def create_dlss5_comparison(original: Image.Image, enhanced: Image.Image) -> Image.Image: |
| w, h = original.size |
| enhanced = enhanced.resize((w, h), Image.LANCZOS) |
| canvas = Image.new("RGB", (w * 2, h)) |
| canvas.paste(original, (0, 0)) |
| canvas.paste(enhanced, (w, 0)) |
|
|
| overlay = Image.new("RGBA", (w * 2, h), (0, 0, 0, 0)) |
| draw = ImageDraw.Draw(overlay) |
| font_size = max(16, int(h * 0.076)) |
| font = get_font(font_size) |
| pad_x = int(font_size * 1.0) |
| pad_y = int(font_size * 0.55) |
|
|
| def draw_label(text, center_x, bottom_y, dark=True, green_bar=False): |
| bbox = font.getbbox(text) |
| tw = bbox[2] - bbox[0] |
| th = bbox[3] - bbox[1] |
| lw = tw + 2 * pad_x |
| lh = th + 2 * pad_y |
| gh = max(4, int(lh * 0.13)) if green_bar else 0 |
| x = center_x - lw // 2 |
| y = bottom_y - lh - gh |
| if dark: |
| draw.rectangle([x, y, x + lw, y + lh], fill=(10, 10, 10, 225), outline=(75, 75, 75, 255), width=1) |
| draw.text((x + lw // 2, y + lh // 2), text, fill=(255, 255, 255, 255), font=font, anchor="mm") |
| else: |
| draw.rectangle([x, y, x + lw, y + lh], fill=(255, 255, 255, 255), outline=(190, 190, 190, 255), width=1) |
| draw.text((x + lw // 2, y + lh // 2), text, fill=(0, 0, 0, 255), font=font, anchor="mm") |
| if green_bar: |
| draw.rectangle([x, y + lh, x + lw, y + lh + gh], fill=(118, 185, 0, 255)) |
|
|
| margin_bottom = int(h * 0.06) |
| draw_label("DLSS 5 Off", w // 2, h - margin_bottom, dark=True) |
| draw_label("DLSS 5 On", w + w // 2, h - margin_bottom, dark=False, green_bar=True) |
|
|
| canvas = Image.alpha_composite(canvas.convert("RGBA"), overlay) |
| return canvas.convert("RGB") |
|
|
|
|
| def _pil_to_bytes(image: Image.Image) -> bytes: |
| buf = io.BytesIO() |
| image.convert("RGB").save(buf, format="JPEG", quality=90) |
| return buf.getvalue() |
|
|
|
|
| def _upload_image(image: Image.Image, oss_key: str, env: str = "prod") -> str: |
| """上传 PIL Image 到 OSS,返回 oss_key""" |
| img_bytes = _pil_to_bytes(image) |
| resp = requests.post( |
| f"{OSS_API_URL}/upload", |
| params={"env": env, "key": oss_key}, |
| headers={"X-API-Key": API_KEY}, |
| files={"file": (os.path.basename(oss_key), img_bytes, "image/jpeg")}, |
| timeout=60, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
| if not data.get("success"): |
| raise RuntimeError(f"OSS upload failed: {data}") |
| return oss_key |
|
|
|
|
| def _download_image(oss_key: str, env: str = "prod") -> Image.Image: |
| """从 OSS 下载图像""" |
| resp = requests.get( |
| f"{OSS_API_URL}/download", |
| params={"env": env, "key": oss_key}, |
| headers={"X-API-Key": API_KEY}, |
| timeout=60, |
| ) |
| resp.raise_for_status() |
| return Image.open(io.BytesIO(resp.content)).convert("RGB") |
|
|
|
|
| def process(image, prompt, seed=42, randomize_seed=True, num_inference_steps=4, |
| progress=gr.Progress(track_tqdm=True)): |
| if image is None: |
| raise gr.Error("Please upload an image!") |
| if not API_KEY: |
| raise gr.Error("DLSS_API_KEY environment variable is not set!") |
|
|
| job_id = uuid.uuid4().hex[:12] |
| input_key = f"dlss/input/{job_id}.jpg" |
|
|
| progress(0.15, desc="Uploading image to OSS...") |
| _upload_image(image, input_key, env="prod") |
|
|
| progress(0.35, desc="Generating DLSS 5 version...") |
| resp = requests.post( |
| f"{DLSS_API_URL}/generate", |
| headers={"X-API-Key": API_KEY, "Content-Type": "application/json"}, |
| json={ |
| "image": f"oss://{input_key}", |
| "prompt": prompt, |
| "seed": seed, |
| "randomize_seed": randomize_seed, |
| "num_inference_steps": num_inference_steps, |
| "output_format": "oss", |
| }, |
| timeout=120, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
|
|
| if not data.get("success"): |
| raise gr.Error(f"DLSS generation failed: {data.get('error', 'unknown error')}") |
|
|
| actual_seed = data.get("seed", seed) |
| enhanced_oss_key = data.get("enhanced", "") |
|
|
| if not enhanced_oss_key: |
| raise gr.Error("No enhanced image returned from API") |
|
|
| if enhanced_oss_key.startswith("oss://"): |
| enhanced_oss_key = enhanced_oss_key[len("oss://"):] |
|
|
| progress(0.75, desc="Downloading enhanced image...") |
| enhanced = _download_image(enhanced_oss_key, env="prod") |
|
|
| progress(0.9, desc="Creating comparison...") |
| comparison = create_dlss5_comparison(image, enhanced) |
|
|
| w, h = enhanced.size |
| original_resized = image.resize((w, h), Image.LANCZOS) |
|
|
| return comparison, actual_seed, original_resized, enhanced |
|
|
|
|
| css = r""" |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap'); |
| *{box-sizing:border-box;margin:0;padding:0} |
| body,.gradio-container{background:#0f0f13!important;font-family:'Inter',system-ui,-apple-system,sans-serif!important;font-size:14px!important;color:#e4e4e7!important} |
| footer{display:none!important} |
| .gradio-container{max-width:1200px!important;margin:0 auto!important;padding:30px 20px!important} |
| .app-header{background:linear-gradient(135deg,#18181b,#1e1e24);border:1px solid #27272a;border-radius:16px;padding:24px;margin-bottom:24px;box-shadow:0 25px 50px -12px rgba(0,0,0,.6)} |
| .main-title h1{text-align:center;font-family:'Inter',sans-serif!important;color:#1E90FF!important;font-size:2.5em!important;font-weight:800!important;letter-spacing:-.5px;margin:0;background:linear-gradient(135deg,#1E90FF,#47A3FF);-webkit-background-clip:text;-webkit-text-fill-color:transparent} |
| .subtitle p{text-align:center;color:#a1a1aa!important;font-size:15px!important;margin:12px 0 0 0;font-weight:500} |
| .block{border-color:#27272a!important;background:#18181b!important;border-radius:12px!important;border:1px solid #27272a!important} |
| .label-wrap{background:#18181b!important;color:#a1a1aa!important;border-color:#27272a!important;font-weight:600!important} |
| .upload-area{border-color:#1E90FF44!important;background:rgba(30,144,255,.03)!important;border-radius:12px!important} |
| .upload-area:hover{border-color:#1E90FF!important;background:rgba(30,144,255,.08)!important} |
| .progress-bar{background:linear-gradient(90deg,#1E90FF,#47A3FF,#1E90FF)!important;background-size:200% 100%!important;animation:shimmer 1.5s ease-in-out infinite!important} |
| @keyframes shimmer{0%{background-position:200% 0}100%{background-position:-200% 0}} |
| .progress-bar-wrap{background:#27272a!important;border-color:#3f3f46!important} |
| #go-btn{background:linear-gradient(135deg,#1E90FF,#1873CC)!important;color:#ffffff!important;font-weight:700!important;font-size:16px!important;min-height:56px!important;border:none!important;border-radius:12px!important;font-family:'Inter',sans-serif!important;box-shadow:0 8px 24px rgba(30,144,255,.35),inset 0 1px 0 rgba(255,255,255,.15)!important;transition:all .3s cubic-bezier(.34,1.56,.64,1)!important;letter-spacing:-.3px} |
| #go-btn:hover{background:linear-gradient(135deg,#47A3FF,#1E90FF)!important;box-shadow:0 12px 32px rgba(30,144,255,.5),inset 0 1px 0 rgba(255,255,255,.2)!important;transform:translateY(-2px)!important} |
| #go-btn:active{transform:translateY(0)!important;box-shadow:0 4px 12px rgba(30,144,255,.3)!important} |
| #video-btn{background:linear-gradient(135deg,#2563eb,#1d4ed8)!important;color:#ffffff!important;font-weight:600!important;font-size:14px!important;min-height:48px!important;border:none!important;border-radius:10px!important;font-family:'Inter',sans-serif!important;box-shadow:0 4px 16px rgba(37,99,235,.3)!important;transition:all .2s ease!important} |
| #video-btn:hover{background:linear-gradient(135deg,#3b82f6,#2563eb)!important;box-shadow:0 8px 24px rgba(37,99,235,.4)!important;transform:translateY(-1px)!important} |
| .gallery-item{border-color:#27272a!important;background:#18181b!important;border-radius:10px!important;transition:all .2s ease!important} |
| .gallery-item:hover{border-color:#1E90FF!important;box-shadow:0 8px 24px rgba(30,144,255,.2)!important;transform:translateY(-2px)!important} |
| .accordion{border:1px solid #27272a!important;border-radius:10px!important;background:#18181b!important} |
| .accordion-button{background:#18181b!important;color:#a1a1aa!important;border:none!important;font-weight:600!important} |
| .accordion-button:hover{background:rgba(30,144,255,.05)!important} |
| ::-webkit-scrollbar{width:8px;height:8px} |
| ::-webkit-scrollbar-track{background:#09090b} |
| ::-webkit-scrollbar-thumb{background:#27272a;border-radius:4px} |
| ::-webkit-scrollbar-thumb:hover{background:#3f3f46} |
| """ |
|
|
| with gr.Blocks(title="DLSS 5 Anything Pro", css=css, theme=gr.themes.Base( |
| primary_hue=gr.themes.colors.blue, |
| secondary_hue=gr.themes.colors.blue, |
| neutral_hue=gr.themes.colors.gray, |
| font=gr.themes.GoogleFont("Inter"), |
| ).set( |
| body_background_fill="#0f0f13", |
| body_background_fill_dark="#0f0f13", |
| block_background_fill="#18181b", |
| block_background_fill_dark="#18181b", |
| block_border_color="#27272a", |
| block_border_color_dark="#27272a", |
| block_label_text_color="#1E90FF", |
| block_label_text_color_dark="#1E90FF", |
| body_text_color="#e4e4e7", |
| body_text_color_dark="#e4e4e7", |
| button_primary_background_fill="#1E90FF", |
| button_primary_background_fill_dark="#1E90FF", |
| button_primary_text_color="#ffffff", |
| button_primary_text_color_dark="#ffffff", |
| input_background_fill="#09090b", |
| input_background_fill_dark="#09090b", |
| input_border_color="#27272a", |
| input_border_color_dark="#27272a", |
| )) as demo: |
|
|
| gr.HTML(""" |
| <div class="app-header"> |
| <h1 class="main-title">🎮 DLSS 5 Anything</h1> |
| <p class="subtitle">Transform any image with AI-powered enhancement</p> |
| </div> |
| """) |
|
|
| prompt = gr.Textbox(value="make it more realistic", visible=False) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### 📸 Upload Image") |
| input_image = gr.Image(label="", type="pil", elem_id="input-img") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### ⚙️ Settings") |
| with gr.Accordion("Advanced Options", open=False): |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
| num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=20, step=1, value=4) |
|
|
| go_btn = gr.Button("🚀 DLSS 5 it!", elem_id="go-btn", variant="primary", scale=1) |
|
|
| gr.Markdown("### 🎯 Result") |
| output_image = gr.Image(label="", type="pil", elem_id="output-img") |
|
|
| original_state = gr.State(None) |
| enhanced_state = gr.State(None) |
|
|
| with gr.Row(): |
| video_btn = gr.Button("🎬 Generate Video", elem_id="video-btn", visible=False, scale=1) |
| video_file = gr.File(visible=False) |
|
|
| def on_generate(image, prompt, seed, randomize_seed, num_inference_steps, progress=gr.Progress(track_tqdm=True)): |
| comparison, seed, orig, enh = process(image, prompt, seed, randomize_seed, num_inference_steps, progress) |
| return comparison, seed, orig, enh, gr.update(visible=True), gr.update(visible=False) |
|
|
| go_btn.click( |
| fn=on_generate, |
| inputs=[input_image, prompt, seed, randomize_seed, num_inference_steps], |
| outputs=[output_image, seed, original_state, enhanced_state, video_btn, video_file], |
| ) |
|
|
| input_image.change( |
| fn=lambda: (gr.update(visible=False), gr.update(visible=False), None, None), |
| inputs=[], |
| outputs=[video_btn, video_file, original_state, enhanced_state], |
| ) |
|
|
| def make_video(orig, enh): |
| if orig is None or enh is None: |
| raise gr.Error("Generate a DLSS 5 comparison first!") |
| return gr.update(value="video_demo.mp4", visible=True) |
|
|
| video_btn.click( |
| fn=lambda: gr.update(value="⏳ Generating...", interactive=False), |
| inputs=[], |
| outputs=[video_btn], |
| ).then( |
| fn=make_video, |
| inputs=[original_state, enhanced_state], |
| outputs=[video_file], |
| ).then( |
| fn=lambda: gr.update(value="🎬 Generate Video", interactive=True), |
| inputs=[], |
| outputs=[video_btn], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|