File size: 4,744 Bytes
94d75d1
9553ed0
94d75d1
9553ed0
94d75d1
9553ed0
86a1081
94d75d1
9553ed0
94d75d1
 
 
 
 
 
 
 
 
 
9553ed0
94d75d1
9553ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94d75d1
9553ed0
 
94d75d1
9553ed0
 
94d75d1
 
 
 
 
 
 
9553ed0
 
94d75d1
9553ed0
 
 
94d75d1
 
9553ed0
 
94d75d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9553ed0
94d75d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9553ed0
94d75d1
9553ed0
 
 
94d75d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9553ed0
 
 
94d75d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9553ed0
94d75d1
 
 
 
 
9553ed0
4a4586c
9553ed0
94d75d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import os
import json
import tempfile
from pathlib import Path
from PIL import Image
from huggingface_hub import hf_hub_download
from model.FGResQ import FGResQ

# Path to model weights
try:
    MODEL_PATH = hf_hub_download(
        repo_id="orpheus0429/FGResQ",
        filename="weights/FGResQ.pth" 
    )
except Exception as e:
    print(f"Failed to download model: {e}")
    # Fallback to local weights if available
    MODEL_PATH = "weights/FGResQ.pth"

print(f"Loading model from {MODEL_PATH}")
model = FGResQ(model_path=MODEL_PATH)

def _save_temp_image(img):
    if img is None:
        return None
    if isinstance(img, str) and Path(img).exists():
        return img
    if isinstance(img, Image.Image):
        tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        tf.close()
        img.save(tf.name)
        return tf.name
    try:
        pil = Image.fromarray(img)
        tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        tf.close()
        pil.save(tf.name)
        return tf.name
    except Exception:
        return None

def predict_single(image, task):
    img_path = _save_temp_image(image)
    if img_path is None:
        return "No image provided"
    
    try:
        score = model.predict_single(img_path)
        if score is None:
            return "Model returned None"
        return f"{score:.4f}"
    except Exception as e:
        return f"Error: {str(e)}"
    finally:
        pass

def predict_pair(image1, image2, task):
    p1 = _save_temp_image(image1)
    p2 = _save_temp_image(image2)
    
    if not p1 or not p2:
        return "Missing images"
    
    try:
        result = model.predict_pair(p1, p2)
        if result is None:
            return "Model returned None"
        
        comparison = result.get("comparison")
        return comparison
    except Exception as e:
        return f"Error: {str(e)}"

def load_examples(json_path, mode="single"):
    if not os.path.exists(json_path):
        print(f"Warning: {json_path} not found")
        return []
        
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    examples = []
    for item in data:
        if mode == "single":
            # Prepend fineData/ to the path
            img_path = os.path.join("fineData", item["image_name"])
            task = item.get("task", "")
            if os.path.exists(img_path):
                examples.append([img_path, task])
        elif mode == "pair":
            img_path_A = os.path.join("fineData", item["image_nameA"])
            img_path_B = os.path.join("fineData", item["image_nameB"])
            task = item.get("task", "")
            if os.path.exists(img_path_A) and os.path.exists(img_path_B):
                examples.append([img_path_A, img_path_B, task])
    return examples

single_examples = load_examples("score_example.json", mode="single")
pair_examples = load_examples("rank_example.json", mode="pair")

with gr.Blocks(title="FGResQ Demo") as demo:
    gr.Markdown("Fine-grained Image Quality Assessment for Perceptual Image Restoration Demo")

    with gr.Tab("Single Image Mode"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Input Image")
                task_input = gr.Textbox(label="Task")
                submit_btn = gr.Button("Predict")
            with gr.Column():
                score_output = gr.Textbox(label="Quality Score")
        
        if single_examples:
            gr.Examples(
                examples=single_examples,
                inputs=[image_input, task_input],
                label="Examples"
            )
        
        submit_btn.click(
            fn=predict_single, 
            inputs=[image_input, task_input], 
            outputs=score_output
        )

    with gr.Tab("Pairwise Mode"):
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    image_input_A = gr.Image(type="pil", label="Image A")
                    image_input_B = gr.Image(type="pil", label="Image B")
                task_input_pair = gr.Textbox(label="Task")
                submit_btn_pair = gr.Button("Compare")
            with gr.Column():
                compare_output = gr.Textbox(label="Comparison Result")
        
        if pair_examples:
            gr.Examples(
                examples=pair_examples,
                inputs=[image_input_A, image_input_B, task_input_pair],
                label="Examples"
            )

        submit_btn_pair.click(
            fn=predict_pair, 
            inputs=[image_input_A, image_input_B, task_input_pair], 
            outputs=compare_output
        )


if __name__ == "__main__":
    demo.launch()