import gradio as gr import os import json import tempfile from pathlib import Path from PIL import Image from huggingface_hub import hf_hub_download from model.FGResQ import FGResQ # Path to model weights try: MODEL_PATH = hf_hub_download( repo_id="orpheus0429/FGResQ", filename="weights/FGResQ.pth" ) except Exception as e: print(f"Failed to download model: {e}") # Fallback to local weights if available MODEL_PATH = "weights/FGResQ.pth" print(f"Loading model from {MODEL_PATH}") model = FGResQ(model_path=MODEL_PATH) def _save_temp_image(img): if img is None: return None if isinstance(img, str) and Path(img).exists(): return img if isinstance(img, Image.Image): tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False) tf.close() img.save(tf.name) return tf.name try: pil = Image.fromarray(img) tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False) tf.close() pil.save(tf.name) return tf.name except Exception: return None def predict_single(image, task): img_path = _save_temp_image(image) if img_path is None: return "No image provided" try: score = model.predict_single(img_path) if score is None: return "Model returned None" return f"{score:.4f}" except Exception as e: return f"Error: {str(e)}" finally: pass def predict_pair(image1, image2, task): p1 = _save_temp_image(image1) p2 = _save_temp_image(image2) if not p1 or not p2: return "Missing images" try: result = model.predict_pair(p1, p2) if result is None: return "Model returned None" comparison = result.get("comparison") return comparison except Exception as e: return f"Error: {str(e)}" def load_examples(json_path, mode="single"): if not os.path.exists(json_path): print(f"Warning: {json_path} not found") return [] with open(json_path, 'r') as f: data = json.load(f) examples = [] for item in data: if mode == "single": # Prepend fineData/ to the path img_path = os.path.join("fineData", item["image_name"]) task = item.get("task", "") if os.path.exists(img_path): examples.append([img_path, task]) elif mode == "pair": img_path_A = os.path.join("fineData", item["image_nameA"]) img_path_B = os.path.join("fineData", item["image_nameB"]) task = item.get("task", "") if os.path.exists(img_path_A) and os.path.exists(img_path_B): examples.append([img_path_A, img_path_B, task]) return examples single_examples = load_examples("score_example.json", mode="single") pair_examples = load_examples("rank_example.json", mode="pair") with gr.Blocks(title="FGResQ Demo") as demo: gr.Markdown("Fine-grained Image Quality Assessment for Perceptual Image Restoration Demo") with gr.Tab("Single Image Mode"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Input Image") task_input = gr.Textbox(label="Task") submit_btn = gr.Button("Predict") with gr.Column(): score_output = gr.Textbox(label="Quality Score") if single_examples: gr.Examples( examples=single_examples, inputs=[image_input, task_input], label="Examples" ) submit_btn.click( fn=predict_single, inputs=[image_input, task_input], outputs=score_output ) with gr.Tab("Pairwise Mode"): with gr.Row(): with gr.Column(): with gr.Row(): image_input_A = gr.Image(type="pil", label="Image A") image_input_B = gr.Image(type="pil", label="Image B") task_input_pair = gr.Textbox(label="Task") submit_btn_pair = gr.Button("Compare") with gr.Column(): compare_output = gr.Textbox(label="Comparison Result") if pair_examples: gr.Examples( examples=pair_examples, inputs=[image_input_A, image_input_B, task_input_pair], label="Examples" ) submit_btn_pair.click( fn=predict_pair, inputs=[image_input_A, image_input_B, task_input_pair], outputs=compare_output ) if __name__ == "__main__": demo.launch()