Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import json | |
| import tempfile | |
| from pathlib import Path | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from model.FGResQ import FGResQ | |
| # Path to model weights | |
| try: | |
| MODEL_PATH = hf_hub_download( | |
| repo_id="orpheus0429/FGResQ", | |
| filename="weights/FGResQ.pth" | |
| ) | |
| except Exception as e: | |
| print(f"Failed to download model: {e}") | |
| # Fallback to local weights if available | |
| MODEL_PATH = "weights/FGResQ.pth" | |
| print(f"Loading model from {MODEL_PATH}") | |
| model = FGResQ(model_path=MODEL_PATH) | |
| def _save_temp_image(img): | |
| if img is None: | |
| return None | |
| if isinstance(img, str) and Path(img).exists(): | |
| return img | |
| if isinstance(img, Image.Image): | |
| tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| tf.close() | |
| img.save(tf.name) | |
| return tf.name | |
| try: | |
| pil = Image.fromarray(img) | |
| tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| tf.close() | |
| pil.save(tf.name) | |
| return tf.name | |
| except Exception: | |
| return None | |
| def predict_single(image, task): | |
| img_path = _save_temp_image(image) | |
| if img_path is None: | |
| return "No image provided" | |
| try: | |
| score = model.predict_single(img_path) | |
| if score is None: | |
| return "Model returned None" | |
| return f"{score:.4f}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| finally: | |
| pass | |
| def predict_pair(image1, image2, task): | |
| p1 = _save_temp_image(image1) | |
| p2 = _save_temp_image(image2) | |
| if not p1 or not p2: | |
| return "Missing images" | |
| try: | |
| result = model.predict_pair(p1, p2) | |
| if result is None: | |
| return "Model returned None" | |
| comparison = result.get("comparison") | |
| return comparison | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def load_examples(json_path, mode="single"): | |
| if not os.path.exists(json_path): | |
| print(f"Warning: {json_path} not found") | |
| return [] | |
| with open(json_path, 'r') as f: | |
| data = json.load(f) | |
| examples = [] | |
| for item in data: | |
| if mode == "single": | |
| # Prepend fineData/ to the path | |
| img_path = os.path.join("fineData", item["image_name"]) | |
| task = item.get("task", "") | |
| if os.path.exists(img_path): | |
| examples.append([img_path, task]) | |
| elif mode == "pair": | |
| img_path_A = os.path.join("fineData", item["image_nameA"]) | |
| img_path_B = os.path.join("fineData", item["image_nameB"]) | |
| task = item.get("task", "") | |
| if os.path.exists(img_path_A) and os.path.exists(img_path_B): | |
| examples.append([img_path_A, img_path_B, task]) | |
| return examples | |
| single_examples = load_examples("score_example.json", mode="single") | |
| pair_examples = load_examples("rank_example.json", mode="pair") | |
| with gr.Blocks(title="FGResQ Demo") as demo: | |
| gr.Markdown("Fine-grained Image Quality Assessment for Perceptual Image Restoration Demo") | |
| with gr.Tab("Single Image Mode"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| task_input = gr.Textbox(label="Task") | |
| submit_btn = gr.Button("Predict") | |
| with gr.Column(): | |
| score_output = gr.Textbox(label="Quality Score") | |
| if single_examples: | |
| gr.Examples( | |
| examples=single_examples, | |
| inputs=[image_input, task_input], | |
| label="Examples" | |
| ) | |
| submit_btn.click( | |
| fn=predict_single, | |
| inputs=[image_input, task_input], | |
| outputs=score_output | |
| ) | |
| with gr.Tab("Pairwise Mode"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image_input_A = gr.Image(type="pil", label="Image A") | |
| image_input_B = gr.Image(type="pil", label="Image B") | |
| task_input_pair = gr.Textbox(label="Task") | |
| submit_btn_pair = gr.Button("Compare") | |
| with gr.Column(): | |
| compare_output = gr.Textbox(label="Comparison Result") | |
| if pair_examples: | |
| gr.Examples( | |
| examples=pair_examples, | |
| inputs=[image_input_A, image_input_B, task_input_pair], | |
| label="Examples" | |
| ) | |
| submit_btn_pair.click( | |
| fn=predict_pair, | |
| inputs=[image_input_A, image_input_B, task_input_pair], | |
| outputs=compare_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |