SubjectiveIQE / UI.py
gongnq's picture
Upload folder using huggingface_hub
5e8fb45 verified
"""
Image Quality Scoring and Interpreting Gradio Interface
- Single image scoring
- Quality interpretation chat
- Multi-GPU distribution for 7B model
- Auto-load model on startup
"""
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, AutoTokenizer
import gc
# Global variables for model
model = None
processor = None
tokenizer = None
def load_model(use_multi_gpu=True):
"""Load the Q-SIT model with optional multi-GPU support"""
global model, processor, tokenizer
# Clear previous model if exists
if model is not None:
del model
gc.collect()
torch.cuda.empty_cache()
# Updated to local model path
model_id = "models/q-sit"
print(f"Loading model from: {model_id}")
print(f"Available GPUs: {torch.cuda.device_count()}")
if use_multi_gpu and torch.cuda.device_count() > 1:
print(f"Using device_map='auto' to distribute across {torch.cuda.device_count()} GPUs")
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
local_files_only=True, # Added: use local files only
)
device_info = "multi-GPU (auto)"
else:
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
local_files_only=True, # Added: use local files only
).to(0)
device_info = "GPU:0"
processor = AutoProcessor.from_pretrained(model_id, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
# Print memory usage
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**3
total = torch.cuda.get_device_properties(i).total_memory / 1024**3
print(f"GPU {i}: {allocated:.2f}GB / {total:.2f}GB")
print(f"Model loaded successfully on {device_info}!")
return f"Model loaded from {model_id} on {device_info}\nGPUs: {torch.cuda.device_count()}"
def wa5(logits):
"""
Weighted average for 5-level scoring
Scoring formula:
score = sum(probability_i * weight_i)
Weights:
- Excellent: 1.0
- Good: 0.75
- Fair: 0.5
- Poor: 0.25
- Bad: 0.0
"""
logprobs = np.array([
logits["Excellent"],
logits["Good"],
logits["Fair"],
logits["Poor"],
logits["Bad"]
])
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
return np.inner(probs, np.array([1, 0.75, 0.5, 0.25, 0])), probs
def score_single_image(image):
"""Score a single image and return score + probabilities"""
if model is None or image is None:
return None, None
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Define rating tokens
toks = ["Excellent", "Good", "Fair", "Poor", "Bad"]
ids_ = [id_[0] for id_ in tokenizer(toks)["input_ids"]]
# Build conversation for scoring
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": """Assume you are an image quality evaluator.
Your rating should be chosen from the following five categories: Excellent, Good, Fair, Poor, and Bad (from high to low).
How would you rate the quality of this image?"""},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors='pt')
# Move to device
device = next(model.parameters()).device
inputs = {k: v.to(device, torch.float16) if v.dtype in [torch.float32, torch.float64] else v.to(device)
for k, v in inputs.items()}
# Add prefix
prefix_text = "The quality of this image is "
prefix_ids = tokenizer(prefix_text, return_tensors="pt")["input_ids"].to(device)
inputs["input_ids"] = torch.cat([inputs["input_ids"], prefix_ids], dim=-1)
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
# Generate
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=1,
output_logits=True,
return_dict_in_generate=True,
)
# Extract logits
last_logits = output.logits[-1][0].cpu()
logits_dict = {tok: last_logits[id_].item() for tok, id_ in zip(toks, ids_)}
score, probs = wa5(logits_dict)
return score, probs
def get_quality_score(image):
"""Get quality score for a single image with detailed output"""
if model is None:
return None, None
if image is None:
return None, None
score, probs = score_single_image(image)
if score is None:
return None, None
score_100 = score * 100
toks = ["Excellent", "Good", "Fair", "Poor", "Bad"]
# Determine rank based on highest probability
max_idx = np.argmax(probs)
rank = toks[max_idx]
# Score and rank as text
score_text = f"**Quality Score:** {score_100:.2f}/100\n\n**Rating:** {rank}"
# Probability distribution as table
table_data = []
for tok, prob in zip(toks, probs):
table_data.append([tok, f"{prob*100:.1f}%"])
return score_text, table_data
def chat_about_quality(image, message, history):
"""Multi-turn conversation about image quality"""
if model is None:
return history + [[message, "Please load the model first!"]], history
if image is None:
return history + [[message, "Please upload an image first!"]], history
if not message.strip():
return history, history
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Build conversation with history
conversation = []
if len(history) == 0:
conversation.append({
"role": "user",
"content": [
{"type": "text", "text": message},
{"type": "image"},
],
})
else:
for i, (user_msg, assistant_msg) in enumerate(history):
if i == 0:
conversation.append({
"role": "user",
"content": [
{"type": "text", "text": user_msg},
{"type": "image"},
],
})
else:
conversation.append({
"role": "user",
"content": [
{"type": "text", "text": user_msg},
],
})
conversation.append({
"role": "assistant",
"content": [
{"type": "text", "text": assistant_msg},
],
})
conversation.append({
"role": "user",
"content": [
{"type": "text", "text": message},
],
})
# Generate response
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors='pt')
device = next(model.parameters()).device
inputs = {k: v.to(device, torch.float16) if v.dtype in [torch.float32, torch.float64] else v.to(device)
for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
full_response = processor.decode(output[0], skip_special_tokens=True)
response = full_response.split("assistant")[-1].strip()
new_history = history + [[message, response]]
return new_history, new_history
def clear_chat():
return [], []
def create_app():
# Create orange theme
orange_theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="orange",
neutral_hue="gray",
)
# Custom CSS for orange color #ff9900 and larger score display
custom_css = """
.gradio-container {
--color-accent: #ff9900 !important;
--color-accent-soft: #fed7aa !important;
}
button.primary {
background-color: #ff9900 !important;
border-color: #ff9900 !important;
}
button.primary:hover {
background-color: #e68a00 !important;
border-color: #e68a00 !important;
}
.tab-nav button.selected {
border-color: #ff9900 !important;
color: #ff9900 !important;
}
a {
color: #ff9900 !important;
}
/* Larger font for score display */
#score_display .prose, #compare_score1 .prose, #compare_score2 .prose, #compare_score3 .prose {
font-size: 1.5em !important;
}
#score_display .prose strong, #compare_score1 .prose strong, #compare_score2 .prose strong, #compare_score3 .prose strong {
font-size: 1.2em !important;
color: #ff9900 !important;
}
"""
with gr.Blocks(title="Image Quality Assessment", theme=orange_theme, css=custom_css) as app:
gr.Markdown("""
# Image Quality Scoring and Interpreting
Unifies image quality **scoring** and **interpreting** in one model.
""")
# ========== UNIFIED INTERFACE ==========
with gr.Row():
# Left: Image upload and scoring
with gr.Column(scale=1):
gr.Markdown("### Upload & Score")
main_image = gr.Image(label="Main Image", type="pil")
score_btn = gr.Button("Get Quality Score", variant="primary")
# Score and rank as text
score_display = gr.Markdown(label="Score & Rating", elem_id="score_display")
# Probability distribution as table
prob_table = gr.Dataframe(
headers=["Level", "Probability"],
label="Probability Distribution"
)
score_btn.click(
get_quality_score,
inputs=[main_image],
outputs=[score_display, prob_table]
)
# Right: Chat about quality
with gr.Column(scale=1):
gr.Markdown("### Chat About Quality")
chatbot = gr.Chatbot(
label="Conversation",
height=300,
bubble_full_width=False
)
chat_state = gr.State([])
with gr.Row():
chat_input = gr.Textbox(
label="Your Question",
placeholder="e.g., 'What distortions can you see?'",
scale=4
)
chat_btn = gr.Button("Send", variant="primary", scale=1)
clear_btn = gr.Button("Clear Chat")
chat_btn.click(
chat_about_quality,
inputs=[main_image, chat_input, chat_state],
outputs=[chatbot, chat_state]
).then(lambda: "", outputs=[chat_input])
chat_input.submit(
chat_about_quality,
inputs=[main_image, chat_input, chat_state],
outputs=[chatbot, chat_state]
).then(lambda: "", outputs=[chat_input])
clear_btn.click(clear_chat, outputs=[chatbot, chat_state])
gr.Markdown("""
---
### Example Questions for Chat
- "How is the sharpness of this image?"
- "Is there any noise or grain?"
- "How is the exposure/brightness?"
- "What quality issues can you identify?"
- "How could this image be improved?"
- "Compare the quality of center vs corners"
""")
gr.Markdown("---")
# ========== MULTI-IMAGE COMPARISON ==========
gr.Markdown("## Compare Multiple Images")
gr.Markdown("Upload 1-3 images to compare their quality scores. You can use just one window or all three.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Image 1")
compare_img1 = gr.Image(label="Image 1", type="pil")
compare_score1 = gr.Markdown(elem_id="compare_score1")
compare_table1 = gr.Dataframe(
headers=["Level", "Probability"],
label="Distribution"
)
with gr.Column(scale=1):
gr.Markdown("### Image 2")
compare_img2 = gr.Image(label="Image 2", type="pil")
compare_score2 = gr.Markdown(elem_id="compare_score2")
compare_table2 = gr.Dataframe(
headers=["Level", "Probability"],
label="Distribution"
)
with gr.Column(scale=1):
gr.Markdown("### Image 3")
compare_img3 = gr.Image(label="Image 3", type="pil")
compare_score3 = gr.Markdown(elem_id="compare_score3")
compare_table3 = gr.Dataframe(
headers=["Level", "Probability"],
label="Distribution"
)
compare_btn = gr.Button("Compare All Images", variant="primary", size="lg")
def compare_images(img1, img2, img3):
"""Compare up to 3 images"""
results = []
for img in [img1, img2, img3]:
if img is None:
results.append((None, None))
else:
score_text, table_data = get_quality_score(img)
results.append((score_text, table_data))
return results[0][0], results[0][1], results[1][0], results[1][1], results[2][0], results[2][1]
compare_btn.click(
compare_images,
inputs=[compare_img1, compare_img2, compare_img3],
outputs=[compare_score1, compare_table1, compare_score2, compare_table2, compare_score3, compare_table3]
)
return app
if __name__ == "__main__":
# Auto-load model on startup
print("=" * 50)
print("Loading Q-SIT model...")
print("=" * 50)
load_model(use_multi_gpu=True)
print("=" * 50)
print("Model loaded! Starting Gradio interface...")
print("=" * 50)
app = create_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)