|
|
""" |
|
|
Image Quality Scoring and Interpreting Gradio Interface |
|
|
- Single image scoring |
|
|
- Quality interpretation chat |
|
|
- Multi-GPU distribution for 7B model |
|
|
- Auto-load model on startup |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, AutoTokenizer |
|
|
import gc |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
tokenizer = None |
|
|
|
|
|
def load_model(use_multi_gpu=True): |
|
|
"""Load the Q-SIT model with optional multi-GPU support""" |
|
|
global model, processor, tokenizer |
|
|
|
|
|
|
|
|
if model is not None: |
|
|
del model |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
model_id = "models/q-sit" |
|
|
|
|
|
print(f"Loading model from: {model_id}") |
|
|
print(f"Available GPUs: {torch.cuda.device_count()}") |
|
|
|
|
|
if use_multi_gpu and torch.cuda.device_count() > 1: |
|
|
print(f"Using device_map='auto' to distribute across {torch.cuda.device_count()} GPUs") |
|
|
model = LlavaOnevisionForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map="auto", |
|
|
local_files_only=True, |
|
|
) |
|
|
device_info = "multi-GPU (auto)" |
|
|
else: |
|
|
model = LlavaOnevisionForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True, |
|
|
local_files_only=True, |
|
|
).to(0) |
|
|
device_info = "GPU:0" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id, local_files_only=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
for i in range(torch.cuda.device_count()): |
|
|
allocated = torch.cuda.memory_allocated(i) / 1024**3 |
|
|
total = torch.cuda.get_device_properties(i).total_memory / 1024**3 |
|
|
print(f"GPU {i}: {allocated:.2f}GB / {total:.2f}GB") |
|
|
|
|
|
print(f"Model loaded successfully on {device_info}!") |
|
|
return f"Model loaded from {model_id} on {device_info}\nGPUs: {torch.cuda.device_count()}" |
|
|
|
|
|
def wa5(logits): |
|
|
""" |
|
|
Weighted average for 5-level scoring |
|
|
|
|
|
Scoring formula: |
|
|
score = sum(probability_i * weight_i) |
|
|
|
|
|
Weights: |
|
|
- Excellent: 1.0 |
|
|
- Good: 0.75 |
|
|
- Fair: 0.5 |
|
|
- Poor: 0.25 |
|
|
- Bad: 0.0 |
|
|
""" |
|
|
logprobs = np.array([ |
|
|
logits["Excellent"], |
|
|
logits["Good"], |
|
|
logits["Fair"], |
|
|
logits["Poor"], |
|
|
logits["Bad"] |
|
|
]) |
|
|
probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) |
|
|
return np.inner(probs, np.array([1, 0.75, 0.5, 0.25, 0])), probs |
|
|
|
|
|
def score_single_image(image): |
|
|
"""Score a single image and return score + probabilities""" |
|
|
if model is None or image is None: |
|
|
return None, None |
|
|
|
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
toks = ["Excellent", "Good", "Fair", "Poor", "Bad"] |
|
|
ids_ = [id_[0] for id_ in tokenizer(toks)["input_ids"]] |
|
|
|
|
|
|
|
|
conversation = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": """Assume you are an image quality evaluator. |
|
|
Your rating should be chosen from the following five categories: Excellent, Good, Fair, Poor, and Bad (from high to low). |
|
|
How would you rate the quality of this image?"""}, |
|
|
{"type": "image"}, |
|
|
], |
|
|
}, |
|
|
] |
|
|
|
|
|
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) |
|
|
inputs = processor(images=image, text=prompt, return_tensors='pt') |
|
|
|
|
|
|
|
|
device = next(model.parameters()).device |
|
|
inputs = {k: v.to(device, torch.float16) if v.dtype in [torch.float32, torch.float64] else v.to(device) |
|
|
for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
prefix_text = "The quality of this image is " |
|
|
prefix_ids = tokenizer(prefix_text, return_tensors="pt")["input_ids"].to(device) |
|
|
inputs["input_ids"] = torch.cat([inputs["input_ids"], prefix_ids], dim=-1) |
|
|
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"]) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=1, |
|
|
output_logits=True, |
|
|
return_dict_in_generate=True, |
|
|
) |
|
|
|
|
|
|
|
|
last_logits = output.logits[-1][0].cpu() |
|
|
logits_dict = {tok: last_logits[id_].item() for tok, id_ in zip(toks, ids_)} |
|
|
|
|
|
score, probs = wa5(logits_dict) |
|
|
return score, probs |
|
|
|
|
|
def get_quality_score(image): |
|
|
"""Get quality score for a single image with detailed output""" |
|
|
if model is None: |
|
|
return None, None |
|
|
|
|
|
if image is None: |
|
|
return None, None |
|
|
|
|
|
score, probs = score_single_image(image) |
|
|
if score is None: |
|
|
return None, None |
|
|
|
|
|
score_100 = score * 100 |
|
|
toks = ["Excellent", "Good", "Fair", "Poor", "Bad"] |
|
|
|
|
|
|
|
|
max_idx = np.argmax(probs) |
|
|
rank = toks[max_idx] |
|
|
|
|
|
|
|
|
score_text = f"**Quality Score:** {score_100:.2f}/100\n\n**Rating:** {rank}" |
|
|
|
|
|
|
|
|
table_data = [] |
|
|
for tok, prob in zip(toks, probs): |
|
|
table_data.append([tok, f"{prob*100:.1f}%"]) |
|
|
|
|
|
return score_text, table_data |
|
|
|
|
|
def chat_about_quality(image, message, history): |
|
|
"""Multi-turn conversation about image quality""" |
|
|
if model is None: |
|
|
return history + [[message, "Please load the model first!"]], history |
|
|
|
|
|
if image is None: |
|
|
return history + [[message, "Please upload an image first!"]], history |
|
|
|
|
|
if not message.strip(): |
|
|
return history, history |
|
|
|
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
conversation = [] |
|
|
|
|
|
if len(history) == 0: |
|
|
conversation.append({ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": message}, |
|
|
{"type": "image"}, |
|
|
], |
|
|
}) |
|
|
else: |
|
|
for i, (user_msg, assistant_msg) in enumerate(history): |
|
|
if i == 0: |
|
|
conversation.append({ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": user_msg}, |
|
|
{"type": "image"}, |
|
|
], |
|
|
}) |
|
|
else: |
|
|
conversation.append({ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": user_msg}, |
|
|
], |
|
|
}) |
|
|
conversation.append({ |
|
|
"role": "assistant", |
|
|
"content": [ |
|
|
{"type": "text", "text": assistant_msg}, |
|
|
], |
|
|
}) |
|
|
|
|
|
conversation.append({ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": message}, |
|
|
], |
|
|
}) |
|
|
|
|
|
|
|
|
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) |
|
|
inputs = processor(images=image, text=prompt, return_tensors='pt') |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
inputs = {k: v.to(device, torch.float16) if v.dtype in [torch.float32, torch.float64] else v.to(device) |
|
|
for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
full_response = processor.decode(output[0], skip_special_tokens=True) |
|
|
response = full_response.split("assistant")[-1].strip() |
|
|
|
|
|
new_history = history + [[message, response]] |
|
|
return new_history, new_history |
|
|
|
|
|
def clear_chat(): |
|
|
return [], [] |
|
|
|
|
|
def create_app(): |
|
|
|
|
|
orange_theme = gr.themes.Soft( |
|
|
primary_hue="orange", |
|
|
secondary_hue="orange", |
|
|
neutral_hue="gray", |
|
|
) |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
--color-accent: #ff9900 !important; |
|
|
--color-accent-soft: #fed7aa !important; |
|
|
} |
|
|
button.primary { |
|
|
background-color: #ff9900 !important; |
|
|
border-color: #ff9900 !important; |
|
|
} |
|
|
button.primary:hover { |
|
|
background-color: #e68a00 !important; |
|
|
border-color: #e68a00 !important; |
|
|
} |
|
|
.tab-nav button.selected { |
|
|
border-color: #ff9900 !important; |
|
|
color: #ff9900 !important; |
|
|
} |
|
|
a { |
|
|
color: #ff9900 !important; |
|
|
} |
|
|
/* Larger font for score display */ |
|
|
#score_display .prose, #compare_score1 .prose, #compare_score2 .prose, #compare_score3 .prose { |
|
|
font-size: 1.5em !important; |
|
|
} |
|
|
#score_display .prose strong, #compare_score1 .prose strong, #compare_score2 .prose strong, #compare_score3 .prose strong { |
|
|
font-size: 1.2em !important; |
|
|
color: #ff9900 !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Image Quality Assessment", theme=orange_theme, css=custom_css) as app: |
|
|
gr.Markdown(""" |
|
|
# Image Quality Scoring and Interpreting |
|
|
|
|
|
Unifies image quality **scoring** and **interpreting** in one model. |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Upload & Score") |
|
|
|
|
|
main_image = gr.Image(label="Main Image", type="pil") |
|
|
score_btn = gr.Button("Get Quality Score", variant="primary") |
|
|
|
|
|
|
|
|
score_display = gr.Markdown(label="Score & Rating", elem_id="score_display") |
|
|
|
|
|
|
|
|
prob_table = gr.Dataframe( |
|
|
headers=["Level", "Probability"], |
|
|
label="Probability Distribution" |
|
|
) |
|
|
|
|
|
score_btn.click( |
|
|
get_quality_score, |
|
|
inputs=[main_image], |
|
|
outputs=[score_display, prob_table] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Chat About Quality") |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
label="Conversation", |
|
|
height=300, |
|
|
bubble_full_width=False |
|
|
) |
|
|
|
|
|
chat_state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
|
chat_input = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="e.g., 'What distortions can you see?'", |
|
|
scale=4 |
|
|
) |
|
|
chat_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
|
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
chat_btn.click( |
|
|
chat_about_quality, |
|
|
inputs=[main_image, chat_input, chat_state], |
|
|
outputs=[chatbot, chat_state] |
|
|
).then(lambda: "", outputs=[chat_input]) |
|
|
|
|
|
chat_input.submit( |
|
|
chat_about_quality, |
|
|
inputs=[main_image, chat_input, chat_state], |
|
|
outputs=[chatbot, chat_state] |
|
|
).then(lambda: "", outputs=[chat_input]) |
|
|
|
|
|
clear_btn.click(clear_chat, outputs=[chatbot, chat_state]) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### Example Questions for Chat |
|
|
- "How is the sharpness of this image?" |
|
|
- "Is there any noise or grain?" |
|
|
- "How is the exposure/brightness?" |
|
|
- "What quality issues can you identify?" |
|
|
- "How could this image be improved?" |
|
|
- "Compare the quality of center vs corners" |
|
|
""") |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
gr.Markdown("## Compare Multiple Images") |
|
|
gr.Markdown("Upload 1-3 images to compare their quality scores. You can use just one window or all three.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Image 1") |
|
|
compare_img1 = gr.Image(label="Image 1", type="pil") |
|
|
compare_score1 = gr.Markdown(elem_id="compare_score1") |
|
|
compare_table1 = gr.Dataframe( |
|
|
headers=["Level", "Probability"], |
|
|
label="Distribution" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Image 2") |
|
|
compare_img2 = gr.Image(label="Image 2", type="pil") |
|
|
compare_score2 = gr.Markdown(elem_id="compare_score2") |
|
|
compare_table2 = gr.Dataframe( |
|
|
headers=["Level", "Probability"], |
|
|
label="Distribution" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Image 3") |
|
|
compare_img3 = gr.Image(label="Image 3", type="pil") |
|
|
compare_score3 = gr.Markdown(elem_id="compare_score3") |
|
|
compare_table3 = gr.Dataframe( |
|
|
headers=["Level", "Probability"], |
|
|
label="Distribution" |
|
|
) |
|
|
|
|
|
compare_btn = gr.Button("Compare All Images", variant="primary", size="lg") |
|
|
|
|
|
def compare_images(img1, img2, img3): |
|
|
"""Compare up to 3 images""" |
|
|
results = [] |
|
|
for img in [img1, img2, img3]: |
|
|
if img is None: |
|
|
results.append((None, None)) |
|
|
else: |
|
|
score_text, table_data = get_quality_score(img) |
|
|
results.append((score_text, table_data)) |
|
|
return results[0][0], results[0][1], results[1][0], results[1][1], results[2][0], results[2][1] |
|
|
|
|
|
compare_btn.click( |
|
|
compare_images, |
|
|
inputs=[compare_img1, compare_img2, compare_img3], |
|
|
outputs=[compare_score1, compare_table1, compare_score2, compare_table2, compare_score3, compare_table3] |
|
|
) |
|
|
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("=" * 50) |
|
|
print("Loading Q-SIT model...") |
|
|
print("=" * 50) |
|
|
load_model(use_multi_gpu=True) |
|
|
print("=" * 50) |
|
|
print("Model loaded! Starting Gradio interface...") |
|
|
print("=" * 50) |
|
|
|
|
|
app = create_app() |
|
|
app.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |