Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| import pandas as pd | |
| from transformers import AutoProcessor, ViTModel, AutoTokenizer, AutoModel | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import pytesseract # For OCR | |
| import spaces | |
| import random | |
| import time | |
| import subprocess | |
| import re | |
| # Load environment variables from .env file (for local development) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() # Load .env file if it exists | |
| print("โ Loaded .env file for local development") | |
| except ImportError: | |
| print("โน๏ธ python-dotenv not installed, using system environment variables only") | |
| # --- 1. Configuration (Mirrored from your scripts) --- | |
| # This ensures consistency with the model's training environment. | |
| MODEL_DIR = "model" | |
| MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_gated_model_2.7_GGG.pth") | |
| CAT_MAPPINGS_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_cat_mappings_GGG.json") | |
| # API Configuration - AI API calls removed, using direct categorical inputs | |
| # Hugging Face Model Hub Configuration | |
| # Point to your model repository (not the Space) | |
| HF_MODEL_REPO = "nitish-spz/ABTestPredictor" # Your model repository | |
| HF_MODEL_FILENAME = "multimodal_gated_model_2.7_GGG.pth" | |
| HF_MAPPINGS_FILENAME = "multimodal_cat_mappings_GGG.json" | |
| VISION_MODEL_NAME = "google/vit-base-patch16-224-in21k" | |
| TEXT_MODEL_NAME = "distilbert-base-uncased" | |
| MAX_TEXT_LENGTH = 512 | |
| # Columns from testing script | |
| CONTROL_IMAGE_URL_COLUMN = "controlImage" | |
| VARIANT_IMAGE_URL_COLUMN = "variantImage" | |
| CATEGORICAL_FEATURES = [ | |
| "Business Model", "Customer Type", "grouped_conversion_type", | |
| "grouped_industry", "grouped_page_type" | |
| ] | |
| CATEGORICAL_EMBEDDING_DIMS = { | |
| "Business Model": 10, "Customer Type": 10, "grouped_conversion_type": 25, | |
| "grouped_industry": 50, "grouped_page_type": 25 | |
| } | |
| GATED_FUSION_DIM = 64 | |
| # --- 2. Model Architecture (Exact Replica from your training script) --- | |
| # This class must be defined to load the saved model weights correctly. | |
| class SupervisedSiameseMultimodal(nn.Module): | |
| """ | |
| Updated model architecture matching the new GGG version. | |
| Includes fusion block, BatchNorm, and enhanced directional features. | |
| """ | |
| def __init__(self, vision_model_name, text_model_name, cat_mappings, cat_embedding_dims): | |
| super().__init__() | |
| self.vision_model = ViTModel.from_pretrained(vision_model_name) | |
| self.text_model = AutoModel.from_pretrained(text_model_name) | |
| vision_dim = self.vision_model.config.hidden_size | |
| text_dim = self.text_model.config.hidden_size | |
| self.embedding_layers = nn.ModuleList() | |
| total_cat_emb_dim = 0 | |
| for feature in CATEGORICAL_FEATURES: | |
| # Safely handle cases where a feature might not be in mappings | |
| if feature in cat_mappings: | |
| num_cats = cat_mappings[feature]['num_categories'] | |
| emb_dim = cat_embedding_dims[feature] | |
| self.embedding_layers.append(nn.Embedding(num_cats, emb_dim)) | |
| total_cat_emb_dim += emb_dim | |
| self.gate_controller = nn.Sequential( | |
| nn.Linear(total_cat_emb_dim, GATED_FUSION_DIM), | |
| nn.ReLU(), | |
| nn.Linear(GATED_FUSION_DIM, 2) | |
| ) | |
| # Updated in_dim calculation to match new architecture | |
| in_dim = (vision_dim * 4) + (text_dim * 4) + total_cat_emb_dim + 2 | |
| # Add the fusion block | |
| self.fusion_block = nn.Sequential( | |
| nn.Linear(in_dim, in_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.2) | |
| ) | |
| # Updated prediction head with BatchNorm | |
| self.prediction_head = nn.Sequential( | |
| nn.BatchNorm1d(in_dim), | |
| nn.Linear(in_dim, vision_dim), | |
| nn.GELU(), | |
| nn.LayerNorm(vision_dim), | |
| nn.Dropout(0.2), | |
| nn.Linear(vision_dim, vision_dim // 2), | |
| nn.GELU(), | |
| nn.LayerNorm(vision_dim // 2), | |
| nn.Dropout(0.1), | |
| nn.Linear(vision_dim // 2, 1) | |
| ) | |
| def forward(self, c_pix, v_pix, c_tok, c_attn, v_tok, v_attn, cat_feats): | |
| # Enhanced forward pass with directional features | |
| emb_c_vision = self.vision_model(pixel_values=c_pix).pooler_output | |
| emb_v_vision = self.vision_model(pixel_values=v_pix).pooler_output | |
| direction_feat_vision = torch.cat([emb_c_vision - emb_v_vision, emb_v_vision - emb_c_vision], dim=1) | |
| c_text_out = self.text_model(input_ids=c_tok, attention_mask=c_attn).last_hidden_state | |
| v_text_out = self.text_model(input_ids=v_tok, attention_mask=v_attn).last_hidden_state | |
| emb_c_text = c_text_out.mean(dim=1) | |
| emb_v_text = v_text_out.mean(dim=1) | |
| direction_feat_text = torch.cat([emb_c_text - emb_v_text, emb_v_text - emb_c_text], dim=1) | |
| cat_embeddings = [layer(cat_feats[:, i]) for i, layer in enumerate(self.embedding_layers)] | |
| final_cat_embedding = torch.cat(cat_embeddings, dim=1) | |
| gates = F.softmax(self.gate_controller(final_cat_embedding), dim=-1) | |
| vision_gate = gates[:, 0].unsqueeze(1) | |
| text_gate = gates[:, 1].unsqueeze(1) | |
| weighted_vision = direction_feat_vision * vision_gate | |
| weighted_text = direction_feat_text * text_gate | |
| batch_size = c_pix.shape[0] | |
| role_embedding = torch.tensor([[1, 0]] * batch_size, dtype=torch.float32, device=c_pix.device) | |
| final_vector = torch.cat([ | |
| emb_c_vision, emb_v_vision, | |
| emb_c_text, emb_v_text, | |
| weighted_vision, weighted_text, | |
| final_cat_embedding, | |
| role_embedding | |
| ], dim=1) | |
| # Pass through the fusion block before the final prediction head | |
| fused_vector = self.fusion_block(final_vector) | |
| return self.prediction_head(fused_vector).squeeze(-1) | |
| # --- 3. Loading Models and Processors (Done once on startup) --- | |
| # Optimized for L4 GPU setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"๐ Using device: {device}") | |
| if torch.cuda.is_available(): | |
| print(f"๐ฅ GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"๐พ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |
| # AGGRESSIVE optimizations for 4x L4 GPU | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.deterministic = False # Allow non-deterministic for speed | |
| # Aggressive memory management | |
| torch.cuda.empty_cache() | |
| # Enable tensor core usage for maximum performance | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Create dummy files if they don't exist for the app to run | |
| if not os.path.exists(MODEL_DIR): | |
| os.makedirs(MODEL_DIR) | |
| if not os.path.exists(CAT_MAPPINGS_SAVE_PATH): | |
| print(f"โ ๏ธ GGG Category mappings not found. Creating default mappings...") | |
| # Create the standard category mappings expected by the model | |
| default_mappings = { | |
| "Business Model": {"num_categories": 4, "categories": ["E-Commerce", "Lead Generation", "Other*", "SaaS"]}, | |
| "Customer Type": {"num_categories": 4, "categories": ["B2B", "B2C", "Both", "Other*"]}, | |
| "grouped_conversion_type": {"num_categories": 6, "categories": ["Direct Purchase", "High-Intent Lead Gen", "Info/Content Lead Gen", "Location Search", "Non-Profit/Community", "Other Conversion"]}, | |
| "grouped_industry": {"num_categories": 14, "categories": ["Automotive & Transportation", "B2B Services", "B2B Software & Tech", "Consumer Services", "Consumer Software & Apps", "Education", "Finance, Insurance & Real Estate", "Food, Hospitality & Travel", "Health & Wellness", "Industrial & Manufacturing", "Media & Entertainment", "Non-Profit & Government", "Other", "Retail & E-commerce"]}, | |
| "grouped_page_type": {"num_categories": 5, "categories": ["Awareness & Discovery", "Consideration & Evaluation", "Conversion", "Internal & Navigation", "Post-Conversion & Other"]} | |
| } | |
| with open(CAT_MAPPINGS_SAVE_PATH, 'w') as f: | |
| json.dump(default_mappings, f, indent=2) | |
| print(f"โ Created default category mappings at {CAT_MAPPINGS_SAVE_PATH}") | |
| with open(CAT_MAPPINGS_SAVE_PATH, 'r') as f: | |
| category_mappings = json.load(f) | |
| # Load mapping.json for converting specific values to parent groups | |
| def load_value_mappings(): | |
| """Load mapping.json for converting industry, page_type, and conversion_type to parent groups""" | |
| try: | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| mapping_file = os.path.join(script_dir, 'mapping.json') | |
| print(f"๐ Looking for mapping file at: {mapping_file}") | |
| if not os.path.exists(mapping_file): | |
| print(f"โ ๏ธ Mapping file not found, trying fallback location...") | |
| mapping_file = 'mapping.json' | |
| with open(mapping_file, 'r') as f: | |
| mapping_data = json.load(f) | |
| print(f"โ Successfully loaded mapping.json with {len(mapping_data)} mapping types") | |
| return mapping_data | |
| except Exception as e: | |
| print(f"โ ๏ธ Error loading mapping.json: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {} | |
| def convert_to_parent_group(value, mapping_type, value_mappings): | |
| """ | |
| Convert a specific value to its parent group using mapping.json | |
| Args: | |
| value: The specific value (e.g., "Accounting Services") | |
| mapping_type: Type of mapping ("industry_mappings", "page_type_mappings", "conversion_type_mappings") | |
| value_mappings: The loaded mapping.json data | |
| Returns: | |
| The parent group name (e.g., "B2B Services") | |
| """ | |
| if mapping_type not in value_mappings: | |
| print(f"โ ๏ธ Mapping type '{mapping_type}' not found in mapping.json") | |
| return value | |
| mappings = value_mappings[mapping_type] | |
| # Search for the value in all parent groups | |
| for parent_group, child_values in mappings.items(): | |
| if value in child_values: | |
| print(f"โ Mapped '{value}' -> '{parent_group}'") | |
| return parent_group | |
| # If not found, check if the value itself is a parent group | |
| if value in mappings.keys(): | |
| print(f"โน๏ธ '{value}' is already a parent group") | |
| return value | |
| print(f"โ ๏ธ Value '{value}' not found in {mapping_type}, returning as-is") | |
| return value | |
| # Load confidence scores directly from JSON file | |
| def load_confidence_scores(): | |
| """Load confidence scores from confidence_scores.json""" | |
| try: | |
| # Get the directory where this script is located | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| confidence_file = os.path.join(script_dir, 'confidence_scores.json') | |
| print(f"๐ Script directory: {script_dir}") | |
| print(f"๐ Looking for confidence file at: {confidence_file}") | |
| print(f"๐ File exists: {os.path.exists(confidence_file)}") | |
| if not os.path.exists(confidence_file): | |
| print(f"โ ๏ธ Confidence file not found, trying fallback location...") | |
| # Try current directory as fallback | |
| confidence_file = 'confidence_scores.json' | |
| print(f"๐ Fallback path: {confidence_file}") | |
| print(f"๐ Fallback exists: {os.path.exists(confidence_file)}") | |
| with open(confidence_file, 'r') as f: | |
| confidence_data = json.load(f) | |
| print(f"โ Successfully loaded {len(confidence_data)} confidence score combinations") | |
| # Print a sample to verify data | |
| sample_key = list(confidence_data.keys())[0] if confidence_data else None | |
| if sample_key: | |
| print(f"๐ Sample entry: {sample_key} = {confidence_data[sample_key]}") | |
| return confidence_data | |
| except FileNotFoundError as e: | |
| print(f"โ Confidence file not found: {e}") | |
| print(f"๐ Current working directory: {os.getcwd()}") | |
| print(f"๐ Files in script dir: {os.listdir(script_dir) if os.path.exists(script_dir) else 'N/A'}") | |
| return {} | |
| except Exception as e: | |
| print(f"โ ๏ธ Error loading confidence scores: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {} | |
| # Load value mappings for converting specific values to parent groups | |
| try: | |
| print("=" * 50) | |
| print("๐ LOADING VALUE MAPPINGS...") | |
| print("=" * 50) | |
| value_mappings = load_value_mappings() | |
| print(f"โ Value mappings loaded successfully") | |
| print("=" * 50) | |
| except Exception as e: | |
| print(f"โ ๏ธ Error loading value mappings: {e}") | |
| value_mappings = {} | |
| # Load confidence scores | |
| try: | |
| print("=" * 50) | |
| print("๐ LOADING CONFIDENCE SCORES...") | |
| print("=" * 50) | |
| confidence_scores = load_confidence_scores() | |
| print(f"โ Confidence scores loaded successfully: {len(confidence_scores)} combinations") | |
| print(f"๐ confidence_scores is empty: {len(confidence_scores) == 0}") | |
| print(f"๐ confidence_scores type: {type(confidence_scores)}") | |
| print("=" * 50) | |
| except Exception as e: | |
| print(f"โ ๏ธ Error loading confidence scores: {e}") | |
| confidence_scores = {} | |
| print(f"โ confidence_scores set to empty dict: {confidence_scores}") | |
| def get_confidence_data(business_model, customer_type, conversion_type, industry, page_type): | |
| """Get confidence data based on Industry + Page Type combination (more reliable than 5-feature combinations)""" | |
| key = f"{industry}|{page_type}" | |
| print(f"๐ Looking for confidence key: '{key}'") | |
| print(f"๐ Total confidence_scores loaded: {len(confidence_scores)}") | |
| print(f"๐ Key exists: {key in confidence_scores}") | |
| if key in confidence_scores: | |
| data = confidence_scores[key] | |
| print(f"โ Found confidence data: {data}") | |
| return data | |
| else: | |
| print(f"โ ๏ธ Key '{key}' not found, using fallback") | |
| print(f"๐ Available keys with '{industry}': {[k for k in confidence_scores.keys() if industry in k]}") | |
| return { | |
| 'accuracy': 0.5, # Default fallback | |
| 'count': 0, | |
| 'training_data_count': 0, | |
| 'correct_predictions': 0, | |
| 'actual_wins': 0, | |
| 'predicted_wins': 0 | |
| } | |
| # Instantiate the model with the loaded mappings | |
| model = SupervisedSiameseMultimodal( | |
| VISION_MODEL_NAME, TEXT_MODEL_NAME, category_mappings, CATEGORICAL_EMBEDDING_DIMS | |
| ) | |
| # Download model from Hugging Face Model Hub | |
| def download_model_from_hub(): | |
| """Download model and mappings from Hugging Face Model Hub""" | |
| try: | |
| print(f"๐ฅ Downloading GGG model from Hugging Face Model Hub: {HF_MODEL_REPO}") | |
| # Download model file | |
| model_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=HF_MODEL_FILENAME, | |
| cache_dir=MODEL_DIR | |
| ) | |
| print(f"โ Model downloaded to: {model_path}") | |
| # Download category mappings if not exists locally | |
| if not os.path.exists(CAT_MAPPINGS_SAVE_PATH): | |
| try: | |
| mappings_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=HF_MAPPINGS_FILENAME, | |
| cache_dir=MODEL_DIR | |
| ) | |
| print(f"โ Category mappings downloaded to: {mappings_path}") | |
| # Copy to expected location | |
| import shutil | |
| shutil.copy(mappings_path, CAT_MAPPINGS_SAVE_PATH) | |
| except Exception as e: | |
| print(f"โ ๏ธ Could not download mappings from hub: {e}") | |
| return model_path | |
| except Exception as e: | |
| print(f"โ ๏ธ Error downloading from Model Hub: {e}") | |
| print(f"๐ง Creating dummy weights for demo...") | |
| torch.save(model.state_dict(), MODEL_SAVE_PATH) | |
| return MODEL_SAVE_PATH | |
| # Use local model if available, otherwise download from hub | |
| if os.path.exists(MODEL_SAVE_PATH): | |
| model_path = MODEL_SAVE_PATH | |
| print(f"โ Using local GGG model at {MODEL_SAVE_PATH}") | |
| else: | |
| print(f"๐ฅ Model not found locally, downloading from Model Hub...") | |
| model_path = download_model_from_hub() | |
| # Load the weights | |
| try: | |
| print(f"๐ Loading GGG model weights from {model_path}") | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| print("โ Successfully loaded GGG model weights from Hugging Face Model Hub") | |
| except Exception as e: | |
| print(f"โ ๏ธ Error loading model weights: {e}") | |
| print("๐ง Using initialized weights for demo...") | |
| model.to(device) | |
| model.eval() | |
| # Warm up the model with a dummy forward pass for better performance | |
| if torch.cuda.is_available(): | |
| with torch.no_grad(): | |
| dummy_c_pix = torch.randn(1, 3, 224, 224).to(device) | |
| dummy_v_pix = torch.randn(1, 3, 224, 224).to(device) | |
| dummy_c_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device) | |
| dummy_c_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device) | |
| dummy_v_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device) | |
| dummy_v_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device) | |
| dummy_cat_feats = torch.randint(0, 2, (1, len(CATEGORICAL_FEATURES))).to(device) | |
| _ = model( | |
| c_pix=dummy_c_pix, v_pix=dummy_v_pix, | |
| c_tok=dummy_c_tok, c_attn=dummy_c_attn, | |
| v_tok=dummy_v_tok, v_attn=dummy_v_attn, | |
| cat_feats=dummy_cat_feats | |
| ) | |
| print("๐ฅ Model warmed up successfully!") | |
| # Load the processors for images and text | |
| image_processor = AutoProcessor.from_pretrained(VISION_MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) | |
| print("โ Model and processors loaded successfully.") | |
| # --- 4. Prediction Functions --- | |
| def get_image_path_from_url(image_url: str, base_dir: str) -> str | None: | |
| """Constructs a local image path from a URL-like string.""" | |
| try: | |
| stem = os.path.splitext(os.path.basename(str(image_url)))[0] | |
| return os.path.join(base_dir, f"{stem}.jpeg") | |
| except (TypeError, ValueError): | |
| return None | |
| # Maximum allowed duration on free tier | |
| def predict_with_categorical_data(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type): | |
| """Make prediction with provided categorical data (no AI API calls)""" | |
| if control_image is None or variant_image is None: | |
| return {"error": "Please provide both control and variant images"} | |
| start_time = time.time() | |
| print(f"๐ Original categories from API: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}") | |
| # Convert specific values to parent groups using mapping.json | |
| grouped_conversion_type = convert_to_parent_group(conversion_type, "conversion_type_mappings", value_mappings) | |
| grouped_industry = convert_to_parent_group(industry, "industry_mappings", value_mappings) | |
| grouped_page_type = convert_to_parent_group(page_type, "page_type_mappings", value_mappings) | |
| print(f"๐ Mapped to parent groups: {business_model} | {customer_type} | {grouped_conversion_type} | {grouped_industry} | {grouped_page_type}") | |
| # Run the prediction with grouped categorical data | |
| prediction_result = predict_single(control_image, variant_image, business_model, customer_type, grouped_conversion_type, grouped_industry, grouped_page_type) | |
| # Create comprehensive result with prediction and confidence data | |
| result = { | |
| "predictionResults": prediction_result, | |
| "providedCategories": { | |
| "businessModel": business_model, | |
| "customerType": customer_type, | |
| "conversionType": conversion_type, | |
| "industry": industry, | |
| "pageType": page_type | |
| }, | |
| "groupedCategories": { | |
| "businessModel": business_model, | |
| "customerType": customer_type, | |
| "conversionType": grouped_conversion_type, | |
| "industry": grouped_industry, | |
| "pageType": grouped_page_type | |
| }, | |
| "processingInfo": { | |
| "totalProcessingTime": f"{time.time() - start_time:.2f}s", | |
| "confidenceSource": f"{grouped_industry} | {grouped_page_type}" | |
| } | |
| } | |
| return result | |
| # Maximum allowed duration on free tier | |
| def predict_single(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type): | |
| """ | |
| Orchestrates the prediction for a single pair of images and features. | |
| Note: This function expects GROUPED values for conversion_type, industry, and page_type. | |
| If calling from API, use predict_with_categorical_data() which handles the conversion automatically. | |
| """ | |
| try: | |
| if control_image is None or variant_image is None: | |
| return {"Error": 1.0, "Please upload both images": 0.0} | |
| start_time = time.time() | |
| print(f"๐ Starting prediction with categories: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}") | |
| c_img = Image.fromarray(control_image).convert("RGB") | |
| v_img = Image.fromarray(variant_image).convert("RGB") | |
| # Extract OCR text from both images (this is crucial for model performance) | |
| try: | |
| c_text_str = pytesseract.image_to_string(c_img) | |
| v_text_str = pytesseract.image_to_string(v_img) | |
| print(f"๐ OCR extracted - Control: {len(c_text_str)} chars, Variant: {len(v_text_str)} chars") | |
| except pytesseract.TesseractNotFoundError: | |
| print("๐ Tesseract is not installed or not in your PATH. Skipping OCR.") | |
| c_text_str, v_text_str = "", "" | |
| # Get confidence data for this combination | |
| confidence_data = get_confidence_data(business_model, customer_type, conversion_type, industry, page_type) | |
| print(f"๐ Confidence data loaded: {confidence_data}") | |
| with torch.no_grad(): | |
| c_pix = image_processor(images=c_img, return_tensors="pt").pixel_values.to(device) | |
| v_pix = image_processor(images=v_img, return_tensors="pt").pixel_values.to(device) | |
| # Process OCR text through the text model | |
| c_text = tokenizer(c_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device) | |
| v_text = tokenizer(v_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device) | |
| cat_inputs = [business_model, customer_type, conversion_type, industry, page_type] | |
| cat_codes = [category_mappings[name]['categories'].index(val) for name, val in zip(CATEGORICAL_FEATURES, cat_inputs)] | |
| cat_feats = torch.tensor([cat_codes], dtype=torch.int64).to(device) | |
| # Run the multimodal model prediction | |
| logits = model( | |
| c_pix=c_pix, v_pix=v_pix, | |
| c_tok=c_text['input_ids'], c_attn=c_text['attention_mask'], | |
| v_tok=v_text['input_ids'], v_attn=v_text['attention_mask'], | |
| cat_feats=cat_feats | |
| ) | |
| probability = torch.sigmoid(logits).item() | |
| processing_time = time.time() - start_time | |
| # Log GPU memory usage for monitoring | |
| if torch.cuda.is_available(): | |
| gpu_memory = torch.cuda.memory_allocated() / 1024**3 | |
| print(f"๐ Prediction completed in {processing_time:.2f}s | GPU Memory: {gpu_memory:.1f}GB") | |
| else: | |
| print(f"๐ Prediction completed in {processing_time:.2f}s") | |
| # Determine winner | |
| winner = "VARIANT WINS" if probability > 0.5 else "CONTROL WINS" | |
| confidence_percentage = confidence_data['accuracy'] * 100 | |
| # Create enhanced output with confidence scores and training data info | |
| result = { | |
| "probability": f"{probability:.3f}", | |
| "modelConfidence": f"{confidence_percentage:.1f}", | |
| "trainingDataSamples": confidence_data['training_data_count'], | |
| "totalPredictions": confidence_data['count'], | |
| "correctPredictions": confidence_data['correct_predictions'], | |
| "totalWinPrediction": confidence_data['actual_wins'], | |
| "totalLosePrediction": confidence_data['count'] - confidence_data['actual_wins'] | |
| } | |
| print(f"๐ฏ Final result: {result}") | |
| return result | |
| except Exception as e: | |
| print(f"โ ERROR in predict_single: {e}") | |
| print(f"๐ Error type: {type(e).__name__}") | |
| import traceback | |
| traceback.print_exc() | |
| # Return error result with fallback confidence data | |
| return { | |
| "error": f"Prediction failed: {str(e)}", | |
| "modelConfidence": "50.0", | |
| "trainingDataSamples": 0, | |
| "totalPredictions": 0, | |
| "correctPredictions": 0, | |
| "totalWinPrediction": 0, | |
| "totalLosePrediction": 0 | |
| } | |
| def get_all_possible_values(): | |
| """ | |
| Get all possible values (both specific and grouped) for industry, page_type, and conversion_type. | |
| This is useful for API documentation and validation. | |
| """ | |
| all_values = { | |
| "industry": [], | |
| "page_type": [], | |
| "conversion_type": [] | |
| } | |
| # Get all industry values (both parent groups and specific values) | |
| if "industry_mappings" in value_mappings: | |
| for parent_group, child_values in value_mappings["industry_mappings"].items(): | |
| all_values["industry"].append(parent_group) | |
| all_values["industry"].extend(child_values) | |
| # Get all page type values | |
| if "page_type_mappings" in value_mappings: | |
| for parent_group, child_values in value_mappings["page_type_mappings"].items(): | |
| all_values["page_type"].append(parent_group) | |
| all_values["page_type"].extend(child_values) | |
| # Get all conversion type values | |
| if "conversion_type_mappings" in value_mappings: | |
| for parent_group, child_values in value_mappings["conversion_type_mappings"].items(): | |
| all_values["conversion_type"].append(parent_group) | |
| all_values["conversion_type"].extend(child_values) | |
| return all_values | |
| def predict_batch(csv_path, control_img_dir, variant_img_dir, num_samples): | |
| """ | |
| Handles batch prediction from a CSV file. | |
| Note: CSV should contain grouped values (not specific values) for: | |
| - grouped_conversion_type | |
| - grouped_industry | |
| - grouped_page_type | |
| """ | |
| if not all([csv_path, control_img_dir, variant_img_dir, num_samples]): | |
| return pd.DataFrame({"Error": ["Please fill in all fields."]}) | |
| try: | |
| df = pd.read_csv(csv_path) | |
| except FileNotFoundError: | |
| return pd.DataFrame({"Error": [f"CSV file not found at: {csv_path}"]}) | |
| except Exception as e: | |
| return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]}) | |
| if num_samples > len(df): | |
| print(f"โ ๏ธ Requested {num_samples} samples, but CSV only has {len(df)} rows. Using all rows.") | |
| num_samples = len(df) | |
| sample_df = df.sample(n=num_samples, random_state=42) | |
| results = [] | |
| for _, row in sample_df.iterrows(): | |
| try: | |
| # Construct image paths | |
| c_path = get_image_path_from_url(row[CONTROL_IMAGE_URL_COLUMN], control_img_dir) | |
| v_path = get_image_path_from_url(row[VARIANT_IMAGE_URL_COLUMN], variant_img_dir) | |
| if not c_path or not os.path.exists(c_path): | |
| raise FileNotFoundError(f"Control image not found: {c_path}") | |
| if not v_path or not os.path.exists(v_path): | |
| raise FileNotFoundError(f"Variant image not found: {v_path}") | |
| # Get categorical features from the row (expects grouped values in CSV) | |
| cat_features_from_row = [row[f] for f in CATEGORICAL_FEATURES] | |
| # Use the core prediction logic | |
| prediction = predict_single( | |
| control_image=np.array(Image.open(c_path)), | |
| variant_image=np.array(Image.open(v_path)), | |
| business_model=cat_features_from_row[0], | |
| customer_type=cat_features_from_row[1], | |
| conversion_type=cat_features_from_row[2], | |
| industry=cat_features_from_row[3], | |
| page_type=cat_features_from_row[4] | |
| ) | |
| result_row = row.to_dict() | |
| result_row['predicted_win_probability'] = prediction.get('Win', 0.0) | |
| results.append(result_row) | |
| except Exception as e: | |
| print(f"๐ Error processing row: {e}") | |
| error_row = row.to_dict() | |
| error_row['predicted_win_probability'] = f"ERROR: {e}" | |
| results.append(error_row) | |
| return pd.DataFrame(results) | |
| # --- 5. Build the Gradio Interface --- | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# ๐ Multimodal A/B Test Predictor") | |
| gr.Markdown(""" | |
| ### Predict A/B test outcomes using: | |
| - ๐ผ๏ธ **Image Analysis**: Visual features from control & variant images | |
| - ๐ **OCR Text Extraction**: Automatically extracts and analyzes text from images | |
| - ๐ **Categorical Features**: Business context (industry, page type, etc.) | |
| - ๐ฏ **Smart Confidence Scores**: Based on Industry + Page Type combinations with high sample counts | |
| **Enhanced Reliability**: Confidence scores use Industry + Page Type combinations (avg 160 samples) instead of low-count 5-feature combinations! | |
| """) | |
| with gr.Tab("๐ฏ API Prediction"): | |
| gr.Markdown("### ๐ Predict with Categorical Data") | |
| gr.Markdown(""" | |
| Upload images and provide categorical data for prediction. | |
| **Note:** For Industry, Page Type, and Conversion Type, you can provide either: | |
| - Specific values (e.g., "Accounting Services") - will be automatically converted to parent group (e.g., "B2B Services") | |
| - Parent group values (e.g., "B2B Services") - will be used directly | |
| The model uses parent groups internally, but the API accepts both for convenience. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_control_image = gr.Image(label="Control Image", type="numpy") | |
| api_variant_image = gr.Image(label="Variant Image", type="numpy") | |
| with gr.Column(): | |
| api_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0]) | |
| api_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0]) | |
| api_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0]) | |
| api_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0]) | |
| api_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0]) | |
| api_predict_btn = gr.Button("๐ฏ Predict with Categorical Data", variant="primary", size="lg") | |
| api_output_json = gr.JSON(label="๐ฏ Prediction Results with Confidence Scores") | |
| with gr.Tab("๐ Manual Selection"): | |
| gr.Markdown("### Manual Category Selection") | |
| gr.Markdown("Select categories manually if you prefer precise control.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| s_control_image = gr.Image(label="Control Image", type="numpy") | |
| s_variant_image = gr.Image(label="Variant Image", type="numpy") | |
| with gr.Column(): | |
| s_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0]) | |
| s_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0]) | |
| s_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0]) | |
| s_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0]) | |
| s_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0]) | |
| s_predict_btn = gr.Button("๐ฎ Predict A/B Test Winner", variant="secondary") | |
| s_output_label = gr.Label(num_top_classes=6, label="๐ฏ Prediction Results & Confidence Analysis") | |
| with gr.Tab("Batch Prediction from CSV"): | |
| gr.Markdown("Provide paths to your data to get predictions for multiple random samples.") | |
| b_csv_path = gr.Textbox(label="Path to CSV file", placeholder="/path/to/your/data.csv") | |
| b_control_dir = gr.Textbox(label="Path to Control Images Folder", placeholder="/path/to/control_images/") | |
| b_variant_dir = gr.Textbox(label="Path to Variant Images Folder", placeholder="/path/to/variant_images/") | |
| b_num_samples = gr.Number(label="Number of random samples to predict", value=10) | |
| b_predict_btn = gr.Button("Run Batch Prediction") | |
| b_output_df = gr.DataFrame(label="Batch Prediction Results") | |
| # Wire up the components | |
| api_predict_btn.click( | |
| fn=predict_with_categorical_data, | |
| inputs=[api_control_image, api_variant_image, api_business_model, api_customer_type, api_conversion_type, api_industry, api_page_type], | |
| outputs=api_output_json | |
| ) | |
| s_predict_btn.click( | |
| fn=predict_single, | |
| inputs=[s_control_image, s_variant_image, s_business_model, s_customer_type, s_conversion_type, s_industry, s_page_type], | |
| outputs=s_output_label | |
| ) | |
| b_predict_btn.click( | |
| fn=predict_batch, | |
| inputs=[b_csv_path, b_control_dir, b_variant_dir, b_num_samples], | |
| outputs=b_output_df | |
| ) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| # AGGRESSIVE optimization for 4x L4 GPU - push to maximum limits | |
| iface.queue( | |
| max_size=128, # Much larger queue for heavy concurrent load | |
| default_concurrency_limit=64 # Push all 4 GPUs to maximum capacity | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True # Show detailed errors for debugging | |
| ) | |