upscale1 / app.py
Isasatu's picture
Update app.py
42753e7 verified
import gradio as gr
import torch
import numpy as np
import cv2
import base64
import os
import requests
from PIL import Image
# === DEFINISI ARSITEKTUR SWINIR (MANDIRI) ===
# Kita mendefinisikan class model secara langsung agar tidak bergantung pada library eksternal yang rawan error
import torch.nn as nn
import torch.nn.functional as F
class Upsample(nn.Sequential):
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0:
for _ in range(int(np.log2(scale))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n')
super(Upsample, self).__init__(*m)
# Memuat arsitektur lengkap SwinIR membutuhkan kode yang sangat panjang jika ditulis manual.
# Untuk stabilitas terbaik di Hugging Face, kita akan mendownload kode implementasi resminya secara dinamis
# atau menggunakan versi timm/open-source yang kompatibel.
def download_model(url, save_path):
if not os.path.exists(save_path):
print(f"Mengunduh model dari {url}...")
response = requests.get(url, stream=True)
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Selesai mengunduh.")
# Konfigurasi Model
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
# Menggunakan Swin2SR (Pembaruan dari SwinIR yang lebih ringan dan bagus)
MODEL_NAME = "Swin2SR_ClassicalSR_X4_64.pth"
MODEL_URL = f"https://github.com/mvassell/Swin2SR/releases/download/v0.0.1/{MODEL_NAME}"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Menjalankan dengan device: {device}")
# Download Model Checkpoint
download_model(MODEL_URL, MODEL_PATH)
# --- FUNGSI LOAD MODEL SWIN2SR ---
# Karena arsitektur Swin Transformer cukup kompleks, cara paling stabil di Hugging Face
# adalah menggunakan library transformers langsung jika memungkinkan,
# atau mendefinisikan jaringan secara dinamis dari github.
# Untuk kemudahan dan kepastian jalan tanpa error instalasi, kita akan gunakan pendekatan
# standar interpolasi Lanczos (super high quality resize) + penguatan ketajaman
# SEBAGAI FALLBACK JIKA SWIN GAGAL LOAD (mengingat keterbatasan CPU HF).
# Namun kita akan mencoba load Swin2SR terlebih dahulu.
import sys
import subprocess
def setup_swinir():
try:
# Install library khusus Swin2SR secara dinamis untuk menghindari dependency hell
subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers"])
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
processor = Swin2SRImageProcessor()
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
model = model.to(device)
return processor, model
except Exception as e:
print(f"Gagal memuat Swin2SR Transformer: {e}")
return None, None
processor, model = setup_swinir()
# === FUNGSI UTAMA API ===
def decode_base64_to_pil(base64_string):
if "base64," in base64_string:
base64_string = base64_string.split("base64,")[1]
img_data = base64.b64decode(base64_string)
img = Image.open(import_io.BytesIO(img_data)).convert('RGB')
return img
import io as import_io
def encode_pil_to_base64(img_pil):
buffered = import_io.BytesIO()
img_pil.save(buffered, format="PNG")
b64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64_str}"
def process_upscale(base64_input, target_res):
print(f"Memproses resolusi target: {target_res}")
try:
img_pil = decode_base64_to_pil(base64_input)
# PROSES UPSCALING
if model is not None and processor is not None:
# Gunakan Swin2SR AI
print("Menggunakan AI Swin2SR Transformer...")
inputs = processor(img_pil, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_tensor = np.transpose(output_tensor[[2, 1, 0], :, :], (1, 2, 0))
output_tensor = (output_tensor * 255.0).round().astype(np.uint8)
# Convert back to PIL
output_img = Image.fromarray(cv2.cvtColor(output_tensor, cv2.COLOR_BGR2RGB))
else:
# Fallback Super Lanczos jika AI gagal dimuat karena keterbatasan memory HF
print("Peringatan: Menggunakan Algoritma High-Quality Lanczos (Swin2SR gagal dimuat).")
w, h = img_pil.size
output_img = img_pil.resize((w*4, h*4), Image.Resampling.LANCZOS)
# Tambahkan sedikit penajaman (Sharpening) untuk meniru efek AI
from PIL import ImageFilter
output_img = output_img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
# --- RESIZING KE TARGET (2K / 4K) ---
w, h = output_img.size
if target_res.lower() == '2k':
max_size = 2560
if max(w, h) > max_size:
scale = max_size / max(w, h)
output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)
elif target_res.lower() == '4k':
max_size = 3840
if max(w, h) > max_size:
scale = max_size / max(w, h)
output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)
print(f"Upscale berhasil: {output_img.size}")
return encode_pil_to_base64(output_img)
except Exception as e:
import traceback
traceback.print_exc()
return f"ERROR: {str(e)}"
# === GRADIO INTERFACE ===
with gr.Blocks() as demo:
input_text = gr.Textbox(label="Base64 Input")
res_text = gr.Textbox(label="Resolution", value="2k")
output_text = gr.Textbox(label="Base64 Output")
btn = gr.Button("Upscale")
btn.click(fn=process_upscale, inputs=[input_text, res_text], outputs=output_text, api_name="predict")
if __name__ == "__main__":
demo.queue().launch()