import os import json import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np import pandas as pd from transformers import AutoProcessor, ViTModel, AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download import gradio as gr import pytesseract # For OCR import spaces import random import time import subprocess import re # Load environment variables from .env file (for local development) try: from dotenv import load_dotenv load_dotenv() # Load .env file if it exists print("✅ Loaded .env file for local development") except ImportError: print("â„šī¸ python-dotenv not installed, using system environment variables only") # --- 1. Configuration (Mirrored from your scripts) --- # This ensures consistency with the model's training environment. MODEL_DIR = "model" MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_gated_model_2.7_GGG.pth") CAT_MAPPINGS_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_cat_mappings_GGG.json") # API Configuration - AI API calls removed, using direct categorical inputs # Hugging Face Model Hub Configuration # Point to your model repository (not the Space) HF_MODEL_REPO = "nitish-spz/ABTestPredictor" # Your model repository HF_MODEL_FILENAME = "multimodal_gated_model_2.7_GGG.pth" HF_MAPPINGS_FILENAME = "multimodal_cat_mappings_GGG.json" VISION_MODEL_NAME = "google/vit-base-patch16-224-in21k" TEXT_MODEL_NAME = "distilbert-base-uncased" MAX_TEXT_LENGTH = 512 # Columns from testing script CONTROL_IMAGE_URL_COLUMN = "controlImage" VARIANT_IMAGE_URL_COLUMN = "variantImage" CATEGORICAL_FEATURES = [ "Business Model", "Customer Type", "grouped_conversion_type", "grouped_industry", "grouped_page_type" ] CATEGORICAL_EMBEDDING_DIMS = { "Business Model": 10, "Customer Type": 10, "grouped_conversion_type": 25, "grouped_industry": 50, "grouped_page_type": 25 } GATED_FUSION_DIM = 64 # --- 2. Model Architecture (Exact Replica from your training script) --- # This class must be defined to load the saved model weights correctly. class SupervisedSiameseMultimodal(nn.Module): """ Updated model architecture matching the new GGG version. Includes fusion block, BatchNorm, and enhanced directional features. """ def __init__(self, vision_model_name, text_model_name, cat_mappings, cat_embedding_dims): super().__init__() self.vision_model = ViTModel.from_pretrained(vision_model_name) self.text_model = AutoModel.from_pretrained(text_model_name) vision_dim = self.vision_model.config.hidden_size text_dim = self.text_model.config.hidden_size self.embedding_layers = nn.ModuleList() total_cat_emb_dim = 0 for feature in CATEGORICAL_FEATURES: # Safely handle cases where a feature might not be in mappings if feature in cat_mappings: num_cats = cat_mappings[feature]['num_categories'] emb_dim = cat_embedding_dims[feature] self.embedding_layers.append(nn.Embedding(num_cats, emb_dim)) total_cat_emb_dim += emb_dim self.gate_controller = nn.Sequential( nn.Linear(total_cat_emb_dim, GATED_FUSION_DIM), nn.ReLU(), nn.Linear(GATED_FUSION_DIM, 2) ) # Updated in_dim calculation to match new architecture in_dim = (vision_dim * 4) + (text_dim * 4) + total_cat_emb_dim + 2 # Add the fusion block self.fusion_block = nn.Sequential( nn.Linear(in_dim, in_dim), nn.ReLU(), nn.Dropout(0.2) ) # Updated prediction head with BatchNorm self.prediction_head = nn.Sequential( nn.BatchNorm1d(in_dim), nn.Linear(in_dim, vision_dim), nn.GELU(), nn.LayerNorm(vision_dim), nn.Dropout(0.2), nn.Linear(vision_dim, vision_dim // 2), nn.GELU(), nn.LayerNorm(vision_dim // 2), nn.Dropout(0.1), nn.Linear(vision_dim // 2, 1) ) def forward(self, c_pix, v_pix, c_tok, c_attn, v_tok, v_attn, cat_feats): # Enhanced forward pass with directional features emb_c_vision = self.vision_model(pixel_values=c_pix).pooler_output emb_v_vision = self.vision_model(pixel_values=v_pix).pooler_output direction_feat_vision = torch.cat([emb_c_vision - emb_v_vision, emb_v_vision - emb_c_vision], dim=1) c_text_out = self.text_model(input_ids=c_tok, attention_mask=c_attn).last_hidden_state v_text_out = self.text_model(input_ids=v_tok, attention_mask=v_attn).last_hidden_state emb_c_text = c_text_out.mean(dim=1) emb_v_text = v_text_out.mean(dim=1) direction_feat_text = torch.cat([emb_c_text - emb_v_text, emb_v_text - emb_c_text], dim=1) cat_embeddings = [layer(cat_feats[:, i]) for i, layer in enumerate(self.embedding_layers)] final_cat_embedding = torch.cat(cat_embeddings, dim=1) gates = F.softmax(self.gate_controller(final_cat_embedding), dim=-1) vision_gate = gates[:, 0].unsqueeze(1) text_gate = gates[:, 1].unsqueeze(1) weighted_vision = direction_feat_vision * vision_gate weighted_text = direction_feat_text * text_gate batch_size = c_pix.shape[0] role_embedding = torch.tensor([[1, 0]] * batch_size, dtype=torch.float32, device=c_pix.device) final_vector = torch.cat([ emb_c_vision, emb_v_vision, emb_c_text, emb_v_text, weighted_vision, weighted_text, final_cat_embedding, role_embedding ], dim=1) # Pass through the fusion block before the final prediction head fused_vector = self.fusion_block(final_vector) return self.prediction_head(fused_vector).squeeze(-1) # --- 3. Loading Models and Processors (Done once on startup) --- # Optimized for L4 GPU setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🚀 Using device: {device}") if torch.cuda.is_available(): print(f"đŸ”Ĩ GPU: {torch.cuda.get_device_name(0)}") print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") # AGGRESSIVE optimizations for 4x L4 GPU torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False # Allow non-deterministic for speed # Aggressive memory management torch.cuda.empty_cache() # Enable tensor core usage for maximum performance torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Create dummy files if they don't exist for the app to run if not os.path.exists(MODEL_DIR): os.makedirs(MODEL_DIR) if not os.path.exists(CAT_MAPPINGS_SAVE_PATH): print(f"âš ī¸ GGG Category mappings not found. Creating default mappings...") # Create the standard category mappings expected by the model default_mappings = { "Business Model": {"num_categories": 4, "categories": ["E-Commerce", "Lead Generation", "Other*", "SaaS"]}, "Customer Type": {"num_categories": 4, "categories": ["B2B", "B2C", "Both", "Other*"]}, "grouped_conversion_type": {"num_categories": 6, "categories": ["Direct Purchase", "High-Intent Lead Gen", "Info/Content Lead Gen", "Location Search", "Non-Profit/Community", "Other Conversion"]}, "grouped_industry": {"num_categories": 14, "categories": ["Automotive & Transportation", "B2B Services", "B2B Software & Tech", "Consumer Services", "Consumer Software & Apps", "Education", "Finance, Insurance & Real Estate", "Food, Hospitality & Travel", "Health & Wellness", "Industrial & Manufacturing", "Media & Entertainment", "Non-Profit & Government", "Other", "Retail & E-commerce"]}, "grouped_page_type": {"num_categories": 5, "categories": ["Awareness & Discovery", "Consideration & Evaluation", "Conversion", "Internal & Navigation", "Post-Conversion & Other"]} } with open(CAT_MAPPINGS_SAVE_PATH, 'w') as f: json.dump(default_mappings, f, indent=2) print(f"✅ Created default category mappings at {CAT_MAPPINGS_SAVE_PATH}") with open(CAT_MAPPINGS_SAVE_PATH, 'r') as f: category_mappings = json.load(f) # Load mapping.json for converting specific values to parent groups def load_value_mappings(): """Load mapping.json for converting industry, page_type, and conversion_type to parent groups""" try: script_dir = os.path.dirname(os.path.abspath(__file__)) mapping_file = os.path.join(script_dir, 'mapping.json') print(f"📂 Looking for mapping file at: {mapping_file}") if not os.path.exists(mapping_file): print(f"âš ī¸ Mapping file not found, trying fallback location...") mapping_file = 'mapping.json' with open(mapping_file, 'r') as f: mapping_data = json.load(f) print(f"✅ Successfully loaded mapping.json with {len(mapping_data)} mapping types") return mapping_data except Exception as e: print(f"âš ī¸ Error loading mapping.json: {e}") import traceback traceback.print_exc() return {} def convert_to_parent_group(value, mapping_type, value_mappings): """ Convert a specific value to its parent group using mapping.json Args: value: The specific value (e.g., "Accounting Services") mapping_type: Type of mapping ("industry_mappings", "page_type_mappings", "conversion_type_mappings") value_mappings: The loaded mapping.json data Returns: The parent group name (e.g., "B2B Services") """ if mapping_type not in value_mappings: print(f"âš ī¸ Mapping type '{mapping_type}' not found in mapping.json") return value mappings = value_mappings[mapping_type] # Search for the value in all parent groups for parent_group, child_values in mappings.items(): if value in child_values: print(f"✅ Mapped '{value}' -> '{parent_group}'") return parent_group # If not found, check if the value itself is a parent group if value in mappings.keys(): print(f"â„šī¸ '{value}' is already a parent group") return value print(f"âš ī¸ Value '{value}' not found in {mapping_type}, returning as-is") return value # Load confidence scores directly from JSON file def load_confidence_scores(): """Load confidence scores from confidence_scores.json""" try: # Get the directory where this script is located script_dir = os.path.dirname(os.path.abspath(__file__)) confidence_file = os.path.join(script_dir, 'confidence_scores.json') print(f"📂 Script directory: {script_dir}") print(f"📂 Looking for confidence file at: {confidence_file}") print(f"📂 File exists: {os.path.exists(confidence_file)}") if not os.path.exists(confidence_file): print(f"âš ī¸ Confidence file not found, trying fallback location...") # Try current directory as fallback confidence_file = 'confidence_scores.json' print(f"📂 Fallback path: {confidence_file}") print(f"📂 Fallback exists: {os.path.exists(confidence_file)}") with open(confidence_file, 'r') as f: confidence_data = json.load(f) print(f"✅ Successfully loaded {len(confidence_data)} confidence score combinations") # Print a sample to verify data sample_key = list(confidence_data.keys())[0] if confidence_data else None if sample_key: print(f"📊 Sample entry: {sample_key} = {confidence_data[sample_key]}") return confidence_data except FileNotFoundError as e: print(f"❌ Confidence file not found: {e}") print(f"📂 Current working directory: {os.getcwd()}") print(f"📂 Files in script dir: {os.listdir(script_dir) if os.path.exists(script_dir) else 'N/A'}") return {} except Exception as e: print(f"âš ī¸ Error loading confidence scores: {e}") import traceback traceback.print_exc() return {} # Load value mappings for converting specific values to parent groups try: print("=" * 50) print("🚀 LOADING VALUE MAPPINGS...") print("=" * 50) value_mappings = load_value_mappings() print(f"✅ Value mappings loaded successfully") print("=" * 50) except Exception as e: print(f"âš ī¸ Error loading value mappings: {e}") value_mappings = {} # Load confidence scores try: print("=" * 50) print("🚀 LOADING CONFIDENCE SCORES...") print("=" * 50) confidence_scores = load_confidence_scores() print(f"✅ Confidence scores loaded successfully: {len(confidence_scores)} combinations") print(f"🔍 confidence_scores is empty: {len(confidence_scores) == 0}") print(f"🔍 confidence_scores type: {type(confidence_scores)}") print("=" * 50) except Exception as e: print(f"âš ī¸ Error loading confidence scores: {e}") confidence_scores = {} print(f"❌ confidence_scores set to empty dict: {confidence_scores}") def get_confidence_data(business_model, customer_type, conversion_type, industry, page_type): """Get confidence data based on Industry + Page Type combination (more reliable than 5-feature combinations)""" key = f"{industry}|{page_type}" print(f"🔍 Looking for confidence key: '{key}'") print(f"🔍 Total confidence_scores loaded: {len(confidence_scores)}") print(f"🔍 Key exists: {key in confidence_scores}") if key in confidence_scores: data = confidence_scores[key] print(f"✅ Found confidence data: {data}") return data else: print(f"âš ī¸ Key '{key}' not found, using fallback") print(f"🔍 Available keys with '{industry}': {[k for k in confidence_scores.keys() if industry in k]}") return { 'accuracy': 0.5, # Default fallback 'count': 0, 'training_data_count': 0, 'correct_predictions': 0, 'actual_wins': 0, 'predicted_wins': 0 } # Instantiate the model with the loaded mappings model = SupervisedSiameseMultimodal( VISION_MODEL_NAME, TEXT_MODEL_NAME, category_mappings, CATEGORICAL_EMBEDDING_DIMS ) # Download model from Hugging Face Model Hub def download_model_from_hub(): """Download model and mappings from Hugging Face Model Hub""" try: print(f"đŸ“Ĩ Downloading GGG model from Hugging Face Model Hub: {HF_MODEL_REPO}") # Download model file model_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME, cache_dir=MODEL_DIR ) print(f"✅ Model downloaded to: {model_path}") # Download category mappings if not exists locally if not os.path.exists(CAT_MAPPINGS_SAVE_PATH): try: mappings_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename=HF_MAPPINGS_FILENAME, cache_dir=MODEL_DIR ) print(f"✅ Category mappings downloaded to: {mappings_path}") # Copy to expected location import shutil shutil.copy(mappings_path, CAT_MAPPINGS_SAVE_PATH) except Exception as e: print(f"âš ī¸ Could not download mappings from hub: {e}") return model_path except Exception as e: print(f"âš ī¸ Error downloading from Model Hub: {e}") print(f"🔧 Creating dummy weights for demo...") torch.save(model.state_dict(), MODEL_SAVE_PATH) return MODEL_SAVE_PATH # Use local model if available, otherwise download from hub if os.path.exists(MODEL_SAVE_PATH): model_path = MODEL_SAVE_PATH print(f"✅ Using local GGG model at {MODEL_SAVE_PATH}") else: print(f"đŸ“Ĩ Model not found locally, downloading from Model Hub...") model_path = download_model_from_hub() # Load the weights try: print(f"🚀 Loading GGG model weights from {model_path}") state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) print("✅ Successfully loaded GGG model weights from Hugging Face Model Hub") except Exception as e: print(f"âš ī¸ Error loading model weights: {e}") print("🔧 Using initialized weights for demo...") model.to(device) model.eval() # Warm up the model with a dummy forward pass for better performance if torch.cuda.is_available(): with torch.no_grad(): dummy_c_pix = torch.randn(1, 3, 224, 224).to(device) dummy_v_pix = torch.randn(1, 3, 224, 224).to(device) dummy_c_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device) dummy_c_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device) dummy_v_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device) dummy_v_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device) dummy_cat_feats = torch.randint(0, 2, (1, len(CATEGORICAL_FEATURES))).to(device) _ = model( c_pix=dummy_c_pix, v_pix=dummy_v_pix, c_tok=dummy_c_tok, c_attn=dummy_c_attn, v_tok=dummy_v_tok, v_attn=dummy_v_attn, cat_feats=dummy_cat_feats ) print("đŸ”Ĩ Model warmed up successfully!") # Load the processors for images and text image_processor = AutoProcessor.from_pretrained(VISION_MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) print("✅ Model and processors loaded successfully.") # --- 4. Prediction Functions --- def get_image_path_from_url(image_url: str, base_dir: str) -> str | None: """Constructs a local image path from a URL-like string.""" try: stem = os.path.splitext(os.path.basename(str(image_url)))[0] return os.path.join(base_dir, f"{stem}.jpeg") except (TypeError, ValueError): return None @spaces.GPU(duration=50) # Maximum allowed duration on free tier def predict_with_categorical_data(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type): """Make prediction with provided categorical data (no AI API calls)""" if control_image is None or variant_image is None: return {"error": "Please provide both control and variant images"} start_time = time.time() print(f"📋 Original categories from API: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}") # Convert specific values to parent groups using mapping.json grouped_conversion_type = convert_to_parent_group(conversion_type, "conversion_type_mappings", value_mappings) grouped_industry = convert_to_parent_group(industry, "industry_mappings", value_mappings) grouped_page_type = convert_to_parent_group(page_type, "page_type_mappings", value_mappings) print(f"📋 Mapped to parent groups: {business_model} | {customer_type} | {grouped_conversion_type} | {grouped_industry} | {grouped_page_type}") # Run the prediction with grouped categorical data prediction_result = predict_single(control_image, variant_image, business_model, customer_type, grouped_conversion_type, grouped_industry, grouped_page_type) # Create comprehensive result with prediction and confidence data result = { "predictionResults": prediction_result, "providedCategories": { "businessModel": business_model, "customerType": customer_type, "conversionType": conversion_type, "industry": industry, "pageType": page_type }, "groupedCategories": { "businessModel": business_model, "customerType": customer_type, "conversionType": grouped_conversion_type, "industry": grouped_industry, "pageType": grouped_page_type }, "processingInfo": { "totalProcessingTime": f"{time.time() - start_time:.2f}s", "confidenceSource": f"{grouped_industry} | {grouped_page_type}" } } return result @spaces.GPU(duration=60) # Maximum allowed duration on free tier def predict_single(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type): """ Orchestrates the prediction for a single pair of images and features. Note: This function expects GROUPED values for conversion_type, industry, and page_type. If calling from API, use predict_with_categorical_data() which handles the conversion automatically. """ try: if control_image is None or variant_image is None: return {"Error": 1.0, "Please upload both images": 0.0} start_time = time.time() print(f"🔍 Starting prediction with categories: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}") c_img = Image.fromarray(control_image).convert("RGB") v_img = Image.fromarray(variant_image).convert("RGB") # Extract OCR text from both images (this is crucial for model performance) try: c_text_str = pytesseract.image_to_string(c_img) v_text_str = pytesseract.image_to_string(v_img) print(f"📝 OCR extracted - Control: {len(c_text_str)} chars, Variant: {len(v_text_str)} chars") except pytesseract.TesseractNotFoundError: print("🛑 Tesseract is not installed or not in your PATH. Skipping OCR.") c_text_str, v_text_str = "", "" # Get confidence data for this combination confidence_data = get_confidence_data(business_model, customer_type, conversion_type, industry, page_type) print(f"📊 Confidence data loaded: {confidence_data}") with torch.no_grad(): c_pix = image_processor(images=c_img, return_tensors="pt").pixel_values.to(device) v_pix = image_processor(images=v_img, return_tensors="pt").pixel_values.to(device) # Process OCR text through the text model c_text = tokenizer(c_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device) v_text = tokenizer(v_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device) cat_inputs = [business_model, customer_type, conversion_type, industry, page_type] cat_codes = [category_mappings[name]['categories'].index(val) for name, val in zip(CATEGORICAL_FEATURES, cat_inputs)] cat_feats = torch.tensor([cat_codes], dtype=torch.int64).to(device) # Run the multimodal model prediction logits = model( c_pix=c_pix, v_pix=v_pix, c_tok=c_text['input_ids'], c_attn=c_text['attention_mask'], v_tok=v_text['input_ids'], v_attn=v_text['attention_mask'], cat_feats=cat_feats ) probability = torch.sigmoid(logits).item() processing_time = time.time() - start_time # Log GPU memory usage for monitoring if torch.cuda.is_available(): gpu_memory = torch.cuda.memory_allocated() / 1024**3 print(f"🚀 Prediction completed in {processing_time:.2f}s | GPU Memory: {gpu_memory:.1f}GB") else: print(f"🚀 Prediction completed in {processing_time:.2f}s") # Determine winner winner = "VARIANT WINS" if probability > 0.5 else "CONTROL WINS" confidence_percentage = confidence_data['accuracy'] * 100 # Create enhanced output with confidence scores and training data info result = { "probability": f"{probability:.3f}", "modelConfidence": f"{confidence_percentage:.1f}", "trainingDataSamples": confidence_data['training_data_count'], "totalPredictions": confidence_data['count'], "correctPredictions": confidence_data['correct_predictions'], "totalWinPrediction": confidence_data['actual_wins'], "totalLosePrediction": confidence_data['count'] - confidence_data['actual_wins'] } print(f"đŸŽ¯ Final result: {result}") return result except Exception as e: print(f"❌ ERROR in predict_single: {e}") print(f"🔍 Error type: {type(e).__name__}") import traceback traceback.print_exc() # Return error result with fallback confidence data return { "error": f"Prediction failed: {str(e)}", "modelConfidence": "50.0", "trainingDataSamples": 0, "totalPredictions": 0, "correctPredictions": 0, "totalWinPrediction": 0, "totalLosePrediction": 0 } def get_all_possible_values(): """ Get all possible values (both specific and grouped) for industry, page_type, and conversion_type. This is useful for API documentation and validation. """ all_values = { "industry": [], "page_type": [], "conversion_type": [] } # Get all industry values (both parent groups and specific values) if "industry_mappings" in value_mappings: for parent_group, child_values in value_mappings["industry_mappings"].items(): all_values["industry"].append(parent_group) all_values["industry"].extend(child_values) # Get all page type values if "page_type_mappings" in value_mappings: for parent_group, child_values in value_mappings["page_type_mappings"].items(): all_values["page_type"].append(parent_group) all_values["page_type"].extend(child_values) # Get all conversion type values if "conversion_type_mappings" in value_mappings: for parent_group, child_values in value_mappings["conversion_type_mappings"].items(): all_values["conversion_type"].append(parent_group) all_values["conversion_type"].extend(child_values) return all_values @spaces.GPU def predict_batch(csv_path, control_img_dir, variant_img_dir, num_samples): """ Handles batch prediction from a CSV file. Note: CSV should contain grouped values (not specific values) for: - grouped_conversion_type - grouped_industry - grouped_page_type """ if not all([csv_path, control_img_dir, variant_img_dir, num_samples]): return pd.DataFrame({"Error": ["Please fill in all fields."]}) try: df = pd.read_csv(csv_path) except FileNotFoundError: return pd.DataFrame({"Error": [f"CSV file not found at: {csv_path}"]}) except Exception as e: return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]}) if num_samples > len(df): print(f"âš ī¸ Requested {num_samples} samples, but CSV only has {len(df)} rows. Using all rows.") num_samples = len(df) sample_df = df.sample(n=num_samples, random_state=42) results = [] for _, row in sample_df.iterrows(): try: # Construct image paths c_path = get_image_path_from_url(row[CONTROL_IMAGE_URL_COLUMN], control_img_dir) v_path = get_image_path_from_url(row[VARIANT_IMAGE_URL_COLUMN], variant_img_dir) if not c_path or not os.path.exists(c_path): raise FileNotFoundError(f"Control image not found: {c_path}") if not v_path or not os.path.exists(v_path): raise FileNotFoundError(f"Variant image not found: {v_path}") # Get categorical features from the row (expects grouped values in CSV) cat_features_from_row = [row[f] for f in CATEGORICAL_FEATURES] # Use the core prediction logic prediction = predict_single( control_image=np.array(Image.open(c_path)), variant_image=np.array(Image.open(v_path)), business_model=cat_features_from_row[0], customer_type=cat_features_from_row[1], conversion_type=cat_features_from_row[2], industry=cat_features_from_row[3], page_type=cat_features_from_row[4] ) result_row = row.to_dict() result_row['predicted_win_probability'] = prediction.get('Win', 0.0) results.append(result_row) except Exception as e: print(f"🛑 Error processing row: {e}") error_row = row.to_dict() error_row['predicted_win_probability'] = f"ERROR: {e}" results.append(error_row) return pd.DataFrame(results) # --- 5. Build the Gradio Interface --- with gr.Blocks() as iface: gr.Markdown("# 🚀 Multimodal A/B Test Predictor") gr.Markdown(""" ### Predict A/B test outcomes using: - đŸ–ŧī¸ **Image Analysis**: Visual features from control & variant images - 📝 **OCR Text Extraction**: Automatically extracts and analyzes text from images - 📊 **Categorical Features**: Business context (industry, page type, etc.) - đŸŽ¯ **Smart Confidence Scores**: Based on Industry + Page Type combinations with high sample counts **Enhanced Reliability**: Confidence scores use Industry + Page Type combinations (avg 160 samples) instead of low-count 5-feature combinations! """) with gr.Tab("đŸŽ¯ API Prediction"): gr.Markdown("### 📊 Predict with Categorical Data") gr.Markdown(""" Upload images and provide categorical data for prediction. **Note:** For Industry, Page Type, and Conversion Type, you can provide either: - Specific values (e.g., "Accounting Services") - will be automatically converted to parent group (e.g., "B2B Services") - Parent group values (e.g., "B2B Services") - will be used directly The model uses parent groups internally, but the API accepts both for convenience. """) with gr.Row(): with gr.Column(): api_control_image = gr.Image(label="Control Image", type="numpy") api_variant_image = gr.Image(label="Variant Image", type="numpy") with gr.Column(): api_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0]) api_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0]) api_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0]) api_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0]) api_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0]) api_predict_btn = gr.Button("đŸŽ¯ Predict with Categorical Data", variant="primary", size="lg") api_output_json = gr.JSON(label="đŸŽ¯ Prediction Results with Confidence Scores") with gr.Tab("📋 Manual Selection"): gr.Markdown("### Manual Category Selection") gr.Markdown("Select categories manually if you prefer precise control.") with gr.Row(): with gr.Column(): s_control_image = gr.Image(label="Control Image", type="numpy") s_variant_image = gr.Image(label="Variant Image", type="numpy") with gr.Column(): s_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0]) s_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0]) s_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0]) s_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0]) s_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0]) s_predict_btn = gr.Button("🔮 Predict A/B Test Winner", variant="secondary") s_output_label = gr.Label(num_top_classes=6, label="đŸŽ¯ Prediction Results & Confidence Analysis") with gr.Tab("Batch Prediction from CSV"): gr.Markdown("Provide paths to your data to get predictions for multiple random samples.") b_csv_path = gr.Textbox(label="Path to CSV file", placeholder="/path/to/your/data.csv") b_control_dir = gr.Textbox(label="Path to Control Images Folder", placeholder="/path/to/control_images/") b_variant_dir = gr.Textbox(label="Path to Variant Images Folder", placeholder="/path/to/variant_images/") b_num_samples = gr.Number(label="Number of random samples to predict", value=10) b_predict_btn = gr.Button("Run Batch Prediction") b_output_df = gr.DataFrame(label="Batch Prediction Results") # Wire up the components api_predict_btn.click( fn=predict_with_categorical_data, inputs=[api_control_image, api_variant_image, api_business_model, api_customer_type, api_conversion_type, api_industry, api_page_type], outputs=api_output_json ) s_predict_btn.click( fn=predict_single, inputs=[s_control_image, s_variant_image, s_business_model, s_customer_type, s_conversion_type, s_industry, s_page_type], outputs=s_output_label ) b_predict_btn.click( fn=predict_batch, inputs=[b_csv_path, b_control_dir, b_variant_dir, b_num_samples], outputs=b_output_df ) # Launch the application if __name__ == "__main__": # AGGRESSIVE optimization for 4x L4 GPU - push to maximum limits iface.queue( max_size=128, # Much larger queue for heavy concurrent load default_concurrency_limit=64 # Push all 4 GPUs to maximum capacity ).launch( server_name="0.0.0.0", server_port=7860, show_error=True # Show detailed errors for debugging )