""" Image Quality Scoring and Interpreting Gradio Interface - Single image scoring - Quality interpretation chat - Multi-GPU distribution for 7B model - Auto-load model on startup """ import gradio as gr import torch import numpy as np from PIL import Image from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, AutoTokenizer import gc # Global variables for model model = None processor = None tokenizer = None def load_model(use_multi_gpu=True): """Load the Q-SIT model with optional multi-GPU support""" global model, processor, tokenizer # Clear previous model if exists if model is not None: del model gc.collect() torch.cuda.empty_cache() # Updated to local model path model_id = "models/q-sit" print(f"Loading model from: {model_id}") print(f"Available GPUs: {torch.cuda.device_count()}") if use_multi_gpu and torch.cuda.device_count() > 1: print(f"Using device_map='auto' to distribute across {torch.cuda.device_count()} GPUs") model = LlavaOnevisionForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", local_files_only=True, # Added: use local files only ) device_info = "multi-GPU (auto)" else: model = LlavaOnevisionForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, local_files_only=True, # Added: use local files only ).to(0) device_info = "GPU:0" processor = AutoProcessor.from_pretrained(model_id, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True) # Print memory usage if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): allocated = torch.cuda.memory_allocated(i) / 1024**3 total = torch.cuda.get_device_properties(i).total_memory / 1024**3 print(f"GPU {i}: {allocated:.2f}GB / {total:.2f}GB") print(f"Model loaded successfully on {device_info}!") return f"Model loaded from {model_id} on {device_info}\nGPUs: {torch.cuda.device_count()}" def wa5(logits): """ Weighted average for 5-level scoring Scoring formula: score = sum(probability_i * weight_i) Weights: - Excellent: 1.0 - Good: 0.75 - Fair: 0.5 - Poor: 0.25 - Bad: 0.0 """ logprobs = np.array([ logits["Excellent"], logits["Good"], logits["Fair"], logits["Poor"], logits["Bad"] ]) probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) return np.inner(probs, np.array([1, 0.75, 0.5, 0.25, 0])), probs def score_single_image(image): """Score a single image and return score + probabilities""" if model is None or image is None: return None, None # Convert to PIL if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) # Define rating tokens toks = ["Excellent", "Good", "Fair", "Poor", "Bad"] ids_ = [id_[0] for id_ in tokenizer(toks)["input_ids"]] # Build conversation for scoring conversation = [ { "role": "user", "content": [ {"type": "text", "text": """Assume you are an image quality evaluator. Your rating should be chosen from the following five categories: Excellent, Good, Fair, Poor, and Bad (from high to low). How would you rate the quality of this image?"""}, {"type": "image"}, ], }, ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(images=image, text=prompt, return_tensors='pt') # Move to device device = next(model.parameters()).device inputs = {k: v.to(device, torch.float16) if v.dtype in [torch.float32, torch.float64] else v.to(device) for k, v in inputs.items()} # Add prefix prefix_text = "The quality of this image is " prefix_ids = tokenizer(prefix_text, return_tensors="pt")["input_ids"].to(device) inputs["input_ids"] = torch.cat([inputs["input_ids"], prefix_ids], dim=-1) inputs["attention_mask"] = torch.ones_like(inputs["input_ids"]) # Generate with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=1, output_logits=True, return_dict_in_generate=True, ) # Extract logits last_logits = output.logits[-1][0].cpu() logits_dict = {tok: last_logits[id_].item() for tok, id_ in zip(toks, ids_)} score, probs = wa5(logits_dict) return score, probs def get_quality_score(image): """Get quality score for a single image with detailed output""" if model is None: return None, None if image is None: return None, None score, probs = score_single_image(image) if score is None: return None, None score_100 = score * 100 toks = ["Excellent", "Good", "Fair", "Poor", "Bad"] # Determine rank based on highest probability max_idx = np.argmax(probs) rank = toks[max_idx] # Score and rank as text score_text = f"**Quality Score:** {score_100:.2f}/100\n\n**Rating:** {rank}" # Probability distribution as table table_data = [] for tok, prob in zip(toks, probs): table_data.append([tok, f"{prob*100:.1f}%"]) return score_text, table_data def chat_about_quality(image, message, history): """Multi-turn conversation about image quality""" if model is None: return history + [[message, "Please load the model first!"]], history if image is None: return history + [[message, "Please upload an image first!"]], history if not message.strip(): return history, history # Convert to PIL if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) # Build conversation with history conversation = [] if len(history) == 0: conversation.append({ "role": "user", "content": [ {"type": "text", "text": message}, {"type": "image"}, ], }) else: for i, (user_msg, assistant_msg) in enumerate(history): if i == 0: conversation.append({ "role": "user", "content": [ {"type": "text", "text": user_msg}, {"type": "image"}, ], }) else: conversation.append({ "role": "user", "content": [ {"type": "text", "text": user_msg}, ], }) conversation.append({ "role": "assistant", "content": [ {"type": "text", "text": assistant_msg}, ], }) conversation.append({ "role": "user", "content": [ {"type": "text", "text": message}, ], }) # Generate response prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(images=image, text=prompt, return_tensors='pt') device = next(model.parameters()).device inputs = {k: v.to(device, torch.float16) if v.dtype in [torch.float32, torch.float64] else v.to(device) for k, v in inputs.items()} with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=512, do_sample=False, ) full_response = processor.decode(output[0], skip_special_tokens=True) response = full_response.split("assistant")[-1].strip() new_history = history + [[message, response]] return new_history, new_history def clear_chat(): return [], [] def create_app(): # Create orange theme orange_theme = gr.themes.Soft( primary_hue="orange", secondary_hue="orange", neutral_hue="gray", ) # Custom CSS for orange color #ff9900 and larger score display custom_css = """ .gradio-container { --color-accent: #ff9900 !important; --color-accent-soft: #fed7aa !important; } button.primary { background-color: #ff9900 !important; border-color: #ff9900 !important; } button.primary:hover { background-color: #e68a00 !important; border-color: #e68a00 !important; } .tab-nav button.selected { border-color: #ff9900 !important; color: #ff9900 !important; } a { color: #ff9900 !important; } /* Larger font for score display */ #score_display .prose, #compare_score1 .prose, #compare_score2 .prose, #compare_score3 .prose { font-size: 1.5em !important; } #score_display .prose strong, #compare_score1 .prose strong, #compare_score2 .prose strong, #compare_score3 .prose strong { font-size: 1.2em !important; color: #ff9900 !important; } """ with gr.Blocks(title="Image Quality Assessment", theme=orange_theme, css=custom_css) as app: gr.Markdown(""" # Image Quality Scoring and Interpreting Unifies image quality **scoring** and **interpreting** in one model. """) # ========== UNIFIED INTERFACE ========== with gr.Row(): # Left: Image upload and scoring with gr.Column(scale=1): gr.Markdown("### Upload & Score") main_image = gr.Image(label="Main Image", type="pil") score_btn = gr.Button("Get Quality Score", variant="primary") # Score and rank as text score_display = gr.Markdown(label="Score & Rating", elem_id="score_display") # Probability distribution as table prob_table = gr.Dataframe( headers=["Level", "Probability"], label="Probability Distribution" ) score_btn.click( get_quality_score, inputs=[main_image], outputs=[score_display, prob_table] ) # Right: Chat about quality with gr.Column(scale=1): gr.Markdown("### Chat About Quality") chatbot = gr.Chatbot( label="Conversation", height=300, bubble_full_width=False ) chat_state = gr.State([]) with gr.Row(): chat_input = gr.Textbox( label="Your Question", placeholder="e.g., 'What distortions can you see?'", scale=4 ) chat_btn = gr.Button("Send", variant="primary", scale=1) clear_btn = gr.Button("Clear Chat") chat_btn.click( chat_about_quality, inputs=[main_image, chat_input, chat_state], outputs=[chatbot, chat_state] ).then(lambda: "", outputs=[chat_input]) chat_input.submit( chat_about_quality, inputs=[main_image, chat_input, chat_state], outputs=[chatbot, chat_state] ).then(lambda: "", outputs=[chat_input]) clear_btn.click(clear_chat, outputs=[chatbot, chat_state]) gr.Markdown(""" --- ### Example Questions for Chat - "How is the sharpness of this image?" - "Is there any noise or grain?" - "How is the exposure/brightness?" - "What quality issues can you identify?" - "How could this image be improved?" - "Compare the quality of center vs corners" """) gr.Markdown("---") # ========== MULTI-IMAGE COMPARISON ========== gr.Markdown("## Compare Multiple Images") gr.Markdown("Upload 1-3 images to compare their quality scores. You can use just one window or all three.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Image 1") compare_img1 = gr.Image(label="Image 1", type="pil") compare_score1 = gr.Markdown(elem_id="compare_score1") compare_table1 = gr.Dataframe( headers=["Level", "Probability"], label="Distribution" ) with gr.Column(scale=1): gr.Markdown("### Image 2") compare_img2 = gr.Image(label="Image 2", type="pil") compare_score2 = gr.Markdown(elem_id="compare_score2") compare_table2 = gr.Dataframe( headers=["Level", "Probability"], label="Distribution" ) with gr.Column(scale=1): gr.Markdown("### Image 3") compare_img3 = gr.Image(label="Image 3", type="pil") compare_score3 = gr.Markdown(elem_id="compare_score3") compare_table3 = gr.Dataframe( headers=["Level", "Probability"], label="Distribution" ) compare_btn = gr.Button("Compare All Images", variant="primary", size="lg") def compare_images(img1, img2, img3): """Compare up to 3 images""" results = [] for img in [img1, img2, img3]: if img is None: results.append((None, None)) else: score_text, table_data = get_quality_score(img) results.append((score_text, table_data)) return results[0][0], results[0][1], results[1][0], results[1][1], results[2][0], results[2][1] compare_btn.click( compare_images, inputs=[compare_img1, compare_img2, compare_img3], outputs=[compare_score1, compare_table1, compare_score2, compare_table2, compare_score3, compare_table3] ) return app if __name__ == "__main__": # Auto-load model on startup print("=" * 50) print("Loading Q-SIT model...") print("=" * 50) load_model(use_multi_gpu=True) print("=" * 50) print("Model loaded! Starting Gradio interface...") print("=" * 50) app = create_app() app.launch( server_name="0.0.0.0", server_port=7860, share=False )