ABTestPredictor / app.py
nitish-spz's picture
Mapping of grouped metadata
5b49b49
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import pandas as pd
from transformers import AutoProcessor, ViTModel, AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
import gradio as gr
import pytesseract # For OCR
import spaces
import random
import time
import subprocess
import re
# Load environment variables from .env file (for local development)
try:
from dotenv import load_dotenv
load_dotenv() # Load .env file if it exists
print("โœ… Loaded .env file for local development")
except ImportError:
print("โ„น๏ธ python-dotenv not installed, using system environment variables only")
# --- 1. Configuration (Mirrored from your scripts) ---
# This ensures consistency with the model's training environment.
MODEL_DIR = "model"
MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_gated_model_2.7_GGG.pth")
CAT_MAPPINGS_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_cat_mappings_GGG.json")
# API Configuration - AI API calls removed, using direct categorical inputs
# Hugging Face Model Hub Configuration
# Point to your model repository (not the Space)
HF_MODEL_REPO = "nitish-spz/ABTestPredictor" # Your model repository
HF_MODEL_FILENAME = "multimodal_gated_model_2.7_GGG.pth"
HF_MAPPINGS_FILENAME = "multimodal_cat_mappings_GGG.json"
VISION_MODEL_NAME = "google/vit-base-patch16-224-in21k"
TEXT_MODEL_NAME = "distilbert-base-uncased"
MAX_TEXT_LENGTH = 512
# Columns from testing script
CONTROL_IMAGE_URL_COLUMN = "controlImage"
VARIANT_IMAGE_URL_COLUMN = "variantImage"
CATEGORICAL_FEATURES = [
"Business Model", "Customer Type", "grouped_conversion_type",
"grouped_industry", "grouped_page_type"
]
CATEGORICAL_EMBEDDING_DIMS = {
"Business Model": 10, "Customer Type": 10, "grouped_conversion_type": 25,
"grouped_industry": 50, "grouped_page_type": 25
}
GATED_FUSION_DIM = 64
# --- 2. Model Architecture (Exact Replica from your training script) ---
# This class must be defined to load the saved model weights correctly.
class SupervisedSiameseMultimodal(nn.Module):
"""
Updated model architecture matching the new GGG version.
Includes fusion block, BatchNorm, and enhanced directional features.
"""
def __init__(self, vision_model_name, text_model_name, cat_mappings, cat_embedding_dims):
super().__init__()
self.vision_model = ViTModel.from_pretrained(vision_model_name)
self.text_model = AutoModel.from_pretrained(text_model_name)
vision_dim = self.vision_model.config.hidden_size
text_dim = self.text_model.config.hidden_size
self.embedding_layers = nn.ModuleList()
total_cat_emb_dim = 0
for feature in CATEGORICAL_FEATURES:
# Safely handle cases where a feature might not be in mappings
if feature in cat_mappings:
num_cats = cat_mappings[feature]['num_categories']
emb_dim = cat_embedding_dims[feature]
self.embedding_layers.append(nn.Embedding(num_cats, emb_dim))
total_cat_emb_dim += emb_dim
self.gate_controller = nn.Sequential(
nn.Linear(total_cat_emb_dim, GATED_FUSION_DIM),
nn.ReLU(),
nn.Linear(GATED_FUSION_DIM, 2)
)
# Updated in_dim calculation to match new architecture
in_dim = (vision_dim * 4) + (text_dim * 4) + total_cat_emb_dim + 2
# Add the fusion block
self.fusion_block = nn.Sequential(
nn.Linear(in_dim, in_dim),
nn.ReLU(),
nn.Dropout(0.2)
)
# Updated prediction head with BatchNorm
self.prediction_head = nn.Sequential(
nn.BatchNorm1d(in_dim),
nn.Linear(in_dim, vision_dim),
nn.GELU(),
nn.LayerNorm(vision_dim),
nn.Dropout(0.2),
nn.Linear(vision_dim, vision_dim // 2),
nn.GELU(),
nn.LayerNorm(vision_dim // 2),
nn.Dropout(0.1),
nn.Linear(vision_dim // 2, 1)
)
def forward(self, c_pix, v_pix, c_tok, c_attn, v_tok, v_attn, cat_feats):
# Enhanced forward pass with directional features
emb_c_vision = self.vision_model(pixel_values=c_pix).pooler_output
emb_v_vision = self.vision_model(pixel_values=v_pix).pooler_output
direction_feat_vision = torch.cat([emb_c_vision - emb_v_vision, emb_v_vision - emb_c_vision], dim=1)
c_text_out = self.text_model(input_ids=c_tok, attention_mask=c_attn).last_hidden_state
v_text_out = self.text_model(input_ids=v_tok, attention_mask=v_attn).last_hidden_state
emb_c_text = c_text_out.mean(dim=1)
emb_v_text = v_text_out.mean(dim=1)
direction_feat_text = torch.cat([emb_c_text - emb_v_text, emb_v_text - emb_c_text], dim=1)
cat_embeddings = [layer(cat_feats[:, i]) for i, layer in enumerate(self.embedding_layers)]
final_cat_embedding = torch.cat(cat_embeddings, dim=1)
gates = F.softmax(self.gate_controller(final_cat_embedding), dim=-1)
vision_gate = gates[:, 0].unsqueeze(1)
text_gate = gates[:, 1].unsqueeze(1)
weighted_vision = direction_feat_vision * vision_gate
weighted_text = direction_feat_text * text_gate
batch_size = c_pix.shape[0]
role_embedding = torch.tensor([[1, 0]] * batch_size, dtype=torch.float32, device=c_pix.device)
final_vector = torch.cat([
emb_c_vision, emb_v_vision,
emb_c_text, emb_v_text,
weighted_vision, weighted_text,
final_cat_embedding,
role_embedding
], dim=1)
# Pass through the fusion block before the final prediction head
fused_vector = self.fusion_block(final_vector)
return self.prediction_head(fused_vector).squeeze(-1)
# --- 3. Loading Models and Processors (Done once on startup) ---
# Optimized for L4 GPU setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"๐Ÿš€ Using device: {device}")
if torch.cuda.is_available():
print(f"๐Ÿ”ฅ GPU: {torch.cuda.get_device_name(0)}")
print(f"๐Ÿ’พ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
# AGGRESSIVE optimizations for 4x L4 GPU
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = False # Allow non-deterministic for speed
# Aggressive memory management
torch.cuda.empty_cache()
# Enable tensor core usage for maximum performance
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Create dummy files if they don't exist for the app to run
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
if not os.path.exists(CAT_MAPPINGS_SAVE_PATH):
print(f"โš ๏ธ GGG Category mappings not found. Creating default mappings...")
# Create the standard category mappings expected by the model
default_mappings = {
"Business Model": {"num_categories": 4, "categories": ["E-Commerce", "Lead Generation", "Other*", "SaaS"]},
"Customer Type": {"num_categories": 4, "categories": ["B2B", "B2C", "Both", "Other*"]},
"grouped_conversion_type": {"num_categories": 6, "categories": ["Direct Purchase", "High-Intent Lead Gen", "Info/Content Lead Gen", "Location Search", "Non-Profit/Community", "Other Conversion"]},
"grouped_industry": {"num_categories": 14, "categories": ["Automotive & Transportation", "B2B Services", "B2B Software & Tech", "Consumer Services", "Consumer Software & Apps", "Education", "Finance, Insurance & Real Estate", "Food, Hospitality & Travel", "Health & Wellness", "Industrial & Manufacturing", "Media & Entertainment", "Non-Profit & Government", "Other", "Retail & E-commerce"]},
"grouped_page_type": {"num_categories": 5, "categories": ["Awareness & Discovery", "Consideration & Evaluation", "Conversion", "Internal & Navigation", "Post-Conversion & Other"]}
}
with open(CAT_MAPPINGS_SAVE_PATH, 'w') as f:
json.dump(default_mappings, f, indent=2)
print(f"โœ… Created default category mappings at {CAT_MAPPINGS_SAVE_PATH}")
with open(CAT_MAPPINGS_SAVE_PATH, 'r') as f:
category_mappings = json.load(f)
# Load mapping.json for converting specific values to parent groups
def load_value_mappings():
"""Load mapping.json for converting industry, page_type, and conversion_type to parent groups"""
try:
script_dir = os.path.dirname(os.path.abspath(__file__))
mapping_file = os.path.join(script_dir, 'mapping.json')
print(f"๐Ÿ“‚ Looking for mapping file at: {mapping_file}")
if not os.path.exists(mapping_file):
print(f"โš ๏ธ Mapping file not found, trying fallback location...")
mapping_file = 'mapping.json'
with open(mapping_file, 'r') as f:
mapping_data = json.load(f)
print(f"โœ… Successfully loaded mapping.json with {len(mapping_data)} mapping types")
return mapping_data
except Exception as e:
print(f"โš ๏ธ Error loading mapping.json: {e}")
import traceback
traceback.print_exc()
return {}
def convert_to_parent_group(value, mapping_type, value_mappings):
"""
Convert a specific value to its parent group using mapping.json
Args:
value: The specific value (e.g., "Accounting Services")
mapping_type: Type of mapping ("industry_mappings", "page_type_mappings", "conversion_type_mappings")
value_mappings: The loaded mapping.json data
Returns:
The parent group name (e.g., "B2B Services")
"""
if mapping_type not in value_mappings:
print(f"โš ๏ธ Mapping type '{mapping_type}' not found in mapping.json")
return value
mappings = value_mappings[mapping_type]
# Search for the value in all parent groups
for parent_group, child_values in mappings.items():
if value in child_values:
print(f"โœ… Mapped '{value}' -> '{parent_group}'")
return parent_group
# If not found, check if the value itself is a parent group
if value in mappings.keys():
print(f"โ„น๏ธ '{value}' is already a parent group")
return value
print(f"โš ๏ธ Value '{value}' not found in {mapping_type}, returning as-is")
return value
# Load confidence scores directly from JSON file
def load_confidence_scores():
"""Load confidence scores from confidence_scores.json"""
try:
# Get the directory where this script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
confidence_file = os.path.join(script_dir, 'confidence_scores.json')
print(f"๐Ÿ“‚ Script directory: {script_dir}")
print(f"๐Ÿ“‚ Looking for confidence file at: {confidence_file}")
print(f"๐Ÿ“‚ File exists: {os.path.exists(confidence_file)}")
if not os.path.exists(confidence_file):
print(f"โš ๏ธ Confidence file not found, trying fallback location...")
# Try current directory as fallback
confidence_file = 'confidence_scores.json'
print(f"๐Ÿ“‚ Fallback path: {confidence_file}")
print(f"๐Ÿ“‚ Fallback exists: {os.path.exists(confidence_file)}")
with open(confidence_file, 'r') as f:
confidence_data = json.load(f)
print(f"โœ… Successfully loaded {len(confidence_data)} confidence score combinations")
# Print a sample to verify data
sample_key = list(confidence_data.keys())[0] if confidence_data else None
if sample_key:
print(f"๐Ÿ“Š Sample entry: {sample_key} = {confidence_data[sample_key]}")
return confidence_data
except FileNotFoundError as e:
print(f"โŒ Confidence file not found: {e}")
print(f"๐Ÿ“‚ Current working directory: {os.getcwd()}")
print(f"๐Ÿ“‚ Files in script dir: {os.listdir(script_dir) if os.path.exists(script_dir) else 'N/A'}")
return {}
except Exception as e:
print(f"โš ๏ธ Error loading confidence scores: {e}")
import traceback
traceback.print_exc()
return {}
# Load value mappings for converting specific values to parent groups
try:
print("=" * 50)
print("๐Ÿš€ LOADING VALUE MAPPINGS...")
print("=" * 50)
value_mappings = load_value_mappings()
print(f"โœ… Value mappings loaded successfully")
print("=" * 50)
except Exception as e:
print(f"โš ๏ธ Error loading value mappings: {e}")
value_mappings = {}
# Load confidence scores
try:
print("=" * 50)
print("๐Ÿš€ LOADING CONFIDENCE SCORES...")
print("=" * 50)
confidence_scores = load_confidence_scores()
print(f"โœ… Confidence scores loaded successfully: {len(confidence_scores)} combinations")
print(f"๐Ÿ” confidence_scores is empty: {len(confidence_scores) == 0}")
print(f"๐Ÿ” confidence_scores type: {type(confidence_scores)}")
print("=" * 50)
except Exception as e:
print(f"โš ๏ธ Error loading confidence scores: {e}")
confidence_scores = {}
print(f"โŒ confidence_scores set to empty dict: {confidence_scores}")
def get_confidence_data(business_model, customer_type, conversion_type, industry, page_type):
"""Get confidence data based on Industry + Page Type combination (more reliable than 5-feature combinations)"""
key = f"{industry}|{page_type}"
print(f"๐Ÿ” Looking for confidence key: '{key}'")
print(f"๐Ÿ” Total confidence_scores loaded: {len(confidence_scores)}")
print(f"๐Ÿ” Key exists: {key in confidence_scores}")
if key in confidence_scores:
data = confidence_scores[key]
print(f"โœ… Found confidence data: {data}")
return data
else:
print(f"โš ๏ธ Key '{key}' not found, using fallback")
print(f"๐Ÿ” Available keys with '{industry}': {[k for k in confidence_scores.keys() if industry in k]}")
return {
'accuracy': 0.5, # Default fallback
'count': 0,
'training_data_count': 0,
'correct_predictions': 0,
'actual_wins': 0,
'predicted_wins': 0
}
# Instantiate the model with the loaded mappings
model = SupervisedSiameseMultimodal(
VISION_MODEL_NAME, TEXT_MODEL_NAME, category_mappings, CATEGORICAL_EMBEDDING_DIMS
)
# Download model from Hugging Face Model Hub
def download_model_from_hub():
"""Download model and mappings from Hugging Face Model Hub"""
try:
print(f"๐Ÿ“ฅ Downloading GGG model from Hugging Face Model Hub: {HF_MODEL_REPO}")
# Download model file
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MODEL_FILENAME,
cache_dir=MODEL_DIR
)
print(f"โœ… Model downloaded to: {model_path}")
# Download category mappings if not exists locally
if not os.path.exists(CAT_MAPPINGS_SAVE_PATH):
try:
mappings_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MAPPINGS_FILENAME,
cache_dir=MODEL_DIR
)
print(f"โœ… Category mappings downloaded to: {mappings_path}")
# Copy to expected location
import shutil
shutil.copy(mappings_path, CAT_MAPPINGS_SAVE_PATH)
except Exception as e:
print(f"โš ๏ธ Could not download mappings from hub: {e}")
return model_path
except Exception as e:
print(f"โš ๏ธ Error downloading from Model Hub: {e}")
print(f"๐Ÿ”ง Creating dummy weights for demo...")
torch.save(model.state_dict(), MODEL_SAVE_PATH)
return MODEL_SAVE_PATH
# Use local model if available, otherwise download from hub
if os.path.exists(MODEL_SAVE_PATH):
model_path = MODEL_SAVE_PATH
print(f"โœ… Using local GGG model at {MODEL_SAVE_PATH}")
else:
print(f"๐Ÿ“ฅ Model not found locally, downloading from Model Hub...")
model_path = download_model_from_hub()
# Load the weights
try:
print(f"๐Ÿš€ Loading GGG model weights from {model_path}")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
print("โœ… Successfully loaded GGG model weights from Hugging Face Model Hub")
except Exception as e:
print(f"โš ๏ธ Error loading model weights: {e}")
print("๐Ÿ”ง Using initialized weights for demo...")
model.to(device)
model.eval()
# Warm up the model with a dummy forward pass for better performance
if torch.cuda.is_available():
with torch.no_grad():
dummy_c_pix = torch.randn(1, 3, 224, 224).to(device)
dummy_v_pix = torch.randn(1, 3, 224, 224).to(device)
dummy_c_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device)
dummy_c_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device)
dummy_v_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device)
dummy_v_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device)
dummy_cat_feats = torch.randint(0, 2, (1, len(CATEGORICAL_FEATURES))).to(device)
_ = model(
c_pix=dummy_c_pix, v_pix=dummy_v_pix,
c_tok=dummy_c_tok, c_attn=dummy_c_attn,
v_tok=dummy_v_tok, v_attn=dummy_v_attn,
cat_feats=dummy_cat_feats
)
print("๐Ÿ”ฅ Model warmed up successfully!")
# Load the processors for images and text
image_processor = AutoProcessor.from_pretrained(VISION_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
print("โœ… Model and processors loaded successfully.")
# --- 4. Prediction Functions ---
def get_image_path_from_url(image_url: str, base_dir: str) -> str | None:
"""Constructs a local image path from a URL-like string."""
try:
stem = os.path.splitext(os.path.basename(str(image_url)))[0]
return os.path.join(base_dir, f"{stem}.jpeg")
except (TypeError, ValueError):
return None
@spaces.GPU(duration=50) # Maximum allowed duration on free tier
def predict_with_categorical_data(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type):
"""Make prediction with provided categorical data (no AI API calls)"""
if control_image is None or variant_image is None:
return {"error": "Please provide both control and variant images"}
start_time = time.time()
print(f"๐Ÿ“‹ Original categories from API: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}")
# Convert specific values to parent groups using mapping.json
grouped_conversion_type = convert_to_parent_group(conversion_type, "conversion_type_mappings", value_mappings)
grouped_industry = convert_to_parent_group(industry, "industry_mappings", value_mappings)
grouped_page_type = convert_to_parent_group(page_type, "page_type_mappings", value_mappings)
print(f"๐Ÿ“‹ Mapped to parent groups: {business_model} | {customer_type} | {grouped_conversion_type} | {grouped_industry} | {grouped_page_type}")
# Run the prediction with grouped categorical data
prediction_result = predict_single(control_image, variant_image, business_model, customer_type, grouped_conversion_type, grouped_industry, grouped_page_type)
# Create comprehensive result with prediction and confidence data
result = {
"predictionResults": prediction_result,
"providedCategories": {
"businessModel": business_model,
"customerType": customer_type,
"conversionType": conversion_type,
"industry": industry,
"pageType": page_type
},
"groupedCategories": {
"businessModel": business_model,
"customerType": customer_type,
"conversionType": grouped_conversion_type,
"industry": grouped_industry,
"pageType": grouped_page_type
},
"processingInfo": {
"totalProcessingTime": f"{time.time() - start_time:.2f}s",
"confidenceSource": f"{grouped_industry} | {grouped_page_type}"
}
}
return result
@spaces.GPU(duration=60) # Maximum allowed duration on free tier
def predict_single(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type):
"""
Orchestrates the prediction for a single pair of images and features.
Note: This function expects GROUPED values for conversion_type, industry, and page_type.
If calling from API, use predict_with_categorical_data() which handles the conversion automatically.
"""
try:
if control_image is None or variant_image is None:
return {"Error": 1.0, "Please upload both images": 0.0}
start_time = time.time()
print(f"๐Ÿ” Starting prediction with categories: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}")
c_img = Image.fromarray(control_image).convert("RGB")
v_img = Image.fromarray(variant_image).convert("RGB")
# Extract OCR text from both images (this is crucial for model performance)
try:
c_text_str = pytesseract.image_to_string(c_img)
v_text_str = pytesseract.image_to_string(v_img)
print(f"๐Ÿ“ OCR extracted - Control: {len(c_text_str)} chars, Variant: {len(v_text_str)} chars")
except pytesseract.TesseractNotFoundError:
print("๐Ÿ›‘ Tesseract is not installed or not in your PATH. Skipping OCR.")
c_text_str, v_text_str = "", ""
# Get confidence data for this combination
confidence_data = get_confidence_data(business_model, customer_type, conversion_type, industry, page_type)
print(f"๐Ÿ“Š Confidence data loaded: {confidence_data}")
with torch.no_grad():
c_pix = image_processor(images=c_img, return_tensors="pt").pixel_values.to(device)
v_pix = image_processor(images=v_img, return_tensors="pt").pixel_values.to(device)
# Process OCR text through the text model
c_text = tokenizer(c_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device)
v_text = tokenizer(v_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device)
cat_inputs = [business_model, customer_type, conversion_type, industry, page_type]
cat_codes = [category_mappings[name]['categories'].index(val) for name, val in zip(CATEGORICAL_FEATURES, cat_inputs)]
cat_feats = torch.tensor([cat_codes], dtype=torch.int64).to(device)
# Run the multimodal model prediction
logits = model(
c_pix=c_pix, v_pix=v_pix,
c_tok=c_text['input_ids'], c_attn=c_text['attention_mask'],
v_tok=v_text['input_ids'], v_attn=v_text['attention_mask'],
cat_feats=cat_feats
)
probability = torch.sigmoid(logits).item()
processing_time = time.time() - start_time
# Log GPU memory usage for monitoring
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / 1024**3
print(f"๐Ÿš€ Prediction completed in {processing_time:.2f}s | GPU Memory: {gpu_memory:.1f}GB")
else:
print(f"๐Ÿš€ Prediction completed in {processing_time:.2f}s")
# Determine winner
winner = "VARIANT WINS" if probability > 0.5 else "CONTROL WINS"
confidence_percentage = confidence_data['accuracy'] * 100
# Create enhanced output with confidence scores and training data info
result = {
"probability": f"{probability:.3f}",
"modelConfidence": f"{confidence_percentage:.1f}",
"trainingDataSamples": confidence_data['training_data_count'],
"totalPredictions": confidence_data['count'],
"correctPredictions": confidence_data['correct_predictions'],
"totalWinPrediction": confidence_data['actual_wins'],
"totalLosePrediction": confidence_data['count'] - confidence_data['actual_wins']
}
print(f"๐ŸŽฏ Final result: {result}")
return result
except Exception as e:
print(f"โŒ ERROR in predict_single: {e}")
print(f"๐Ÿ” Error type: {type(e).__name__}")
import traceback
traceback.print_exc()
# Return error result with fallback confidence data
return {
"error": f"Prediction failed: {str(e)}",
"modelConfidence": "50.0",
"trainingDataSamples": 0,
"totalPredictions": 0,
"correctPredictions": 0,
"totalWinPrediction": 0,
"totalLosePrediction": 0
}
def get_all_possible_values():
"""
Get all possible values (both specific and grouped) for industry, page_type, and conversion_type.
This is useful for API documentation and validation.
"""
all_values = {
"industry": [],
"page_type": [],
"conversion_type": []
}
# Get all industry values (both parent groups and specific values)
if "industry_mappings" in value_mappings:
for parent_group, child_values in value_mappings["industry_mappings"].items():
all_values["industry"].append(parent_group)
all_values["industry"].extend(child_values)
# Get all page type values
if "page_type_mappings" in value_mappings:
for parent_group, child_values in value_mappings["page_type_mappings"].items():
all_values["page_type"].append(parent_group)
all_values["page_type"].extend(child_values)
# Get all conversion type values
if "conversion_type_mappings" in value_mappings:
for parent_group, child_values in value_mappings["conversion_type_mappings"].items():
all_values["conversion_type"].append(parent_group)
all_values["conversion_type"].extend(child_values)
return all_values
@spaces.GPU
def predict_batch(csv_path, control_img_dir, variant_img_dir, num_samples):
"""
Handles batch prediction from a CSV file.
Note: CSV should contain grouped values (not specific values) for:
- grouped_conversion_type
- grouped_industry
- grouped_page_type
"""
if not all([csv_path, control_img_dir, variant_img_dir, num_samples]):
return pd.DataFrame({"Error": ["Please fill in all fields."]})
try:
df = pd.read_csv(csv_path)
except FileNotFoundError:
return pd.DataFrame({"Error": [f"CSV file not found at: {csv_path}"]})
except Exception as e:
return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]})
if num_samples > len(df):
print(f"โš ๏ธ Requested {num_samples} samples, but CSV only has {len(df)} rows. Using all rows.")
num_samples = len(df)
sample_df = df.sample(n=num_samples, random_state=42)
results = []
for _, row in sample_df.iterrows():
try:
# Construct image paths
c_path = get_image_path_from_url(row[CONTROL_IMAGE_URL_COLUMN], control_img_dir)
v_path = get_image_path_from_url(row[VARIANT_IMAGE_URL_COLUMN], variant_img_dir)
if not c_path or not os.path.exists(c_path):
raise FileNotFoundError(f"Control image not found: {c_path}")
if not v_path or not os.path.exists(v_path):
raise FileNotFoundError(f"Variant image not found: {v_path}")
# Get categorical features from the row (expects grouped values in CSV)
cat_features_from_row = [row[f] for f in CATEGORICAL_FEATURES]
# Use the core prediction logic
prediction = predict_single(
control_image=np.array(Image.open(c_path)),
variant_image=np.array(Image.open(v_path)),
business_model=cat_features_from_row[0],
customer_type=cat_features_from_row[1],
conversion_type=cat_features_from_row[2],
industry=cat_features_from_row[3],
page_type=cat_features_from_row[4]
)
result_row = row.to_dict()
result_row['predicted_win_probability'] = prediction.get('Win', 0.0)
results.append(result_row)
except Exception as e:
print(f"๐Ÿ›‘ Error processing row: {e}")
error_row = row.to_dict()
error_row['predicted_win_probability'] = f"ERROR: {e}"
results.append(error_row)
return pd.DataFrame(results)
# --- 5. Build the Gradio Interface ---
with gr.Blocks() as iface:
gr.Markdown("# ๐Ÿš€ Multimodal A/B Test Predictor")
gr.Markdown("""
### Predict A/B test outcomes using:
- ๐Ÿ–ผ๏ธ **Image Analysis**: Visual features from control & variant images
- ๐Ÿ“ **OCR Text Extraction**: Automatically extracts and analyzes text from images
- ๐Ÿ“Š **Categorical Features**: Business context (industry, page type, etc.)
- ๐ŸŽฏ **Smart Confidence Scores**: Based on Industry + Page Type combinations with high sample counts
**Enhanced Reliability**: Confidence scores use Industry + Page Type combinations (avg 160 samples) instead of low-count 5-feature combinations!
""")
with gr.Tab("๐ŸŽฏ API Prediction"):
gr.Markdown("### ๐Ÿ“Š Predict with Categorical Data")
gr.Markdown("""
Upload images and provide categorical data for prediction.
**Note:** For Industry, Page Type, and Conversion Type, you can provide either:
- Specific values (e.g., "Accounting Services") - will be automatically converted to parent group (e.g., "B2B Services")
- Parent group values (e.g., "B2B Services") - will be used directly
The model uses parent groups internally, but the API accepts both for convenience.
""")
with gr.Row():
with gr.Column():
api_control_image = gr.Image(label="Control Image", type="numpy")
api_variant_image = gr.Image(label="Variant Image", type="numpy")
with gr.Column():
api_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0])
api_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0])
api_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0])
api_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0])
api_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0])
api_predict_btn = gr.Button("๐ŸŽฏ Predict with Categorical Data", variant="primary", size="lg")
api_output_json = gr.JSON(label="๐ŸŽฏ Prediction Results with Confidence Scores")
with gr.Tab("๐Ÿ“‹ Manual Selection"):
gr.Markdown("### Manual Category Selection")
gr.Markdown("Select categories manually if you prefer precise control.")
with gr.Row():
with gr.Column():
s_control_image = gr.Image(label="Control Image", type="numpy")
s_variant_image = gr.Image(label="Variant Image", type="numpy")
with gr.Column():
s_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0])
s_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0])
s_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0])
s_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0])
s_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0])
s_predict_btn = gr.Button("๐Ÿ”ฎ Predict A/B Test Winner", variant="secondary")
s_output_label = gr.Label(num_top_classes=6, label="๐ŸŽฏ Prediction Results & Confidence Analysis")
with gr.Tab("Batch Prediction from CSV"):
gr.Markdown("Provide paths to your data to get predictions for multiple random samples.")
b_csv_path = gr.Textbox(label="Path to CSV file", placeholder="/path/to/your/data.csv")
b_control_dir = gr.Textbox(label="Path to Control Images Folder", placeholder="/path/to/control_images/")
b_variant_dir = gr.Textbox(label="Path to Variant Images Folder", placeholder="/path/to/variant_images/")
b_num_samples = gr.Number(label="Number of random samples to predict", value=10)
b_predict_btn = gr.Button("Run Batch Prediction")
b_output_df = gr.DataFrame(label="Batch Prediction Results")
# Wire up the components
api_predict_btn.click(
fn=predict_with_categorical_data,
inputs=[api_control_image, api_variant_image, api_business_model, api_customer_type, api_conversion_type, api_industry, api_page_type],
outputs=api_output_json
)
s_predict_btn.click(
fn=predict_single,
inputs=[s_control_image, s_variant_image, s_business_model, s_customer_type, s_conversion_type, s_industry, s_page_type],
outputs=s_output_label
)
b_predict_btn.click(
fn=predict_batch,
inputs=[b_csv_path, b_control_dir, b_variant_dir, b_num_samples],
outputs=b_output_df
)
# Launch the application
if __name__ == "__main__":
# AGGRESSIVE optimization for 4x L4 GPU - push to maximum limits
iface.queue(
max_size=128, # Much larger queue for heavy concurrent load
default_concurrency_limit=64 # Push all 4 GPUs to maximum capacity
).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True # Show detailed errors for debugging
)