Spaces:
Sleeping
Sleeping
Commit
·
d9e2abb
0
Parent(s):
Initial changes
Browse files- app.py +228 -0
- deploy.sh +21 -0
- models/fsrcnn_x2.pth +0 -0
- models/fsrcnn_x3.pth +0 -0
- models/fsrcnn_x4.pth +0 -0
- requirements.txt +10 -0
- secrets.txt +1 -0
- tests/test2.py +104 -0
- tests/test_app.py +137 -0
app.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import typing as tp
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from huggingface_hub import InferenceClient
|
| 10 |
+
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ============================================================
|
| 16 |
+
# 🧠 PART 1: FSRCNN Image Upscaling
|
| 17 |
+
# ============================================================
|
| 18 |
+
|
| 19 |
+
class FSRCNN(nn.Module):
|
| 20 |
+
def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
|
| 21 |
+
super(FSRCNN, self).__init__()
|
| 22 |
+
self.first_part = nn.Sequential(
|
| 23 |
+
nn.Conv2d(num_channels, d, kernel_size=5, padding=5 // 2),
|
| 24 |
+
nn.PReLU(d)
|
| 25 |
+
)
|
| 26 |
+
self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
|
| 27 |
+
for _ in range(m):
|
| 28 |
+
self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3 // 2), nn.PReLU(s)])
|
| 29 |
+
self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
|
| 30 |
+
self.mid_part = nn.Sequential(*self.mid_part)
|
| 31 |
+
self.last_part = nn.ConvTranspose2d(
|
| 32 |
+
d, num_channels, kernel_size=9,
|
| 33 |
+
stride=scale_factor, padding=9 // 2,
|
| 34 |
+
output_padding=scale_factor - 1
|
| 35 |
+
)
|
| 36 |
+
self._initialize_weights()
|
| 37 |
+
|
| 38 |
+
def _initialize_weights(self):
|
| 39 |
+
for m in self.first_part:
|
| 40 |
+
if isinstance(m, nn.Conv2d):
|
| 41 |
+
nn.init.normal_(m.weight.data, mean=0.0,
|
| 42 |
+
std=math.sqrt(2 / (m.out_channels * m.weight.data[0][0].numel())))
|
| 43 |
+
nn.init.zeros_(m.bias.data)
|
| 44 |
+
for m in self.mid_part:
|
| 45 |
+
if isinstance(m, nn.Conv2d):
|
| 46 |
+
nn.init.normal_(m.weight.data, mean=0.0,
|
| 47 |
+
std=math.sqrt(2 / (m.out_channels * m.weight.data[0][0].numel())))
|
| 48 |
+
nn.init.zeros_(m.bias.data)
|
| 49 |
+
nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
|
| 50 |
+
nn.init.zeros_(self.last_part.bias.data)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
x = self.first_part(x)
|
| 54 |
+
x = self.mid_part(x)
|
| 55 |
+
x = self.last_part(x)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
Device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
MODEL_CACHE: dict[int, tuple[FSRCNN, bool]] = {}
|
| 61 |
+
WEIGHTS_PATHS = {2: "models/fsrcnn_x2.pth", 3: "models/fsrcnn_x3.pth", 4: "models/fsrcnn_x4.pth"}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def try_load_weights(model, weights_path):
|
| 65 |
+
if not weights_path or not os.path.isfile(weights_path):
|
| 66 |
+
print(f"[FSRCNN] No valid weights at {weights_path}. Falling back to Bicubic.")
|
| 67 |
+
return False
|
| 68 |
+
try:
|
| 69 |
+
checkpoint = torch.load(weights_path, map_location=Device, weights_only=False)
|
| 70 |
+
model.load_state_dict(checkpoint, strict=True)
|
| 71 |
+
print(f"[FSRCNN] Loaded weights from {weights_path}")
|
| 72 |
+
return True
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"[FSRCNN] Failed to load weights: {e}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_model(scale, weights_path=None):
|
| 79 |
+
if scale not in MODEL_CACHE:
|
| 80 |
+
model = FSRCNN(scale_factor=scale).to(Device).eval()
|
| 81 |
+
has_weights = try_load_weights(model, weights_path)
|
| 82 |
+
MODEL_CACHE[scale] = (model, has_weights)
|
| 83 |
+
return MODEL_CACHE[scale]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def rgb_to_ycbcr(img_rgb: np.ndarray) -> np.ndarray:
|
| 87 |
+
return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2YCrCb)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def ycbcr_to_rgb(img_ycrcb: np.ndarray) -> np.ndarray:
|
| 91 |
+
return cv2.cvtColor(img_ycrcb, cv2.COLOR_YCrCb2RGB)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def run_fsrcnn_on_y(y: np.ndarray, model: FSRCNN) -> np.ndarray:
|
| 95 |
+
y_f = y.astype(np.float32) / 255.0
|
| 96 |
+
tens = torch.from_numpy(y_f).unsqueeze(0).unsqueeze(0).to(Device)
|
| 97 |
+
with torch.inference_mode():
|
| 98 |
+
out = model(tens)
|
| 99 |
+
out_np = out.squeeze(0).squeeze(0).clamp(0.0, 1.0).cpu().numpy()
|
| 100 |
+
return (out_np * 255.0 + 0.5).astype(np.uint8)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def fsrcnn_upscale_rgb(img_rgb: np.ndarray, scale: int, weights: tp.Optional[str] = None) -> np.ndarray:
|
| 104 |
+
h, w = img_rgb.shape[:2]
|
| 105 |
+
model, has_weights = get_model(scale, weights)
|
| 106 |
+
if not has_weights:
|
| 107 |
+
return cv2.resize(img_rgb, (w * scale, h * scale), interpolation=cv2.INTER_CUBIC)
|
| 108 |
+
|
| 109 |
+
ycrcb = rgb_to_ycbcr(img_rgb)
|
| 110 |
+
y, cr, cb = ycrcb[..., 0], ycrcb[..., 1], ycrcb[..., 2]
|
| 111 |
+
y_sr = run_fsrcnn_on_y(y, model)
|
| 112 |
+
new_w, new_h = w * scale, h * scale
|
| 113 |
+
cr_up = cv2.resize(cr, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
| 114 |
+
cb_up = cv2.resize(cb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
| 115 |
+
ycrcb_up = np.stack([y_sr, cr_up, cb_up], axis=-1)
|
| 116 |
+
return ycbcr_to_rgb(ycrcb_up)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def maybe_downscale_for_memory(img_rgb: np.ndarray, max_pixels: int = 8_000_000) -> np.ndarray:
|
| 120 |
+
h, w = img_rgb.shape[:2]
|
| 121 |
+
if h * w <= max_pixels:
|
| 122 |
+
return img_rgb
|
| 123 |
+
scale = (max_pixels / (h * w)) ** 0.5
|
| 124 |
+
new_w, new_h = int(w * scale), int(h * scale)
|
| 125 |
+
return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def upscale_ui(image: np.ndarray, scale_factor: int, method: str):
|
| 129 |
+
if image is None:
|
| 130 |
+
return None, "Please upload an image."
|
| 131 |
+
if image.dtype != np.uint8:
|
| 132 |
+
image = np.clip(image, 0, 255).astype(np.uint8)
|
| 133 |
+
if image.ndim == 2:
|
| 134 |
+
image = np.stack([image]*3, axis=-1)
|
| 135 |
+
elif image.shape[2] == 4:
|
| 136 |
+
image = image[..., :3]
|
| 137 |
+
|
| 138 |
+
image = maybe_downscale_for_memory(image)
|
| 139 |
+
weights_path = WEIGHTS_PATHS.get(scale_factor)
|
| 140 |
+
if method == "FSRCNN (Y channel)":
|
| 141 |
+
out = fsrcnn_upscale_rgb(image, scale_factor, weights_path)
|
| 142 |
+
status = f"Used FSRCNN x{scale_factor} (bundled weights)."
|
| 143 |
+
else:
|
| 144 |
+
out = cv2.resize(image, (image.shape[1]*scale_factor, image.shape[0]*scale_factor), interpolation=cv2.INTER_CUBIC)
|
| 145 |
+
status = f"Used Bicubic x{scale_factor}."
|
| 146 |
+
return out, status
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ============================================================
|
| 150 |
+
# 🌍 PART 2: Multilingual Translator
|
| 151 |
+
# ============================================================
|
| 152 |
+
|
| 153 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 154 |
+
client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN)
|
| 155 |
+
|
| 156 |
+
lang_map = {
|
| 157 |
+
"English": "en_XX", "French": "fr_XX", "Spanish": "es_XX", "German": "de_DE",
|
| 158 |
+
"Hindi": "hi_IN", "Chinese": "zh_CN", "Japanese": "ja_XX", "Korean": "ko_KR",
|
| 159 |
+
"Tamil": "ta_IN", "Telugu": "te_IN", "Arabic": "ar_AR", "Russian": "ru_RU"
|
| 160 |
+
# You can add full map from your existing file if needed
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def translate_text(text, src_lang, tgt_lang):
|
| 164 |
+
if not text.strip():
|
| 165 |
+
return "Please enter any text to translate 😃"
|
| 166 |
+
try:
|
| 167 |
+
src_code, tgt_code = lang_map[src_lang], lang_map[tgt_lang]
|
| 168 |
+
result = client.translation(
|
| 169 |
+
text,
|
| 170 |
+
model="facebook/mbart-large-50-many-to-many-mmt",
|
| 171 |
+
src_lang=src_code,
|
| 172 |
+
tgt_lang=tgt_code
|
| 173 |
+
)
|
| 174 |
+
return result.translation_text
|
| 175 |
+
except Exception as e:
|
| 176 |
+
return f"Error in translation: {str(e)}"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ============================================================
|
| 180 |
+
# 🎨 Combine into One Interface with Tabs
|
| 181 |
+
# ============================================================
|
| 182 |
+
|
| 183 |
+
custom_theme = gr.themes.Default().set(
|
| 184 |
+
button_primary_background_fill="#1769aa",
|
| 185 |
+
button_primary_text_color="#ffffff",
|
| 186 |
+
button_secondary_background_fill="#e0e0e0",
|
| 187 |
+
button_secondary_text_color="#222222"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
with gr.Blocks(theme=custom_theme, title="AI Multi-Tool: FSRCNN & Translator") as demo:
|
| 191 |
+
gr.Markdown("# 🚀 AI Multi-Tool Suite\nChoose an application below 👇")
|
| 192 |
+
|
| 193 |
+
with gr.Tabs():
|
| 194 |
+
# Tab 1: FSRCNN Upscaler
|
| 195 |
+
with gr.Tab("🖼️ Image Upscaling"):
|
| 196 |
+
with gr.Row():
|
| 197 |
+
with gr.Column():
|
| 198 |
+
inp_img = gr.Image(type="numpy", label="Input Image")
|
| 199 |
+
scale = gr.Dropdown([2, 3, 4], value=2, label="Upscale Factor")
|
| 200 |
+
method = gr.Radio(["FSRCNN (Y channel)", "Bicubic"], value="FSRCNN (Y channel)")
|
| 201 |
+
run_btn = gr.Button("Upscale", variant="primary")
|
| 202 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
| 203 |
+
status_box = gr.Textbox(label="Status")
|
| 204 |
+
with gr.Column():
|
| 205 |
+
out_img = gr.Image(type="numpy", label="Upscaled Output")
|
| 206 |
+
|
| 207 |
+
run_btn.click(fn=upscale_ui, inputs=[inp_img, scale, method], outputs=[out_img, status_box])
|
| 208 |
+
clear_btn.click(fn=lambda: (None, 2, "FSRCNN (Y channel)", None, ""),
|
| 209 |
+
outputs=[inp_img, scale, method, out_img, status_box])
|
| 210 |
+
|
| 211 |
+
# Tab 2: Translator
|
| 212 |
+
with gr.Tab("🌍 Text Translator"):
|
| 213 |
+
with gr.Row():
|
| 214 |
+
with gr.Column():
|
| 215 |
+
src_lang = gr.Dropdown(choices=list(lang_map.keys()), value="English", label="Source Language")
|
| 216 |
+
input_text = gr.Textbox(lines=4, label="Enter Text")
|
| 217 |
+
with gr.Column():
|
| 218 |
+
tgt_lang = gr.Dropdown(choices=list(lang_map.keys()), value="French", label="Target Language")
|
| 219 |
+
output_text = gr.Textbox(lines=4, label="Translation", interactive=False)
|
| 220 |
+
|
| 221 |
+
translate_btn = gr.Button("Translate ✨", variant="primary")
|
| 222 |
+
clear_btn2 = gr.Button("Clear", variant="secondary")
|
| 223 |
+
|
| 224 |
+
translate_btn.click(fn=translate_text, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
|
| 225 |
+
clear_btn2.click(fn=lambda: ("", "English", "French", ""), outputs=[input_text, src_lang, tgt_lang, output_text])
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
demo.launch()
|
deploy.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
sudo apt update
|
| 4 |
+
# sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6 ffmpeg
|
| 5 |
+
sudo apt install -y libgl1-mesa-glx
|
| 6 |
+
|
| 7 |
+
cd ~/projects/
|
| 8 |
+
if [ -d "Image_upscaling" ]; then
|
| 9 |
+
cd Image_upscaling && git pull origin main
|
| 10 |
+
else
|
| 11 |
+
git clone https://github.com/harikp196/Image_upscaling.git
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
cd ~/projects
|
| 15 |
+
python3 -m venv venv
|
| 16 |
+
source venv/bin/activate
|
| 17 |
+
|
| 18 |
+
pip install --upgrade pip
|
| 19 |
+
pip install -r ~/projects/Image_upscaling/requirements.txt
|
| 20 |
+
|
| 21 |
+
nohup python3 ~/projects/Image_upscaling/app.py > ~/projects/project1.log 2>&1 &
|
models/fsrcnn_x2.pth
ADDED
|
Binary file (55 kB). View file
|
|
|
models/fsrcnn_x3.pth
ADDED
|
Binary file (55 kB). View file
|
|
|
models/fsrcnn_x4.pth
ADDED
|
Binary file (55 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
opencv-python
|
| 3 |
+
numpy
|
| 4 |
+
torch
|
| 5 |
+
tqdm
|
| 6 |
+
setuptools
|
| 7 |
+
torchvision
|
| 8 |
+
Pillow
|
| 9 |
+
natsort
|
| 10 |
+
huggingface_hub
|
secrets.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
HF_TOKEN==>REDACTED
|
tests/test2.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_fsrcnn.py
|
| 2 |
+
import os
|
| 3 |
+
import importlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
# Adjust this import if your file isn't named fsrcnn_app.py
|
| 9 |
+
import app as app
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture(autouse=True)
|
| 13 |
+
def _reset_cache_between_tests():
|
| 14 |
+
# Ensure cache isolation between tests
|
| 15 |
+
app.MODEL_CACHE.clear()
|
| 16 |
+
yield
|
| 17 |
+
app.MODEL_CACHE.clear()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_fsrcnn_forward_output_shape_cpu_only():
|
| 21 |
+
"""FSRCNN forward should upscale 1-channel input by its scale factor."""
|
| 22 |
+
model = app.FSRCNN(scale_factor=3).eval()
|
| 23 |
+
x = torch.randn(1, 1, 10, 12) # (N, C, H, W)
|
| 24 |
+
with torch.inference_mode():
|
| 25 |
+
y = model(x)
|
| 26 |
+
assert y.shape == (1, 1, 30, 36), "Output shape must be (H*scale, W*scale)"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_run_fsrcnn_on_y_shape_and_dtype():
|
| 30 |
+
"""run_fsrcnn_on_y should return uint8 image with upscaled spatial dims."""
|
| 31 |
+
y = np.random.randint(0, 256, (9, 7), dtype=np.uint8)
|
| 32 |
+
model = app.FSRCNN(scale_factor=2).eval()
|
| 33 |
+
out = app.run_fsrcnn_on_y(y, model)
|
| 34 |
+
assert out.dtype == np.uint8
|
| 35 |
+
assert out.shape == (9 * 2, 7 * 2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_bicubic_upscale_rgb_shape_and_dtype():
|
| 39 |
+
rgb = np.random.randint(0, 256, (16, 24, 3), dtype=np.uint8)
|
| 40 |
+
out = app.bicubic_upscale_rgb(rgb, scale=4)
|
| 41 |
+
assert out.dtype == np.uint8
|
| 42 |
+
assert out.shape == (16 * 4, 24 * 4, 3)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_rgb_ycbcr_roundtrip_close():
|
| 46 |
+
"""RGB -> YCrCb -> RGB roundtrip should be close (small max diff)."""
|
| 47 |
+
rgb = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
|
| 48 |
+
ycrcb = app.rgb_to_ycbcr(rgb)
|
| 49 |
+
back = app.ycbcr_to_rgb(ycrcb)
|
| 50 |
+
# Allow small numerical differences from color conversion
|
| 51 |
+
assert np.max(np.abs(back.astype(int) - rgb.astype(int))) <= 2
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_fsrcnn_upscale_falls_back_to_bicubic_when_no_weights(tmp_path):
|
| 55 |
+
"""When no valid weights are provided, FSRCNN code must return bicubic result."""
|
| 56 |
+
rgb = np.random.randint(0, 256, (12, 10, 3), dtype=np.uint8)
|
| 57 |
+
scale = 3
|
| 58 |
+
|
| 59 |
+
# Ensure a fresh cache so the "no-weights" path is exercised
|
| 60 |
+
app.MODEL_CACHE.clear()
|
| 61 |
+
|
| 62 |
+
out_fallback = app.fsrcnn_upscale_rgb(rgb, scale=scale, weights=None)
|
| 63 |
+
out_bicubic = app.bicubic_upscale_rgb(rgb, scale=scale)
|
| 64 |
+
|
| 65 |
+
assert out_fallback.shape == out_bicubic.shape
|
| 66 |
+
# Code path returns bicubic directly; should be byte-identical
|
| 67 |
+
assert np.array_equal(out_fallback, out_bicubic)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_ui_accepts_grayscale_and_rgba_and_clips():
|
| 71 |
+
"""The UI helper should handle grayscale, RGBA, and non-uint8 inputs."""
|
| 72 |
+
# Grayscale -> stacked to RGB
|
| 73 |
+
gray = np.random.randint(0, 256, (8, 8), dtype=np.uint8)
|
| 74 |
+
out_gray = app.upscale_ui(gray, 2, "Bicubic", "", "", "")
|
| 75 |
+
assert out_gray.shape == (16, 16, 3)
|
| 76 |
+
assert out_gray.dtype == np.uint8
|
| 77 |
+
|
| 78 |
+
# RGBA -> drop alpha
|
| 79 |
+
rgba = np.random.randint(0, 256, (8, 8, 4), dtype=np.uint8)
|
| 80 |
+
out_rgba = app.upscale_ui(rgba, 2, "Bicubic", "", "", "")
|
| 81 |
+
assert out_rgba.shape == (16, 16, 3)
|
| 82 |
+
assert out_rgba.dtype == np.uint8
|
| 83 |
+
|
| 84 |
+
# Float input -> should clip/convert to uint8 internally
|
| 85 |
+
f_rgb = np.random.randn(8, 8, 3).astype(np.float32) * 1000.0 # intentionally wild
|
| 86 |
+
out_float = app.upscale_ui(f_rgb, 2, "Bicubic", "", "", "")
|
| 87 |
+
assert out_float.dtype == np.uint8
|
| 88 |
+
assert out_float.shape == (16, 16, 3)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_maybe_downscale_for_memory_respects_limit():
|
| 92 |
+
big = np.random.randint(0, 256, (4000, 4000, 3), dtype=np.uint8) # 16M px
|
| 93 |
+
capped = app.maybe_downscale_for_memory(big, max_pixels=1_000_000)
|
| 94 |
+
assert capped.shape[0] * capped.shape[1] <= 1_000_000
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_get_model_cache_per_scale():
|
| 98 |
+
m2, w2 = app.get_model(2, weights_path=None)
|
| 99 |
+
m3, w3 = app.get_model(3, weights_path=None)
|
| 100 |
+
|
| 101 |
+
# Cache populated for both scales
|
| 102 |
+
assert 2 in app.MODEL_CACHE and 3 in app.MODEL_CACHE
|
| 103 |
+
assert isinstance(m2, app.FSRCNN) and isinstance(m3, app.FSRCNN)
|
| 104 |
+
assert m2 is not m3, "Different scales should use different model instances"
|
tests/test_app.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
from app import FSRCNN, rgb_to_ycbcr, ycbcr_to_rgb, bicubic_upscale_rgb, upscale_ui, try_load_weights
|
| 6 |
+
|
| 7 |
+
def test_fsrcnn_model_initialization():
|
| 8 |
+
for scale in [2, 3, 4]:
|
| 9 |
+
model = FSRCNN(scale_factor=scale)
|
| 10 |
+
assert model is not None
|
| 11 |
+
assert hasattr(model, 'first_part')
|
| 12 |
+
assert hasattr(model, 'mid_part')
|
| 13 |
+
assert hasattr(model, 'last_part')
|
| 14 |
+
|
| 15 |
+
def test_color_conversion():
|
| 16 |
+
test_img = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
|
| 17 |
+
ycbcr = rgb_to_ycbcr(test_img)
|
| 18 |
+
reconstructed = ycbcr_to_rgb(ycbcr)
|
| 19 |
+
|
| 20 |
+
assert test_img.shape == reconstructed.shape
|
| 21 |
+
assert np.mean(np.abs(test_img.astype(float) - reconstructed.astype(float))) < 2.0
|
| 22 |
+
|
| 23 |
+
def test_bicubic_upscaling():
|
| 24 |
+
test_img = np.random.randint(0, 255, (16, 16, 3), dtype=np.uint8)
|
| 25 |
+
|
| 26 |
+
for scale in [1, 2, 3, 4]:
|
| 27 |
+
upscaled = bicubic_upscale_rgb(test_img, scale)
|
| 28 |
+
expected_shape = (16 * scale, 16 * scale, 3)
|
| 29 |
+
assert upscaled.shape == expected_shape
|
| 30 |
+
|
| 31 |
+
def test_try_load_weights_error(tmp_path):
|
| 32 |
+
model = FSRCNN(scale_factor=2)
|
| 33 |
+
fake_weights = {
|
| 34 |
+
"last_part.bias": torch.zeros(1),
|
| 35 |
+
"a_fake_key.weight": torch.randn(10)
|
| 36 |
+
}
|
| 37 |
+
fake_weights_file = "test.pth"
|
| 38 |
+
torch.save(fake_weights, fake_weights_file)
|
| 39 |
+
result = try_load_weights(model, str(fake_weights_file))
|
| 40 |
+
assert result == True
|
| 41 |
+
|
| 42 |
+
model = FSRCNN(scale_factor=2)
|
| 43 |
+
assert try_load_weights(model, None) == False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
model = FSRCNN(scale_factor=2)
|
| 47 |
+
corrupted_file = tmp_path/"corrupted.pth"
|
| 48 |
+
corrupted_file.write_text("this is just a text file, not a model!")
|
| 49 |
+
result = try_load_weights(model, str(corrupted_file))
|
| 50 |
+
assert result == False
|
| 51 |
+
|
| 52 |
+
def test_try_load_weights():
|
| 53 |
+
model = FSRCNN(scale_factor=2)
|
| 54 |
+
assert try_load_weights(model, "../models/fsrcnn_x2.pth") == False
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_model_forward_pass():
|
| 58 |
+
for scale in [2, 3, 4]:
|
| 59 |
+
model = FSRCNN(scale_factor=scale)
|
| 60 |
+
dummy_input = np.random.rand(1, 1, 32, 32).astype(np.float32)
|
| 61 |
+
output = model(torch.from_numpy(dummy_input))
|
| 62 |
+
|
| 63 |
+
expected_height = 32 * scale
|
| 64 |
+
expected_width = 32 * scale
|
| 65 |
+
assert output.shape[2] == expected_height
|
| 66 |
+
assert output.shape[3] == expected_width
|
| 67 |
+
|
| 68 |
+
def test_upscale_ui_noimage():
|
| 69 |
+
assert upscale_ui(None, 2, "FSRCNN (Y channel)") == (None, 'Please upload an image.')
|
| 70 |
+
|
| 71 |
+
def test_upscale_ui():
|
| 72 |
+
# Float input
|
| 73 |
+
float_image = np.random.rand(32, 32, 3).astype(np.float32)
|
| 74 |
+
|
| 75 |
+
result = upscale_ui(
|
| 76 |
+
image=float_image,
|
| 77 |
+
scale_factor=2,
|
| 78 |
+
method="Bicubic"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
assert result[0] is not None
|
| 82 |
+
assert result[0].dtype == np.uint8
|
| 83 |
+
assert result[0].shape == (64, 64, 3)
|
| 84 |
+
|
| 85 |
+
"""Test upscale_ui with grayscale (2D) input"""
|
| 86 |
+
grayscale_image = np.random.randint(0, 255, (32, 32), dtype=np.uint8)
|
| 87 |
+
result = upscale_ui(
|
| 88 |
+
image=grayscale_image,
|
| 89 |
+
scale_factor=2,
|
| 90 |
+
method="FSRCNN (Y channel)",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
assert result[0] is not None
|
| 94 |
+
assert result[0].dtype == np.uint8
|
| 95 |
+
assert result[0].shape == (64, 64, 3)
|
| 96 |
+
|
| 97 |
+
"""Test upscale_ui with RGBA input"""
|
| 98 |
+
rgba_image = np.random.randint(0, 255, (32, 32, 4), dtype=np.uint8)
|
| 99 |
+
|
| 100 |
+
result = upscale_ui(
|
| 101 |
+
image=rgba_image,
|
| 102 |
+
scale_factor=2,
|
| 103 |
+
method="Bicubic",
|
| 104 |
+
)
|
| 105 |
+
assert result[0] is not None
|
| 106 |
+
assert result[0].dtype == np.uint8
|
| 107 |
+
assert result[0].shape == (64, 64, 3)
|
| 108 |
+
|
| 109 |
+
downloadscale_image = np.random.randint(0, 255, (4000, 4000, 3), dtype=np.uint8)
|
| 110 |
+
|
| 111 |
+
result = upscale_ui(
|
| 112 |
+
image=downloadscale_image,
|
| 113 |
+
scale_factor=2,
|
| 114 |
+
method="FSRCNN (Y channel)",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
assert result[0] is not None
|
| 118 |
+
assert result[0].dtype == np.uint8
|
| 119 |
+
assert result[0].shape == (5656, 5656, 3)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
test_img = np.random.randint(0, 255, (16, 16, 3), dtype=np.uint8)
|
| 124 |
+
for scale in [2, 3, 4]:
|
| 125 |
+
upscaled = upscale_ui(test_img, scale, "FSRCNN (Y channel)")
|
| 126 |
+
expected_shape = (16 * scale, 16 * scale, 3)
|
| 127 |
+
assert upscaled[0].shape == expected_shape
|
| 128 |
+
|
| 129 |
+
def test_upscale_ui_bicubic():
|
| 130 |
+
test_img = np.random.randint(0, 255, (16, 16, 3), dtype=np.uint8)
|
| 131 |
+
for scale in [2, 3, 4]:
|
| 132 |
+
upscaled = upscale_ui(test_img, scale, "Bicubic")
|
| 133 |
+
expected_shape = (16 * scale, 16 * scale, 3)
|
| 134 |
+
assert upscaled[0].shape == expected_shape
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
pytest.main([__file__])
|