Spaces:
Running
Running
File size: 4,744 Bytes
94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 86a1081 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 94d75d1 9553ed0 4a4586c 9553ed0 94d75d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import gradio as gr
import os
import json
import tempfile
from pathlib import Path
from PIL import Image
from huggingface_hub import hf_hub_download
from model.FGResQ import FGResQ
# Path to model weights
try:
MODEL_PATH = hf_hub_download(
repo_id="orpheus0429/FGResQ",
filename="weights/FGResQ.pth"
)
except Exception as e:
print(f"Failed to download model: {e}")
# Fallback to local weights if available
MODEL_PATH = "weights/FGResQ.pth"
print(f"Loading model from {MODEL_PATH}")
model = FGResQ(model_path=MODEL_PATH)
def _save_temp_image(img):
if img is None:
return None
if isinstance(img, str) and Path(img).exists():
return img
if isinstance(img, Image.Image):
tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tf.close()
img.save(tf.name)
return tf.name
try:
pil = Image.fromarray(img)
tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tf.close()
pil.save(tf.name)
return tf.name
except Exception:
return None
def predict_single(image, task):
img_path = _save_temp_image(image)
if img_path is None:
return "No image provided"
try:
score = model.predict_single(img_path)
if score is None:
return "Model returned None"
return f"{score:.4f}"
except Exception as e:
return f"Error: {str(e)}"
finally:
pass
def predict_pair(image1, image2, task):
p1 = _save_temp_image(image1)
p2 = _save_temp_image(image2)
if not p1 or not p2:
return "Missing images"
try:
result = model.predict_pair(p1, p2)
if result is None:
return "Model returned None"
comparison = result.get("comparison")
return comparison
except Exception as e:
return f"Error: {str(e)}"
def load_examples(json_path, mode="single"):
if not os.path.exists(json_path):
print(f"Warning: {json_path} not found")
return []
with open(json_path, 'r') as f:
data = json.load(f)
examples = []
for item in data:
if mode == "single":
# Prepend fineData/ to the path
img_path = os.path.join("fineData", item["image_name"])
task = item.get("task", "")
if os.path.exists(img_path):
examples.append([img_path, task])
elif mode == "pair":
img_path_A = os.path.join("fineData", item["image_nameA"])
img_path_B = os.path.join("fineData", item["image_nameB"])
task = item.get("task", "")
if os.path.exists(img_path_A) and os.path.exists(img_path_B):
examples.append([img_path_A, img_path_B, task])
return examples
single_examples = load_examples("score_example.json", mode="single")
pair_examples = load_examples("rank_example.json", mode="pair")
with gr.Blocks(title="FGResQ Demo") as demo:
gr.Markdown("Fine-grained Image Quality Assessment for Perceptual Image Restoration Demo")
with gr.Tab("Single Image Mode"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
task_input = gr.Textbox(label="Task")
submit_btn = gr.Button("Predict")
with gr.Column():
score_output = gr.Textbox(label="Quality Score")
if single_examples:
gr.Examples(
examples=single_examples,
inputs=[image_input, task_input],
label="Examples"
)
submit_btn.click(
fn=predict_single,
inputs=[image_input, task_input],
outputs=score_output
)
with gr.Tab("Pairwise Mode"):
with gr.Row():
with gr.Column():
with gr.Row():
image_input_A = gr.Image(type="pil", label="Image A")
image_input_B = gr.Image(type="pil", label="Image B")
task_input_pair = gr.Textbox(label="Task")
submit_btn_pair = gr.Button("Compare")
with gr.Column():
compare_output = gr.Textbox(label="Comparison Result")
if pair_examples:
gr.Examples(
examples=pair_examples,
inputs=[image_input_A, image_input_B, task_input_pair],
label="Examples"
)
submit_btn_pair.click(
fn=predict_pair,
inputs=[image_input_A, image_input_B, task_input_pair],
outputs=compare_output
)
if __name__ == "__main__":
demo.launch()
|