Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import typing as tp | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # ============================================================ | |
| # π§ PART 1: FSRCNN Image Upscaling | |
| # ============================================================ | |
| class FSRCNN(nn.Module): | |
| def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4): | |
| super(FSRCNN, self).__init__() | |
| self.first_part = nn.Sequential( | |
| nn.Conv2d(num_channels, d, kernel_size=5, padding=5 // 2), | |
| nn.PReLU(d) | |
| ) | |
| self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)] | |
| for _ in range(m): | |
| self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3 // 2), nn.PReLU(s)]) | |
| self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)]) | |
| self.mid_part = nn.Sequential(*self.mid_part) | |
| self.last_part = nn.ConvTranspose2d( | |
| d, num_channels, kernel_size=9, | |
| stride=scale_factor, padding=9 // 2, | |
| output_padding=scale_factor - 1 | |
| ) | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| for m in self.first_part: | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.normal_(m.weight.data, mean=0.0, | |
| std=math.sqrt(2 / (m.out_channels * m.weight.data[0][0].numel()))) | |
| nn.init.zeros_(m.bias.data) | |
| for m in self.mid_part: | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.normal_(m.weight.data, mean=0.0, | |
| std=math.sqrt(2 / (m.out_channels * m.weight.data[0][0].numel()))) | |
| nn.init.zeros_(m.bias.data) | |
| nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) | |
| nn.init.zeros_(self.last_part.bias.data) | |
| def forward(self, x): | |
| x = self.first_part(x) | |
| x = self.mid_part(x) | |
| x = self.last_part(x) | |
| return x | |
| Device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_CACHE: dict[int, tuple[FSRCNN, bool]] = {} | |
| WEIGHTS_PATHS = {2: "models/fsrcnn_x2.pth", 3: "models/fsrcnn_x3.pth", 4: "models/fsrcnn_x4.pth"} | |
| def try_load_weights(model, weights_path): | |
| if not weights_path or not os.path.isfile(weights_path): | |
| print(f"[FSRCNN] No valid weights at {weights_path}. Falling back to Bicubic.") | |
| return False | |
| try: | |
| checkpoint = torch.load(weights_path, map_location=Device, weights_only=False) | |
| model.load_state_dict(checkpoint, strict=True) | |
| print(f"[FSRCNN] Loaded weights from {weights_path}") | |
| return True | |
| except Exception as e: | |
| print(f"[FSRCNN] Failed to load weights: {e}") | |
| return False | |
| def get_model(scale, weights_path=None): | |
| if scale not in MODEL_CACHE: | |
| model = FSRCNN(scale_factor=scale).to(Device).eval() | |
| has_weights = try_load_weights(model, weights_path) | |
| MODEL_CACHE[scale] = (model, has_weights) | |
| return MODEL_CACHE[scale] | |
| def rgb_to_ycbcr(img_rgb: np.ndarray) -> np.ndarray: | |
| return cv2.cvtColor(img_rgb, cv2.COLOR_RGB2YCrCb) | |
| def ycbcr_to_rgb(img_ycrcb: np.ndarray) -> np.ndarray: | |
| return cv2.cvtColor(img_ycrcb, cv2.COLOR_YCrCb2RGB) | |
| def run_fsrcnn_on_y(y: np.ndarray, model: FSRCNN) -> np.ndarray: | |
| y_f = y.astype(np.float32) / 255.0 | |
| tens = torch.from_numpy(y_f).unsqueeze(0).unsqueeze(0).to(Device) | |
| with torch.inference_mode(): | |
| out = model(tens) | |
| out_np = out.squeeze(0).squeeze(0).clamp(0.0, 1.0).cpu().numpy() | |
| return (out_np * 255.0 + 0.5).astype(np.uint8) | |
| def fsrcnn_upscale_rgb(img_rgb: np.ndarray, scale: int, weights: tp.Optional[str] = None) -> np.ndarray: | |
| h, w = img_rgb.shape[:2] | |
| model, has_weights = get_model(scale, weights) | |
| if not has_weights: | |
| return cv2.resize(img_rgb, (w * scale, h * scale), interpolation=cv2.INTER_CUBIC) | |
| ycrcb = rgb_to_ycbcr(img_rgb) | |
| y, cr, cb = ycrcb[..., 0], ycrcb[..., 1], ycrcb[..., 2] | |
| y_sr = run_fsrcnn_on_y(y, model) | |
| new_w, new_h = w * scale, h * scale | |
| cr_up = cv2.resize(cr, (new_w, new_h), interpolation=cv2.INTER_CUBIC) | |
| cb_up = cv2.resize(cb, (new_w, new_h), interpolation=cv2.INTER_CUBIC) | |
| ycrcb_up = np.stack([y_sr, cr_up, cb_up], axis=-1) | |
| return ycbcr_to_rgb(ycrcb_up) | |
| def maybe_downscale_for_memory(img_rgb: np.ndarray, max_pixels: int = 8_000_000) -> np.ndarray: | |
| h, w = img_rgb.shape[:2] | |
| if h * w <= max_pixels: | |
| return img_rgb | |
| scale = (max_pixels / (h * w)) ** 0.5 | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| return cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| def upscale_ui(image: np.ndarray, scale_factor: int, method: str): | |
| if image is None: | |
| return None, "Please upload an image." | |
| if image.dtype != np.uint8: | |
| image = np.clip(image, 0, 255).astype(np.uint8) | |
| if image.ndim == 2: | |
| image = np.stack([image]*3, axis=-1) | |
| elif image.shape[2] == 4: | |
| image = image[..., :3] | |
| image = maybe_downscale_for_memory(image) | |
| weights_path = WEIGHTS_PATHS.get(scale_factor) | |
| if method == "FSRCNN (Y channel)": | |
| out = fsrcnn_upscale_rgb(image, scale_factor, weights_path) | |
| status = f"Used FSRCNN x{scale_factor} (bundled weights)." | |
| else: | |
| out = cv2.resize(image, (image.shape[1]*scale_factor, image.shape[0]*scale_factor), interpolation=cv2.INTER_CUBIC) | |
| status = f"Used Bicubic x{scale_factor}." | |
| return out, status | |
| # ============================================================ | |
| # π PART 2: Multilingual Translator | |
| # ============================================================ | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN) | |
| lang_map = { | |
| "English": "en_XX", "French": "fr_XX", "Spanish": "es_XX", "German": "de_DE", | |
| "Hindi": "hi_IN", "Chinese": "zh_CN", "Japanese": "ja_XX", "Korean": "ko_KR", | |
| "Tamil": "ta_IN", "Telugu": "te_IN", "Arabic": "ar_AR", "Russian": "ru_RU" | |
| # You can add full map from your existing file if needed | |
| } | |
| def translate_text(text, src_lang, tgt_lang): | |
| if not text.strip(): | |
| return "Please enter any text to translate π" | |
| try: | |
| src_code, tgt_code = lang_map[src_lang], lang_map[tgt_lang] | |
| result = client.translation( | |
| text, | |
| model="facebook/mbart-large-50-many-to-many-mmt", | |
| src_lang=src_code, | |
| tgt_lang=tgt_code | |
| ) | |
| return result.translation_text | |
| except Exception as e: | |
| return f"Error in translation: {str(e)}" | |
| # ============================================================ | |
| # π¨ Combine into One Interface with Tabs | |
| # ============================================================ | |
| custom_theme = gr.themes.Default().set( | |
| button_primary_background_fill="#1769aa", | |
| button_primary_text_color="#ffffff", | |
| button_secondary_background_fill="#e0e0e0", | |
| button_secondary_text_color="#222222" | |
| ) | |
| with gr.Blocks(theme=custom_theme, title="AI Multi-Tool: FSRCNN & Translator") as demo: | |
| gr.Markdown("# π AI Multi-Tool Suite\nChoose an application below π") | |
| with gr.Tabs(): | |
| # Tab 1: FSRCNN Upscaler | |
| with gr.Tab("πΌοΈ Image Upscaling"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp_img = gr.Image(type="numpy", label="Input Image") | |
| scale = gr.Dropdown([2, 3, 4], value=2, label="Upscale Factor") | |
| method = gr.Radio(["FSRCNN (Y channel)", "Bicubic"], value="FSRCNN (Y channel)") | |
| run_btn = gr.Button("Upscale", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| status_box = gr.Textbox(label="Status") | |
| with gr.Column(): | |
| out_img = gr.Image(type="numpy", label="Upscaled Output") | |
| run_btn.click(fn=upscale_ui, inputs=[inp_img, scale, method], outputs=[out_img, status_box]) | |
| clear_btn.click(fn=lambda: (None, 2, "FSRCNN (Y channel)", None, ""), | |
| outputs=[inp_img, scale, method, out_img, status_box]) | |
| # Tab 2: Translator | |
| with gr.Tab("π Text Translator"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| src_lang = gr.Dropdown(choices=list(lang_map.keys()), value="English", label="Source Language") | |
| input_text = gr.Textbox(lines=4, label="Enter Text") | |
| with gr.Column(): | |
| tgt_lang = gr.Dropdown(choices=list(lang_map.keys()), value="French", label="Target Language") | |
| output_text = gr.Textbox(lines=4, label="Translation", interactive=False) | |
| translate_btn = gr.Button("Translate β¨", variant="primary") | |
| clear_btn2 = gr.Button("Clear", variant="secondary") | |
| translate_btn.click(fn=translate_text, inputs=[input_text, src_lang, tgt_lang], outputs=output_text) | |
| clear_btn2.click(fn=lambda: ("", "English", "French", ""), outputs=[input_text, src_lang, tgt_lang, output_text]) | |
| if __name__ == "__main__": | |
| demo.launch() | |