| import gradio as gr |
| import torch |
| import numpy as np |
| import cv2 |
| import base64 |
| import os |
| import requests |
| from PIL import Image |
|
|
| |
| |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class Upsample(nn.Sequential): |
| def __init__(self, scale, num_feat): |
| m = [] |
| if (scale & (scale - 1)) == 0: |
| for _ in range(int(np.log2(scale))): |
| m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) |
| m.append(nn.PixelShuffle(2)) |
| else: |
| raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n') |
| super(Upsample, self).__init__(*m) |
|
|
| |
| |
| |
|
|
| def download_model(url, save_path): |
| if not os.path.exists(save_path): |
| print(f"Mengunduh model dari {url}...") |
| response = requests.get(url, stream=True) |
| with open(save_path, 'wb') as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| print("Selesai mengunduh.") |
|
|
| |
| MODEL_DIR = "models" |
| os.makedirs(MODEL_DIR, exist_ok=True) |
| |
| MODEL_NAME = "Swin2SR_ClassicalSR_X4_64.pth" |
| MODEL_URL = f"https://github.com/mvassell/Swin2SR/releases/download/v0.0.1/{MODEL_NAME}" |
| MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Menjalankan dengan device: {device}") |
|
|
| |
| download_model(MODEL_URL, MODEL_PATH) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import sys |
| import subprocess |
|
|
| def setup_swinir(): |
| try: |
| |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers"]) |
| from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor |
| |
| processor = Swin2SRImageProcessor() |
| model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64") |
| model = model.to(device) |
| return processor, model |
| except Exception as e: |
| print(f"Gagal memuat Swin2SR Transformer: {e}") |
| return None, None |
|
|
| processor, model = setup_swinir() |
|
|
| |
| def decode_base64_to_pil(base64_string): |
| if "base64," in base64_string: |
| base64_string = base64_string.split("base64,")[1] |
| img_data = base64.b64decode(base64_string) |
| img = Image.open(import_io.BytesIO(img_data)).convert('RGB') |
| return img |
| |
| import io as import_io |
|
|
| def encode_pil_to_base64(img_pil): |
| buffered = import_io.BytesIO() |
| img_pil.save(buffered, format="PNG") |
| b64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| return f"data:image/png;base64,{b64_str}" |
|
|
| def process_upscale(base64_input, target_res): |
| print(f"Memproses resolusi target: {target_res}") |
| |
| try: |
| img_pil = decode_base64_to_pil(base64_input) |
| |
| |
| if model is not None and processor is not None: |
| |
| print("Menggunakan AI Swin2SR Transformer...") |
| inputs = processor(img_pil, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
| output_tensor = np.transpose(output_tensor[[2, 1, 0], :, :], (1, 2, 0)) |
| output_tensor = (output_tensor * 255.0).round().astype(np.uint8) |
| |
| output_img = Image.fromarray(cv2.cvtColor(output_tensor, cv2.COLOR_BGR2RGB)) |
| else: |
| |
| print("Peringatan: Menggunakan Algoritma High-Quality Lanczos (Swin2SR gagal dimuat).") |
| w, h = img_pil.size |
| output_img = img_pil.resize((w*4, h*4), Image.Resampling.LANCZOS) |
| |
| |
| from PIL import ImageFilter |
| output_img = output_img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3)) |
|
|
| |
| w, h = output_img.size |
| |
| if target_res.lower() == '2k': |
| max_size = 2560 |
| if max(w, h) > max_size: |
| scale = max_size / max(w, h) |
| output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS) |
| elif target_res.lower() == '4k': |
| max_size = 3840 |
| if max(w, h) > max_size: |
| scale = max_size / max(w, h) |
| output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS) |
|
|
| print(f"Upscale berhasil: {output_img.size}") |
| return encode_pil_to_base64(output_img) |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return f"ERROR: {str(e)}" |
|
|
| |
| with gr.Blocks() as demo: |
| input_text = gr.Textbox(label="Base64 Input") |
| res_text = gr.Textbox(label="Resolution", value="2k") |
| output_text = gr.Textbox(label="Base64 Output") |
| |
| btn = gr.Button("Upscale") |
| btn.click(fn=process_upscale, inputs=[input_text, res_text], outputs=output_text, api_name="predict") |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |