import gradio as gr import pandas as pd import json from typing import Tuple from PIL import Image import torch import numpy as np from torchvision import transforms import os from models import ( get_device, get_tokenizers, get_image_processor, load_merger_model, get_predicated_values ) from rembg import remove from io import BytesIO # Load environment variables (optional for local dev; Spaces use web UI for env vars) if os.path.exists('.env'): from dotenv import load_dotenv load_dotenv() # Global constants ATTRIBUTES_LIST = ['sleeve', 'color', 'type', 'pattern', 'material', 'style', 'neck', 'gender', 'brand'] MAX_SEQ_LENGTH = 256 DECODER_MAX_SEQ_LENGTH = 64 # Global variables for model components MODEL_COMPONENTS = None MODEL_LOADED = False def initialize_model_and_tokenizers(): """Initialize model and tokenizers once""" global MODEL_COMPONENTS, MODEL_LOADED if MODEL_LOADED and MODEL_COMPONENTS: return MODEL_COMPONENTS try: print("🔄 Loading AI model components...") device = get_device() bert_tokenizer, roberta_tokenizer = get_tokenizers() image_processor = get_image_processor() model = load_merger_model(bert_tokenizer, device) MODEL_COMPONENTS = { 'model': model, 'bert_tokenizer': bert_tokenizer, 'roberta_tokenizer': roberta_tokenizer, 'image_processor': image_processor, 'device': device } MODEL_LOADED = True print("✅ Model loaded successfully!") return MODEL_COMPONENTS except Exception as e: print(f"❌ Failed to load model: {str(e)}") raise e def validate_inputs(image, text_input: str, category: str) -> Tuple[bool, str]: """Validate that all inputs are provided""" if image is None: return False, "❌ Please upload an image file" if not text_input or text_input.strip() == "": return False, "❌ Please provide a product description" if not category: return False, "❌ Please select a product category" return True, "✅ Inputs validated successfully" def resize_image_for_display(image: Image.Image, target_size=(512, 512)) -> Image.Image: """Resize image for consistent display""" # Ensure image is RGBA image = image.convert('RGBA') # rembg expects bytes in/out with BytesIO() as inp, BytesIO() as out: image.save(inp, format="PNG") inp.seek(0) data = remove(inp.read()) out.write(data) out.seek(0) no_bg = Image.open(out).convert("RGBA") # 2) Compute new size preserving aspect ratio orig_w, orig_h = no_bg.size max_w, max_h = target_size # Determine scale factor scale = min(max_w / orig_w, max_h / orig_h) new_w = int(orig_w * scale) new_h = int(orig_h * scale) # 3) Resize with high-quality resampling resized = no_bg.resize((new_w, new_h), Image.Resampling.LANCZOS) return resized def preprocess_image(image: Image.Image) -> torch.Tensor: """Preprocess image for model input""" if image.mode != 'RGBA': image = image.convert('RGBA') image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) image_tensor = image_tensor.unsqueeze(0) return image_tensor def run_inference(image_tensor: torch.Tensor, description: str, category: str, model_components: dict) -> dict: """Run model inference using get_predicated_values API""" model = model_components['model'] bert_tokenizer = model_components['bert_tokenizer'] roberta_tokenizer = model_components['roberta_tokenizer'] image_processor = model_components['image_processor'] device = model_components['device'] pil_img = transforms.ToPILImage()(image_tensor.squeeze(0).cpu()) results = get_predicated_values( model, category, pil_img, description, image_processor, bert_tokenizer, roberta_tokenizer, device ) total_attributes = len([a for a in results if a["value"] and a["value"] != "N/A"]) avg_confidence = np.mean([a["confidence"] for a in results if a["value"] and a["value"] != "N/A"]) if total_attributes > 0 else 0 return { "attributes": results, "total_attributes": total_attributes, "avg_confidence": avg_confidence, } def get_confidence_color(confidence: float) -> str: """Get color based on confidence level""" if confidence >= 0.8: return "#28a745" # Green elif confidence >= 0.6: return "#ffc107" # Yellow else: return "#dc3545" # Red def format_results_html(results: dict) -> str: """Format results as HTML for display with dark theme""" if not results or results["total_attributes"] == 0: return """

🔍 No attributes extracted

Try with a different image or more detailed description.

""" html = """

📊 Extracted Attributes

""" for attr in results["attributes"]: if attr["value"] != "N/A": confidence = attr["confidence"] color = get_confidence_color(confidence) html += f"""
{attr["name"].title()}
{attr["value"]}
{confidence:.1%}
""" html += f"""

📈 Summary

{results["total_attributes"]} attributes extracted | {results["avg_confidence"]:.1%} avg confidence |

""" return html def create_download_files(results: dict) -> Tuple[str, str]: """Create JSON and CSV files for download""" if not results: return None, None json_content = json.dumps(results, indent=2) json_file = "attributes.json" with open(json_file, "w") as f: f.write(json_content) df = pd.DataFrame(results["attributes"]) csv_file = "attributes.csv" df.to_csv(csv_file, index=False) return json_file, csv_file def process_inputs(image, category, description, progress=gr.Progress()): """Main processing function""" global MODEL_COMPONENTS if not MODEL_LOADED: progress(0.1, desc="Loading AI model...") try: MODEL_COMPONENTS = initialize_model_and_tokenizers() except Exception as e: error_msg = f"❌ Failed to load model: {str(e)}" return None, error_msg, None, None, None is_valid, validation_message = validate_inputs(image, description, category) if not is_valid: return None, validation_message, None, None, None try: progress(0.3, desc="📸 Preprocessing image...") resized_image = resize_image_for_display(image, (512, 512)) image_tensor = preprocess_image(resized_image) progress(0.7, desc="🧠 Running AI inference...") results = run_inference(image_tensor, description, category, MODEL_COMPONENTS) progress(0.9, desc="📊 Formatting results...") results_html = format_results_html(results) json_file, csv_file = create_download_files(results) progress(1.0, desc="✅ Processing complete!") success_msg = f"🎉 Successfully extracted {results['total_attributes']} attributes!" return resized_image, success_msg, results_html, json_file, csv_file except Exception as e: error_msg = f"❌ Processing failed: {str(e)}" return None, error_msg, None, None, None # Updated custom CSS for dark theme and refined layout custom_css = """ /* Dark theme overrides */ body, .gradio-container { background-color: #1a1a1a !important; color: #e0e0e0 !important; } .gr-blocks, .gr-row, .gr-column { background-color: #1a1a1a !important; } .input-section, .results-section { background-color: #2a2a2a !important; padding: 20px; border-radius: 15px; margin-bottom: 20px; } .gr-image, .gr-textbox, .gr-dropdown { background-color: #333 !important; color: #e0e0e0 !important; border: 1px solid #444 !important; } .gr-button { background-color: #444 !important; color: #e0e0e0 !important; border: none !important; } .gr-button:hover { background-color: #555 !important; } /* Header styling */ .header { text-align: center; color: #e0e0e0; margin-bottom: 30px; } .header h1 { font-size: 2em; } .header p { font-size: 1.1em; color: #b0b0b0; } /* Results styling */ .results-container { padding: 20px; } .results-title { color: #e0e0e0; margin-bottom: 20px; font-size: 1.5em; } .attribute-item { background-color: #333; padding: 15px; margin-bottom: 10px; border-radius: 10px; display: flex; justify-content: space-between; align-items: center; } .attribute-name { color: #e0e0e0; font-weight: bold; font-size: 1.1em; } .attribute-value { color: #b0b0b0; margin-left: 10px; } .confidence-badge { color: white; padding: 4px 8px; border-radius: 12px; font-size: 0.8em; font-weight: bold; } .summary-box { background-color: #444; color: #e0e0e0; padding: 15px; border-radius: 10px; margin-top: 20px; text-align: center; } .no-results { padding: 20px; text-align: center; background-color: #333; border-radius: 10px; color: #e0e0e0; } /* Status message styling */ .status-positive { background-color: #1a472a !important; color: #e0e0e0 !important; padding: 10px; border-radius: 8px; } .status-negative { background-color: #471a1a !important; color: #e0e0e0 !important; padding: 10px; border-radius: 8px; } """ def create_interface(): """Create the main Gradio interface with dark theme and refined layout""" with gr.Blocks(css=custom_css, title="AI Attribute Extractor", theme=gr.themes.Soft()) as demo: # Header with dark theme styling gr.HTML("""

🔍 AI Attribute Extractor

Upload an image and provide text to extract detailed attributes using AI

""") with gr.Row(): # Left column - Input section with gr.Column(scale=1): gr.HTML("

📤 Input Section

") image_input = gr.Image( label="Upload Product Image", type="pil", height=300, elem_classes=["input-section"] ) category_input = gr.Dropdown( choices=["clothing", "bags", "shoes", "accessories"], label="Product Category", value="clothing", elem_classes=["input-section"] ) text_input = gr.Textbox( label="Product Description", placeholder="Describe the product in detail...", lines=4, elem_classes=["input-section"] ) process_btn = gr.Button( "🚀 Extract Attributes", variant="primary", size="lg", elem_classes=["primary-button"] ) status_msg = gr.HTML(label="Status") # Right column - Results section with gr.Column(scale=1): gr.HTML("

📊 Results Section

") processed_image = gr.Image( label="Processed Image", height=300, elem_classes=["results-section"] ) results_html = gr.HTML( label="Extracted Attributes", elem_classes=["results-section"] ) with gr.Row(): json_download = gr.File( label="📄 Download JSON", visible=False ) csv_download = gr.File( label="📊 Download CSV", visible=False ) # Event handlers def update_status(message: str, is_error: bool = False): """Update status message with styling""" class_name = "status-negative" if is_error else "status-positive" return f'
{message}
' def process_and_update(image, category, description): """Process inputs and update all outputs""" processed_img, status, results, json_file, csv_file = process_inputs( image, category, description ) is_error = status.startswith("❌") styled_status = update_status(status, is_error) json_visible = json_file is not None csv_visible = csv_file is not None return ( processed_img, styled_status, results, gr.update(value=json_file, visible=json_visible), gr.update(value=csv_file, visible=csv_visible) ) process_btn.click( fn=process_and_update, inputs=[image_input, category_input, text_input], outputs=[processed_image, status_msg, results_html, json_download, csv_download] ) return demo if __name__ == "__main__": print("Initializing AI Attribute Extractor...") demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=False, show_error=True, quiet=False )