File size: 6,363 Bytes
fbbef9a
 
 
42753e7
fbbef9a
42753e7
 
aa81089
fbbef9a
42753e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbbef9a
 
42753e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa81089
42753e7
 
 
aa81089
42753e7
 
 
 
aa81089
42753e7
 
fbbef9a
42753e7
aa81089
42753e7
 
 
 
 
 
 
aa81089
42753e7
 
 
 
 
 
 
fbbef9a
42753e7
 
 
fbbef9a
42753e7
fbbef9a
42753e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa81089
fbbef9a
 
 
aa81089
42753e7
fbbef9a
 
 
aa81089
42753e7
fbbef9a
42753e7
 
aa81089
fbbef9a
aa81089
 
42753e7
fbbef9a
aa81089
fbbef9a
aa81089
 
 
fbbef9a
 
aa81089
fbbef9a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import torch
import numpy as np
import cv2
import base64
import os
import requests
from PIL import Image

# === DEFINISI ARSITEKTUR SWINIR (MANDIRI) ===
# Kita mendefinisikan class model secara langsung agar tidak bergantung pada library eksternal yang rawan error
import torch.nn as nn
import torch.nn.functional as F

class Upsample(nn.Sequential):
    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:
            for _ in range(int(np.log2(scale))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        else:
            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n')
        super(Upsample, self).__init__(*m)

# Memuat arsitektur lengkap SwinIR membutuhkan kode yang sangat panjang jika ditulis manual.
# Untuk stabilitas terbaik di Hugging Face, kita akan mendownload kode implementasi resminya secara dinamis
# atau menggunakan versi timm/open-source yang kompatibel.

def download_model(url, save_path):
    if not os.path.exists(save_path):
        print(f"Mengunduh model dari {url}...")
        response = requests.get(url, stream=True)
        with open(save_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Selesai mengunduh.")

# Konfigurasi Model
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
# Menggunakan Swin2SR (Pembaruan dari SwinIR yang lebih ringan dan bagus)
MODEL_NAME = "Swin2SR_ClassicalSR_X4_64.pth"
MODEL_URL = f"https://github.com/mvassell/Swin2SR/releases/download/v0.0.1/{MODEL_NAME}"
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Menjalankan dengan device: {device}")

# Download Model Checkpoint
download_model(MODEL_URL, MODEL_PATH)

# --- FUNGSI LOAD MODEL SWIN2SR ---
# Karena arsitektur Swin Transformer cukup kompleks, cara paling stabil di Hugging Face 
# adalah menggunakan library transformers langsung jika memungkinkan, 
# atau mendefinisikan jaringan secara dinamis dari github.
# Untuk kemudahan dan kepastian jalan tanpa error instalasi, kita akan gunakan pendekatan 
# standar interpolasi Lanczos (super high quality resize) + penguatan ketajaman 
# SEBAGAI FALLBACK JIKA SWIN GAGAL LOAD (mengingat keterbatasan CPU HF).
# Namun kita akan mencoba load Swin2SR terlebih dahulu.

import sys
import subprocess

def setup_swinir():
    try:
        # Install library khusus Swin2SR secara dinamis untuk menghindari dependency hell
        subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers"])
        from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
        
        processor = Swin2SRImageProcessor()
        model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
        model = model.to(device)
        return processor, model
    except Exception as e:
        print(f"Gagal memuat Swin2SR Transformer: {e}")
        return None, None

processor, model = setup_swinir()

# === FUNGSI UTAMA API ===
def decode_base64_to_pil(base64_string):
    if "base64," in base64_string:
        base64_string = base64_string.split("base64,")[1]
    img_data = base64.b64decode(base64_string)
    img = Image.open(import_io.BytesIO(img_data)).convert('RGB')
    return img
    
import io as import_io

def encode_pil_to_base64(img_pil):
    buffered = import_io.BytesIO()
    img_pil.save(buffered, format="PNG")
    b64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{b64_str}"

def process_upscale(base64_input, target_res):
    print(f"Memproses resolusi target: {target_res}")
    
    try:
        img_pil = decode_base64_to_pil(base64_input)
        
        # PROSES UPSCALING
        if model is not None and processor is not None:
            # Gunakan Swin2SR AI
            print("Menggunakan AI Swin2SR Transformer...")
            inputs = processor(img_pil, return_tensors="pt").to(device)
            with torch.no_grad():
                outputs = model(**inputs)
            
            output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            output_tensor = np.transpose(output_tensor[[2, 1, 0], :, :], (1, 2, 0))
            output_tensor = (output_tensor * 255.0).round().astype(np.uint8)
            # Convert back to PIL
            output_img = Image.fromarray(cv2.cvtColor(output_tensor, cv2.COLOR_BGR2RGB))
        else:
            # Fallback Super Lanczos jika AI gagal dimuat karena keterbatasan memory HF
            print("Peringatan: Menggunakan Algoritma High-Quality Lanczos (Swin2SR gagal dimuat).")
            w, h = img_pil.size
            output_img = img_pil.resize((w*4, h*4), Image.Resampling.LANCZOS)
            
            # Tambahkan sedikit penajaman (Sharpening) untuk meniru efek AI
            from PIL import ImageFilter
            output_img = output_img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))

        # --- RESIZING KE TARGET (2K / 4K) ---
        w, h = output_img.size
        
        if target_res.lower() == '2k':
            max_size = 2560
            if max(w, h) > max_size:
                scale = max_size / max(w, h)
                output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)
        elif target_res.lower() == '4k':
            max_size = 3840
            if max(w, h) > max_size:
                scale = max_size / max(w, h)
                output_img = output_img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)

        print(f"Upscale berhasil: {output_img.size}")
        return encode_pil_to_base64(output_img)

    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"ERROR: {str(e)}"

# === GRADIO INTERFACE ===
with gr.Blocks() as demo:
    input_text = gr.Textbox(label="Base64 Input")
    res_text = gr.Textbox(label="Resolution", value="2k")
    output_text = gr.Textbox(label="Base64 Output")
    
    btn = gr.Button("Upscale")
    btn.click(fn=process_upscale, inputs=[input_text, res_text], outputs=output_text, api_name="predict")

if __name__ == "__main__":
    demo.queue().launch()