FGResQ / app.py
orpheus0429's picture
Update app.py
4a4586c verified
import gradio as gr
import os
import json
import tempfile
from pathlib import Path
from PIL import Image
from huggingface_hub import hf_hub_download
from model.FGResQ import FGResQ
# Path to model weights
try:
MODEL_PATH = hf_hub_download(
repo_id="orpheus0429/FGResQ",
filename="weights/FGResQ.pth"
)
except Exception as e:
print(f"Failed to download model: {e}")
# Fallback to local weights if available
MODEL_PATH = "weights/FGResQ.pth"
print(f"Loading model from {MODEL_PATH}")
model = FGResQ(model_path=MODEL_PATH)
def _save_temp_image(img):
if img is None:
return None
if isinstance(img, str) and Path(img).exists():
return img
if isinstance(img, Image.Image):
tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tf.close()
img.save(tf.name)
return tf.name
try:
pil = Image.fromarray(img)
tf = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tf.close()
pil.save(tf.name)
return tf.name
except Exception:
return None
def predict_single(image, task):
img_path = _save_temp_image(image)
if img_path is None:
return "No image provided"
try:
score = model.predict_single(img_path)
if score is None:
return "Model returned None"
return f"{score:.4f}"
except Exception as e:
return f"Error: {str(e)}"
finally:
pass
def predict_pair(image1, image2, task):
p1 = _save_temp_image(image1)
p2 = _save_temp_image(image2)
if not p1 or not p2:
return "Missing images"
try:
result = model.predict_pair(p1, p2)
if result is None:
return "Model returned None"
comparison = result.get("comparison")
return comparison
except Exception as e:
return f"Error: {str(e)}"
def load_examples(json_path, mode="single"):
if not os.path.exists(json_path):
print(f"Warning: {json_path} not found")
return []
with open(json_path, 'r') as f:
data = json.load(f)
examples = []
for item in data:
if mode == "single":
# Prepend fineData/ to the path
img_path = os.path.join("fineData", item["image_name"])
task = item.get("task", "")
if os.path.exists(img_path):
examples.append([img_path, task])
elif mode == "pair":
img_path_A = os.path.join("fineData", item["image_nameA"])
img_path_B = os.path.join("fineData", item["image_nameB"])
task = item.get("task", "")
if os.path.exists(img_path_A) and os.path.exists(img_path_B):
examples.append([img_path_A, img_path_B, task])
return examples
single_examples = load_examples("score_example.json", mode="single")
pair_examples = load_examples("rank_example.json", mode="pair")
with gr.Blocks(title="FGResQ Demo") as demo:
gr.Markdown("Fine-grained Image Quality Assessment for Perceptual Image Restoration Demo")
with gr.Tab("Single Image Mode"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Input Image")
task_input = gr.Textbox(label="Task")
submit_btn = gr.Button("Predict")
with gr.Column():
score_output = gr.Textbox(label="Quality Score")
if single_examples:
gr.Examples(
examples=single_examples,
inputs=[image_input, task_input],
label="Examples"
)
submit_btn.click(
fn=predict_single,
inputs=[image_input, task_input],
outputs=score_output
)
with gr.Tab("Pairwise Mode"):
with gr.Row():
with gr.Column():
with gr.Row():
image_input_A = gr.Image(type="pil", label="Image A")
image_input_B = gr.Image(type="pil", label="Image B")
task_input_pair = gr.Textbox(label="Task")
submit_btn_pair = gr.Button("Compare")
with gr.Column():
compare_output = gr.Textbox(label="Comparison Result")
if pair_examples:
gr.Examples(
examples=pair_examples,
inputs=[image_input_A, image_input_B, task_input_pair],
label="Examples"
)
submit_btn_pair.click(
fn=predict_pair,
inputs=[image_input_A, image_input_B, task_input_pair],
outputs=compare_output
)
if __name__ == "__main__":
demo.launch()