Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import shutil | |
| import tempfile | |
| import threading | |
| import queue | |
| import zipfile | |
| import uuid | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| import onnxruntime as ort | |
| import gradio as gr | |
| # ================= CONFIG ================= | |
| MODEL_DIR = "model" | |
| MODEL_X2_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN_x2plus.onnx") | |
| MODEL_X4_PATH = os.path.join(MODEL_DIR, "Real-ESRGAN-x4plus.onnx") | |
| FILE_ID_X2 = "15xmXXZNH2wMyeQv4ie5hagT7eWK9MgP6" | |
| FILE_ID_X4 = "1wDBHad9RCJgJDGsPdapLYl3cr8j-PMJ6" | |
| MAX_DIM = 1024 | |
| INPUT_DIR = "inputs" | |
| OUTPUT_DIR = "outputs" | |
| os.makedirs(INPUT_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # ================= MODEL DOWNLOAD ================= | |
| def download_from_drive(file_id, dest): | |
| url = "https://drive.google.com/uc?export=download" | |
| session = requests.Session() | |
| r = session.get(url, params={"id": file_id}, stream=True) | |
| token = None | |
| for k, v in r.cookies.items(): | |
| if k.startswith("download_warning"): | |
| token = v | |
| break | |
| if token: | |
| r = session.get(url, params={"id": file_id, "confirm": token}, stream=True) | |
| os.makedirs(os.path.dirname(dest), exist_ok=True) | |
| with open(dest, "wb") as f: | |
| for chunk in r.iter_content(32768): | |
| if chunk: | |
| f.write(chunk) | |
| print("Downloading models if not exist...") | |
| if not os.path.exists(MODEL_X2_PATH): | |
| download_from_drive(FILE_ID_X2, MODEL_X2_PATH) | |
| if not os.path.exists(MODEL_X4_PATH): | |
| download_from_drive(FILE_ID_X4, MODEL_X4_PATH) | |
| print("Models ready!") | |
| # ================= ONNX SESSIONS ================= | |
| opts = ort.SessionOptions() | |
| opts.intra_op_num_threads = 2 | |
| opts.inter_op_num_threads = 2 | |
| sess_x2 = ort.InferenceSession(MODEL_X2_PATH, opts, providers=["CPUExecutionProvider"]) | |
| sess_x4 = ort.InferenceSession(MODEL_X4_PATH, opts, providers=["CPUExecutionProvider"]) | |
| meta_x2 = sess_x2.get_inputs()[0] | |
| meta_x4 = sess_x4.get_inputs()[0] | |
| _, _, H2, W2 = meta_x2.shape | |
| _, _, H4, W4 = meta_x4.shape | |
| # ================= CORE LOGIC ================= | |
| def run_tile(tile, session, meta): | |
| inp = np.transpose(tile, (2, 0, 1))[None, ...] | |
| out = session.run(None, {meta.name: inp})[0][0] | |
| return np.transpose(out, (1, 2, 0)) | |
| def upscale_core(img: Image.Image, scale: int): | |
| if scale == 2: | |
| H, W, sess, meta, S = H2, W2, sess_x2, meta_x2, 2 | |
| else: | |
| H, W, sess, meta, S = H4, W4, sess_x4, meta_x4, 4 | |
| w, h = img.size | |
| if max(w, h) > MAX_DIM: | |
| r = MAX_DIM / max(w, h) | |
| img = img.resize((int(w*r), int(h*r)), Image.LANCZOS) | |
| arr = np.array(img.convert("RGB")).astype(np.float32) / 255.0 | |
| h0, w0, _ = arr.shape | |
| th = math.ceil(h0 / H) | |
| tw = math.ceil(w0 / W) | |
| pad = np.pad(arr, ((0, th*H-h0), (0, tw*W-w0), (0, 0)), mode="reflect") | |
| out = np.zeros((th*H*S, tw*W*S, 3), dtype=np.float32) | |
| for i in range(th): | |
| for j in range(tw): | |
| tile = pad[i*H:(i+1)*H, j*W:(j+1)*W] | |
| up = run_tile(tile, sess, meta) | |
| out[i*H*S:(i+1)*H*S, j*W*S:(j+1)*W*S] = up | |
| out = np.clip(out[:h0*S, :w0*S], 0, 1) | |
| return Image.fromarray((out * 255).astype(np.uint8)) | |
| # ================= BACKGROUND QUEUE ================= | |
| task_queue = queue.Queue() | |
| def background_worker(): | |
| while True: | |
| task = task_queue.get() | |
| if task is None: | |
| break | |
| in_path, mode, out_path = task | |
| try: | |
| print(f"Processing {in_path} -> mode {mode}") | |
| img = Image.open(in_path) | |
| if mode == "x2": | |
| out = upscale_core(img, 2) | |
| elif mode == "x8": | |
| temp = upscale_core(img, 4) | |
| out = temp.resize((img.width * 8, img.height * 8), Image.LANCZOS) | |
| else: | |
| out = upscale_core(img, 4) | |
| out.save(out_path, format="PNG") | |
| except Exception as e: | |
| print(f"Error processing {in_path}: {e}") | |
| finally: | |
| if os.path.exists(in_path): | |
| os.remove(in_path) # Storage bachane ke liye original image delete | |
| task_queue.task_done() | |
| threading.Thread(target=background_worker, daemon=True).start() | |
| # ================= GRADIO UI FUNCTIONS ================= | |
| def submit_images(files, mode): | |
| if not files: | |
| return "β Please upload at least 1 image." | |
| if len(files) > 40: | |
| return "β Warning: Maximum 40 images allowed at a time!" | |
| # Process up to 40 files | |
| files = files[:40] | |
| count = 0 | |
| for f in files: | |
| unique_name = str(uuid.uuid4())[:8] | |
| filename = os.path.basename(f.name) | |
| in_path = os.path.join(INPUT_DIR, f"{unique_name}_{filename}") | |
| out_path = os.path.join(OUTPUT_DIR, f"UP_{unique_name}_{filename.split('.')[0]}.png") | |
| shutil.copy(f.name, in_path) | |
| task_queue.put((in_path, mode, out_path)) | |
| count += 1 | |
| return f"β {count} Images successfully added to the background queue! (Mode: {mode})\n\nπ You can safely CLOSE THIS TAB now. The images will be processed one by one. Check the 'View & Download Results' tab later." | |
| def get_results(): | |
| files = [os.path.join(OUTPUT_DIR, f) for f in os.listdir(OUTPUT_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
| return files | |
| def create_zip(): | |
| files = get_results() | |
| if not files: | |
| return None | |
| zip_name = "Upscaled_Images.zip" | |
| with zipfile.ZipFile(zip_name, 'w') as zipf: | |
| for f in files: | |
| zipf.write(f, os.path.basename(f)) | |
| return zip_name | |
| def clear_all(): | |
| files = get_results() | |
| for f in files: | |
| try: | |
| os.remove(f) | |
| except: | |
| pass | |
| if os.path.exists("Upscaled_Images.zip"): | |
| os.remove("Upscaled_Images.zip") | |
| return [], None, "β All server images deleted successfully." | |
| # ================= GRADIO APP LAYOUT ================= | |
| with gr.Blocks(title="SpectraGAN Auto-Upscaler", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# π SpectraGAN Background Batch Upscaler") | |
| gr.Markdown("Hugging Face Free Tier optimized! Upload max 40 images. Background queueing allows you to close the browser safely while your images upscale.") | |
| with gr.Tabs(): | |
| # TAB 1: UPLOAD | |
| with gr.TabItem("β¬οΈ Upload & Process"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_input = gr.File(file_count="multiple", label="Upload Images (Max 40)", type="filepath") | |
| mode_input = gr.Radio(["x2", "x4", "x8"], value="x4", label="Upscaling Quality") | |
| submit_btn = gr.Button("π Start Upscaling (Background)", variant="primary") | |
| with gr.Column(scale=1): | |
| status_out = gr.Textbox(label="Status Log", lines=5, interactive=False) | |
| submit_btn.click(fn=submit_images, inputs=[file_input, mode_input], outputs=status_out) | |
| # TAB 2: RESULTS | |
| with gr.TabItem("πΌοΈ View & Download Results"): | |
| refresh_btn = gr.Button("π Refresh Gallery (Check Progress)") | |
| gallery = gr.Gallery(label="Completed Images", columns=4, height="auto") | |
| with gr.Row(): | |
| zip_btn = gr.Button("π¦ Create & Download ZIP", variant="primary") | |
| zip_out = gr.File(label="Download Zip File Here") | |
| clear_btn = gr.Button("ποΈ Delete All Images from Server", variant="stop") | |
| clear_status = gr.Markdown() | |
| refresh_btn.click(fn=get_results, inputs=[], outputs=gallery) | |
| zip_btn.click(fn=create_zip, inputs=[], outputs=zip_out) | |
| clear_btn.click(fn=clear_all, inputs=[], outputs=[gallery, zip_out, clear_status]) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) |