File size: 6,363 Bytes
fbbef9a 42753e7 fbbef9a 42753e7 aa81089 fbbef9a 42753e7 fbbef9a 42753e7 aa81089 42753e7 aa81089 42753e7 aa81089 42753e7 fbbef9a 42753e7 aa81089 42753e7 aa81089 42753e7 fbbef9a 42753e7 fbbef9a 42753e7 fbbef9a 42753e7 aa81089 fbbef9a aa81089 42753e7 fbbef9a aa81089 42753e7 fbbef9a 42753e7 aa81089 fbbef9a aa81089 42753e7 fbbef9a aa81089 fbbef9a aa81089 fbbef9a aa81089 fbbef9a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import gradio as gr
import torch
import numpy as np
import cv2
import base64
import os
import requests
from PIL import Image
# === DEFINISI ARSITEKTUR SWINIR (MANDIRI) ===
# Kita mendefinisikan class model secara langsung agar tidak bergantung pada library eksternal yang rawan error
import torch.nn as nn
import torch.nn.functional as F
class Upsample(nn.Sequential):
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0:
for _ in range(int(np.log2(scale))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n')
super(Upsample, self).__init__(*m)
# Memuat arsitektur lengkap SwinIR membutuhkan kode yang sangat panjang jika ditulis manual.
# Untuk stabilitas terbaik di Hugging Face, kita akan mendownload kode implementasi resminya secara dinamis
# atau menggunakan versi timm/open-source yang kompatibel.
def download_model(url, save_path):
if not os.path.exists(save_path):
print(f"Mengunduh model dari {url}...")
response = requests.get(url, stream=True)
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Selesai mengunduh.")
# Konfigurasi Model
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
# Menggunakan Swin2SR (Pembaruan dari SwinIR yang lebih ringan dan bagus)
MODEL_NAME = "Swin2SR_ClassicalSR_X4_64.pth"
MODEL_URL = f"https://github.com/mvassell/Swin2SR/releases/download/v0.0.1/{MODEL_NAME}"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Menjalankan dengan device: {device}")
# Download Model Checkpoint
download_model(MODEL_URL, MODEL_PATH)
# --- FUNGSI LOAD MODEL SWIN2SR ---
# Karena arsitektur Swin Transformer cukup kompleks, cara paling stabil di Hugging Face
# adalah menggunakan library transformers langsung jika memungkinkan,
# atau mendefinisikan jaringan secara dinamis dari github.
# Untuk kemudahan dan kepastian jalan tanpa error instalasi, kita akan gunakan pendekatan
# standar interpolasi Lanczos (super high quality resize) + penguatan ketajaman
# SEBAGAI FALLBACK JIKA SWIN GAGAL LOAD (mengingat keterbatasan CPU HF).
# Namun kita akan mencoba load Swin2SR terlebih dahulu.
import sys
import subprocess
def setup_swinir():
try:
# Install library khusus Swin2SR secara dinamis untuk menghindari dependency hell
subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers"])
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
processor = Swin2SRImageProcessor()
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
model = model.to(device)
return processor, model
except Exception as e:
print(f"Gagal memuat Swin2SR Transformer: {e}")
return None, None
processor, model = setup_swinir()
# === FUNGSI UTAMA API ===
def decode_base64_to_pil(base64_string):
if "base64," in base64_string:
base64_string = base64_string.split("base64,")[1]
img_data = base64.b64decode(base64_string)
img = Image.open(import_io.BytesIO(img_data)).convert('RGB')
return img
import io as import_io
def encode_pil_to_base64(img_pil):
buffered = import_io.BytesIO()
img_pil.save(buffered, format="PNG")
b64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64_str}"
def process_upscale(base64_input, target_res):
print(f"Memproses resolusi target: {target_res}")
try:
img_pil = decode_base64_to_pil(base64_input)
# PROSES UPSCALING
if model is not None and processor is not None:
# Gunakan Swin2SR AI
print("Menggunakan AI Swin2SR Transformer...")
inputs = processor(img_pil, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_tensor = np.transpose(output_tensor[[2, 1, 0], :, :], (1, 2, 0))
output_tensor = (output_tensor * 255.0).round().astype(np.uint8)
# Convert back to PIL
output_img = Image.fromarray(cv2.cvtColor(output_tensor, cv2.COLOR_BGR2RGB))
else:
# Fallback Super Lanczos jika AI gagal dimuat karena keterbatasan memory HF
print("Peringatan: Menggunakan Algoritma High-Quality Lanczos (Swin2SR gagal dimuat).")
w, h = img_pil.size
output_img = img_pil.resize((w*4, h*4), Image.Resampling.LANCZOS)
# Tambahkan sedikit penajaman (Sharpening) untuk meniru efek AI
from PIL import ImageFilter
output_img = output_img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
# --- RESIZING KE TARGET (2K / 4K) ---
w, h = output_img.size
if target_res.lower() == '2k':
max_size = 2560
if max(w, h) > max_size:
scale = max_size / max(w, h)
output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)
elif target_res.lower() == '4k':
max_size = 3840
if max(w, h) > max_size:
scale = max_size / max(w, h)
output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)
print(f"Upscale berhasil: {output_img.size}")
return encode_pil_to_base64(output_img)
except Exception as e:
import traceback
traceback.print_exc()
return f"ERROR: {str(e)}"
# === GRADIO INTERFACE ===
with gr.Blocks() as demo:
input_text = gr.Textbox(label="Base64 Input")
res_text = gr.Textbox(label="Resolution", value="2k")
output_text = gr.Textbox(label="Base64 Output")
btn = gr.Button("Upscale")
btn.click(fn=process_upscale, inputs=[input_text, res_text], outputs=output_text, api_name="predict")
if __name__ == "__main__":
demo.queue().launch() |