Spaces:
Runtime error
Runtime error
File size: 35,211 Bytes
37380c1 8f319c2 37380c1 e749f25 37380c1 ec4c5af 37380c1 e93a798 37380c1 5b49b49 cfff171 37380c1 cfff171 37380c1 e87ec54 cfff171 e87ec54 cfff171 37380c1 cfff171 37380c1 cfff171 37380c1 cfff171 37380c1 e87ec54 37380c1 5b49b49 37380c1 cfff171 37380c1 cfff171 37380c1 cfff171 37380c1 fef82a4 37380c1 81aa4c1 37380c1 81aa4c1 37380c1 4a19fdb e749f25 37380c1 e749f25 37380c1 5b49b49 37380c1 5b49b49 37380c1 e749f25 84c9af7 e749f25 84c9af7 05bf721 84c9af7 37380c1 5b49b49 84c9af7 5b49b49 37380c1 e749f25 37380c1 4a19fdb 37380c1 5b49b49 37380c1 774e5c0 37380c1 774e5c0 37380c1 774e5c0 37380c1 774e5c0 05bf721 e43b46b fba0a01 87dad1f fba0a01 774e5c0 05bf721 e43b46b fba0a01 84c9af7 fba0a01 774e5c0 37380c1 5b49b49 37380c1 5b49b49 37380c1 5b49b49 37380c1 e749f25 5b49b49 37380c1 e749f25 37380c1 e749f25 37380c1 e749f25 37380c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 |
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import pandas as pd
from transformers import AutoProcessor, ViTModel, AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
import gradio as gr
import pytesseract # For OCR
import spaces
import random
import time
import subprocess
import re
# Load environment variables from .env file (for local development)
try:
from dotenv import load_dotenv
load_dotenv() # Load .env file if it exists
print("โ
Loaded .env file for local development")
except ImportError:
print("โน๏ธ python-dotenv not installed, using system environment variables only")
# --- 1. Configuration (Mirrored from your scripts) ---
# This ensures consistency with the model's training environment.
MODEL_DIR = "model"
MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_gated_model_2.7_GGG.pth")
CAT_MAPPINGS_SAVE_PATH = os.path.join(MODEL_DIR, "multimodal_cat_mappings_GGG.json")
# API Configuration - AI API calls removed, using direct categorical inputs
# Hugging Face Model Hub Configuration
# Point to your model repository (not the Space)
HF_MODEL_REPO = "nitish-spz/ABTestPredictor" # Your model repository
HF_MODEL_FILENAME = "multimodal_gated_model_2.7_GGG.pth"
HF_MAPPINGS_FILENAME = "multimodal_cat_mappings_GGG.json"
VISION_MODEL_NAME = "google/vit-base-patch16-224-in21k"
TEXT_MODEL_NAME = "distilbert-base-uncased"
MAX_TEXT_LENGTH = 512
# Columns from testing script
CONTROL_IMAGE_URL_COLUMN = "controlImage"
VARIANT_IMAGE_URL_COLUMN = "variantImage"
CATEGORICAL_FEATURES = [
"Business Model", "Customer Type", "grouped_conversion_type",
"grouped_industry", "grouped_page_type"
]
CATEGORICAL_EMBEDDING_DIMS = {
"Business Model": 10, "Customer Type": 10, "grouped_conversion_type": 25,
"grouped_industry": 50, "grouped_page_type": 25
}
GATED_FUSION_DIM = 64
# --- 2. Model Architecture (Exact Replica from your training script) ---
# This class must be defined to load the saved model weights correctly.
class SupervisedSiameseMultimodal(nn.Module):
"""
Updated model architecture matching the new GGG version.
Includes fusion block, BatchNorm, and enhanced directional features.
"""
def __init__(self, vision_model_name, text_model_name, cat_mappings, cat_embedding_dims):
super().__init__()
self.vision_model = ViTModel.from_pretrained(vision_model_name)
self.text_model = AutoModel.from_pretrained(text_model_name)
vision_dim = self.vision_model.config.hidden_size
text_dim = self.text_model.config.hidden_size
self.embedding_layers = nn.ModuleList()
total_cat_emb_dim = 0
for feature in CATEGORICAL_FEATURES:
# Safely handle cases where a feature might not be in mappings
if feature in cat_mappings:
num_cats = cat_mappings[feature]['num_categories']
emb_dim = cat_embedding_dims[feature]
self.embedding_layers.append(nn.Embedding(num_cats, emb_dim))
total_cat_emb_dim += emb_dim
self.gate_controller = nn.Sequential(
nn.Linear(total_cat_emb_dim, GATED_FUSION_DIM),
nn.ReLU(),
nn.Linear(GATED_FUSION_DIM, 2)
)
# Updated in_dim calculation to match new architecture
in_dim = (vision_dim * 4) + (text_dim * 4) + total_cat_emb_dim + 2
# Add the fusion block
self.fusion_block = nn.Sequential(
nn.Linear(in_dim, in_dim),
nn.ReLU(),
nn.Dropout(0.2)
)
# Updated prediction head with BatchNorm
self.prediction_head = nn.Sequential(
nn.BatchNorm1d(in_dim),
nn.Linear(in_dim, vision_dim),
nn.GELU(),
nn.LayerNorm(vision_dim),
nn.Dropout(0.2),
nn.Linear(vision_dim, vision_dim // 2),
nn.GELU(),
nn.LayerNorm(vision_dim // 2),
nn.Dropout(0.1),
nn.Linear(vision_dim // 2, 1)
)
def forward(self, c_pix, v_pix, c_tok, c_attn, v_tok, v_attn, cat_feats):
# Enhanced forward pass with directional features
emb_c_vision = self.vision_model(pixel_values=c_pix).pooler_output
emb_v_vision = self.vision_model(pixel_values=v_pix).pooler_output
direction_feat_vision = torch.cat([emb_c_vision - emb_v_vision, emb_v_vision - emb_c_vision], dim=1)
c_text_out = self.text_model(input_ids=c_tok, attention_mask=c_attn).last_hidden_state
v_text_out = self.text_model(input_ids=v_tok, attention_mask=v_attn).last_hidden_state
emb_c_text = c_text_out.mean(dim=1)
emb_v_text = v_text_out.mean(dim=1)
direction_feat_text = torch.cat([emb_c_text - emb_v_text, emb_v_text - emb_c_text], dim=1)
cat_embeddings = [layer(cat_feats[:, i]) for i, layer in enumerate(self.embedding_layers)]
final_cat_embedding = torch.cat(cat_embeddings, dim=1)
gates = F.softmax(self.gate_controller(final_cat_embedding), dim=-1)
vision_gate = gates[:, 0].unsqueeze(1)
text_gate = gates[:, 1].unsqueeze(1)
weighted_vision = direction_feat_vision * vision_gate
weighted_text = direction_feat_text * text_gate
batch_size = c_pix.shape[0]
role_embedding = torch.tensor([[1, 0]] * batch_size, dtype=torch.float32, device=c_pix.device)
final_vector = torch.cat([
emb_c_vision, emb_v_vision,
emb_c_text, emb_v_text,
weighted_vision, weighted_text,
final_cat_embedding,
role_embedding
], dim=1)
# Pass through the fusion block before the final prediction head
fused_vector = self.fusion_block(final_vector)
return self.prediction_head(fused_vector).squeeze(-1)
# --- 3. Loading Models and Processors (Done once on startup) ---
# Optimized for L4 GPU setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"๐ Using device: {device}")
if torch.cuda.is_available():
print(f"๐ฅ GPU: {torch.cuda.get_device_name(0)}")
print(f"๐พ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
# AGGRESSIVE optimizations for 4x L4 GPU
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = False # Allow non-deterministic for speed
# Aggressive memory management
torch.cuda.empty_cache()
# Enable tensor core usage for maximum performance
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Create dummy files if they don't exist for the app to run
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
if not os.path.exists(CAT_MAPPINGS_SAVE_PATH):
print(f"โ ๏ธ GGG Category mappings not found. Creating default mappings...")
# Create the standard category mappings expected by the model
default_mappings = {
"Business Model": {"num_categories": 4, "categories": ["E-Commerce", "Lead Generation", "Other*", "SaaS"]},
"Customer Type": {"num_categories": 4, "categories": ["B2B", "B2C", "Both", "Other*"]},
"grouped_conversion_type": {"num_categories": 6, "categories": ["Direct Purchase", "High-Intent Lead Gen", "Info/Content Lead Gen", "Location Search", "Non-Profit/Community", "Other Conversion"]},
"grouped_industry": {"num_categories": 14, "categories": ["Automotive & Transportation", "B2B Services", "B2B Software & Tech", "Consumer Services", "Consumer Software & Apps", "Education", "Finance, Insurance & Real Estate", "Food, Hospitality & Travel", "Health & Wellness", "Industrial & Manufacturing", "Media & Entertainment", "Non-Profit & Government", "Other", "Retail & E-commerce"]},
"grouped_page_type": {"num_categories": 5, "categories": ["Awareness & Discovery", "Consideration & Evaluation", "Conversion", "Internal & Navigation", "Post-Conversion & Other"]}
}
with open(CAT_MAPPINGS_SAVE_PATH, 'w') as f:
json.dump(default_mappings, f, indent=2)
print(f"โ
Created default category mappings at {CAT_MAPPINGS_SAVE_PATH}")
with open(CAT_MAPPINGS_SAVE_PATH, 'r') as f:
category_mappings = json.load(f)
# Load mapping.json for converting specific values to parent groups
def load_value_mappings():
"""Load mapping.json for converting industry, page_type, and conversion_type to parent groups"""
try:
script_dir = os.path.dirname(os.path.abspath(__file__))
mapping_file = os.path.join(script_dir, 'mapping.json')
print(f"๐ Looking for mapping file at: {mapping_file}")
if not os.path.exists(mapping_file):
print(f"โ ๏ธ Mapping file not found, trying fallback location...")
mapping_file = 'mapping.json'
with open(mapping_file, 'r') as f:
mapping_data = json.load(f)
print(f"โ
Successfully loaded mapping.json with {len(mapping_data)} mapping types")
return mapping_data
except Exception as e:
print(f"โ ๏ธ Error loading mapping.json: {e}")
import traceback
traceback.print_exc()
return {}
def convert_to_parent_group(value, mapping_type, value_mappings):
"""
Convert a specific value to its parent group using mapping.json
Args:
value: The specific value (e.g., "Accounting Services")
mapping_type: Type of mapping ("industry_mappings", "page_type_mappings", "conversion_type_mappings")
value_mappings: The loaded mapping.json data
Returns:
The parent group name (e.g., "B2B Services")
"""
if mapping_type not in value_mappings:
print(f"โ ๏ธ Mapping type '{mapping_type}' not found in mapping.json")
return value
mappings = value_mappings[mapping_type]
# Search for the value in all parent groups
for parent_group, child_values in mappings.items():
if value in child_values:
print(f"โ
Mapped '{value}' -> '{parent_group}'")
return parent_group
# If not found, check if the value itself is a parent group
if value in mappings.keys():
print(f"โน๏ธ '{value}' is already a parent group")
return value
print(f"โ ๏ธ Value '{value}' not found in {mapping_type}, returning as-is")
return value
# Load confidence scores directly from JSON file
def load_confidence_scores():
"""Load confidence scores from confidence_scores.json"""
try:
# Get the directory where this script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
confidence_file = os.path.join(script_dir, 'confidence_scores.json')
print(f"๐ Script directory: {script_dir}")
print(f"๐ Looking for confidence file at: {confidence_file}")
print(f"๐ File exists: {os.path.exists(confidence_file)}")
if not os.path.exists(confidence_file):
print(f"โ ๏ธ Confidence file not found, trying fallback location...")
# Try current directory as fallback
confidence_file = 'confidence_scores.json'
print(f"๐ Fallback path: {confidence_file}")
print(f"๐ Fallback exists: {os.path.exists(confidence_file)}")
with open(confidence_file, 'r') as f:
confidence_data = json.load(f)
print(f"โ
Successfully loaded {len(confidence_data)} confidence score combinations")
# Print a sample to verify data
sample_key = list(confidence_data.keys())[0] if confidence_data else None
if sample_key:
print(f"๐ Sample entry: {sample_key} = {confidence_data[sample_key]}")
return confidence_data
except FileNotFoundError as e:
print(f"โ Confidence file not found: {e}")
print(f"๐ Current working directory: {os.getcwd()}")
print(f"๐ Files in script dir: {os.listdir(script_dir) if os.path.exists(script_dir) else 'N/A'}")
return {}
except Exception as e:
print(f"โ ๏ธ Error loading confidence scores: {e}")
import traceback
traceback.print_exc()
return {}
# Load value mappings for converting specific values to parent groups
try:
print("=" * 50)
print("๐ LOADING VALUE MAPPINGS...")
print("=" * 50)
value_mappings = load_value_mappings()
print(f"โ
Value mappings loaded successfully")
print("=" * 50)
except Exception as e:
print(f"โ ๏ธ Error loading value mappings: {e}")
value_mappings = {}
# Load confidence scores
try:
print("=" * 50)
print("๐ LOADING CONFIDENCE SCORES...")
print("=" * 50)
confidence_scores = load_confidence_scores()
print(f"โ
Confidence scores loaded successfully: {len(confidence_scores)} combinations")
print(f"๐ confidence_scores is empty: {len(confidence_scores) == 0}")
print(f"๐ confidence_scores type: {type(confidence_scores)}")
print("=" * 50)
except Exception as e:
print(f"โ ๏ธ Error loading confidence scores: {e}")
confidence_scores = {}
print(f"โ confidence_scores set to empty dict: {confidence_scores}")
def get_confidence_data(business_model, customer_type, conversion_type, industry, page_type):
"""Get confidence data based on Industry + Page Type combination (more reliable than 5-feature combinations)"""
key = f"{industry}|{page_type}"
print(f"๐ Looking for confidence key: '{key}'")
print(f"๐ Total confidence_scores loaded: {len(confidence_scores)}")
print(f"๐ Key exists: {key in confidence_scores}")
if key in confidence_scores:
data = confidence_scores[key]
print(f"โ
Found confidence data: {data}")
return data
else:
print(f"โ ๏ธ Key '{key}' not found, using fallback")
print(f"๐ Available keys with '{industry}': {[k for k in confidence_scores.keys() if industry in k]}")
return {
'accuracy': 0.5, # Default fallback
'count': 0,
'training_data_count': 0,
'correct_predictions': 0,
'actual_wins': 0,
'predicted_wins': 0
}
# Instantiate the model with the loaded mappings
model = SupervisedSiameseMultimodal(
VISION_MODEL_NAME, TEXT_MODEL_NAME, category_mappings, CATEGORICAL_EMBEDDING_DIMS
)
# Download model from Hugging Face Model Hub
def download_model_from_hub():
"""Download model and mappings from Hugging Face Model Hub"""
try:
print(f"๐ฅ Downloading GGG model from Hugging Face Model Hub: {HF_MODEL_REPO}")
# Download model file
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MODEL_FILENAME,
cache_dir=MODEL_DIR
)
print(f"โ
Model downloaded to: {model_path}")
# Download category mappings if not exists locally
if not os.path.exists(CAT_MAPPINGS_SAVE_PATH):
try:
mappings_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MAPPINGS_FILENAME,
cache_dir=MODEL_DIR
)
print(f"โ
Category mappings downloaded to: {mappings_path}")
# Copy to expected location
import shutil
shutil.copy(mappings_path, CAT_MAPPINGS_SAVE_PATH)
except Exception as e:
print(f"โ ๏ธ Could not download mappings from hub: {e}")
return model_path
except Exception as e:
print(f"โ ๏ธ Error downloading from Model Hub: {e}")
print(f"๐ง Creating dummy weights for demo...")
torch.save(model.state_dict(), MODEL_SAVE_PATH)
return MODEL_SAVE_PATH
# Use local model if available, otherwise download from hub
if os.path.exists(MODEL_SAVE_PATH):
model_path = MODEL_SAVE_PATH
print(f"โ
Using local GGG model at {MODEL_SAVE_PATH}")
else:
print(f"๐ฅ Model not found locally, downloading from Model Hub...")
model_path = download_model_from_hub()
# Load the weights
try:
print(f"๐ Loading GGG model weights from {model_path}")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
print("โ
Successfully loaded GGG model weights from Hugging Face Model Hub")
except Exception as e:
print(f"โ ๏ธ Error loading model weights: {e}")
print("๐ง Using initialized weights for demo...")
model.to(device)
model.eval()
# Warm up the model with a dummy forward pass for better performance
if torch.cuda.is_available():
with torch.no_grad():
dummy_c_pix = torch.randn(1, 3, 224, 224).to(device)
dummy_v_pix = torch.randn(1, 3, 224, 224).to(device)
dummy_c_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device)
dummy_c_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device)
dummy_v_tok = torch.randint(0, 1000, (1, MAX_TEXT_LENGTH)).to(device)
dummy_v_attn = torch.ones(1, MAX_TEXT_LENGTH).to(device)
dummy_cat_feats = torch.randint(0, 2, (1, len(CATEGORICAL_FEATURES))).to(device)
_ = model(
c_pix=dummy_c_pix, v_pix=dummy_v_pix,
c_tok=dummy_c_tok, c_attn=dummy_c_attn,
v_tok=dummy_v_tok, v_attn=dummy_v_attn,
cat_feats=dummy_cat_feats
)
print("๐ฅ Model warmed up successfully!")
# Load the processors for images and text
image_processor = AutoProcessor.from_pretrained(VISION_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
print("โ
Model and processors loaded successfully.")
# --- 4. Prediction Functions ---
def get_image_path_from_url(image_url: str, base_dir: str) -> str | None:
"""Constructs a local image path from a URL-like string."""
try:
stem = os.path.splitext(os.path.basename(str(image_url)))[0]
return os.path.join(base_dir, f"{stem}.jpeg")
except (TypeError, ValueError):
return None
@spaces.GPU(duration=50) # Maximum allowed duration on free tier
def predict_with_categorical_data(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type):
"""Make prediction with provided categorical data (no AI API calls)"""
if control_image is None or variant_image is None:
return {"error": "Please provide both control and variant images"}
start_time = time.time()
print(f"๐ Original categories from API: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}")
# Convert specific values to parent groups using mapping.json
grouped_conversion_type = convert_to_parent_group(conversion_type, "conversion_type_mappings", value_mappings)
grouped_industry = convert_to_parent_group(industry, "industry_mappings", value_mappings)
grouped_page_type = convert_to_parent_group(page_type, "page_type_mappings", value_mappings)
print(f"๐ Mapped to parent groups: {business_model} | {customer_type} | {grouped_conversion_type} | {grouped_industry} | {grouped_page_type}")
# Run the prediction with grouped categorical data
prediction_result = predict_single(control_image, variant_image, business_model, customer_type, grouped_conversion_type, grouped_industry, grouped_page_type)
# Create comprehensive result with prediction and confidence data
result = {
"predictionResults": prediction_result,
"providedCategories": {
"businessModel": business_model,
"customerType": customer_type,
"conversionType": conversion_type,
"industry": industry,
"pageType": page_type
},
"groupedCategories": {
"businessModel": business_model,
"customerType": customer_type,
"conversionType": grouped_conversion_type,
"industry": grouped_industry,
"pageType": grouped_page_type
},
"processingInfo": {
"totalProcessingTime": f"{time.time() - start_time:.2f}s",
"confidenceSource": f"{grouped_industry} | {grouped_page_type}"
}
}
return result
@spaces.GPU(duration=60) # Maximum allowed duration on free tier
def predict_single(control_image, variant_image, business_model, customer_type, conversion_type, industry, page_type):
"""
Orchestrates the prediction for a single pair of images and features.
Note: This function expects GROUPED values for conversion_type, industry, and page_type.
If calling from API, use predict_with_categorical_data() which handles the conversion automatically.
"""
try:
if control_image is None or variant_image is None:
return {"Error": 1.0, "Please upload both images": 0.0}
start_time = time.time()
print(f"๐ Starting prediction with categories: {business_model} | {customer_type} | {conversion_type} | {industry} | {page_type}")
c_img = Image.fromarray(control_image).convert("RGB")
v_img = Image.fromarray(variant_image).convert("RGB")
# Extract OCR text from both images (this is crucial for model performance)
try:
c_text_str = pytesseract.image_to_string(c_img)
v_text_str = pytesseract.image_to_string(v_img)
print(f"๐ OCR extracted - Control: {len(c_text_str)} chars, Variant: {len(v_text_str)} chars")
except pytesseract.TesseractNotFoundError:
print("๐ Tesseract is not installed or not in your PATH. Skipping OCR.")
c_text_str, v_text_str = "", ""
# Get confidence data for this combination
confidence_data = get_confidence_data(business_model, customer_type, conversion_type, industry, page_type)
print(f"๐ Confidence data loaded: {confidence_data}")
with torch.no_grad():
c_pix = image_processor(images=c_img, return_tensors="pt").pixel_values.to(device)
v_pix = image_processor(images=v_img, return_tensors="pt").pixel_values.to(device)
# Process OCR text through the text model
c_text = tokenizer(c_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device)
v_text = tokenizer(v_text_str, padding='max_length', truncation=True, max_length=MAX_TEXT_LENGTH, return_tensors='pt').to(device)
cat_inputs = [business_model, customer_type, conversion_type, industry, page_type]
cat_codes = [category_mappings[name]['categories'].index(val) for name, val in zip(CATEGORICAL_FEATURES, cat_inputs)]
cat_feats = torch.tensor([cat_codes], dtype=torch.int64).to(device)
# Run the multimodal model prediction
logits = model(
c_pix=c_pix, v_pix=v_pix,
c_tok=c_text['input_ids'], c_attn=c_text['attention_mask'],
v_tok=v_text['input_ids'], v_attn=v_text['attention_mask'],
cat_feats=cat_feats
)
probability = torch.sigmoid(logits).item()
processing_time = time.time() - start_time
# Log GPU memory usage for monitoring
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / 1024**3
print(f"๐ Prediction completed in {processing_time:.2f}s | GPU Memory: {gpu_memory:.1f}GB")
else:
print(f"๐ Prediction completed in {processing_time:.2f}s")
# Determine winner
winner = "VARIANT WINS" if probability > 0.5 else "CONTROL WINS"
confidence_percentage = confidence_data['accuracy'] * 100
# Create enhanced output with confidence scores and training data info
result = {
"probability": f"{probability:.3f}",
"modelConfidence": f"{confidence_percentage:.1f}",
"trainingDataSamples": confidence_data['training_data_count'],
"totalPredictions": confidence_data['count'],
"correctPredictions": confidence_data['correct_predictions'],
"totalWinPrediction": confidence_data['actual_wins'],
"totalLosePrediction": confidence_data['count'] - confidence_data['actual_wins']
}
print(f"๐ฏ Final result: {result}")
return result
except Exception as e:
print(f"โ ERROR in predict_single: {e}")
print(f"๐ Error type: {type(e).__name__}")
import traceback
traceback.print_exc()
# Return error result with fallback confidence data
return {
"error": f"Prediction failed: {str(e)}",
"modelConfidence": "50.0",
"trainingDataSamples": 0,
"totalPredictions": 0,
"correctPredictions": 0,
"totalWinPrediction": 0,
"totalLosePrediction": 0
}
def get_all_possible_values():
"""
Get all possible values (both specific and grouped) for industry, page_type, and conversion_type.
This is useful for API documentation and validation.
"""
all_values = {
"industry": [],
"page_type": [],
"conversion_type": []
}
# Get all industry values (both parent groups and specific values)
if "industry_mappings" in value_mappings:
for parent_group, child_values in value_mappings["industry_mappings"].items():
all_values["industry"].append(parent_group)
all_values["industry"].extend(child_values)
# Get all page type values
if "page_type_mappings" in value_mappings:
for parent_group, child_values in value_mappings["page_type_mappings"].items():
all_values["page_type"].append(parent_group)
all_values["page_type"].extend(child_values)
# Get all conversion type values
if "conversion_type_mappings" in value_mappings:
for parent_group, child_values in value_mappings["conversion_type_mappings"].items():
all_values["conversion_type"].append(parent_group)
all_values["conversion_type"].extend(child_values)
return all_values
@spaces.GPU
def predict_batch(csv_path, control_img_dir, variant_img_dir, num_samples):
"""
Handles batch prediction from a CSV file.
Note: CSV should contain grouped values (not specific values) for:
- grouped_conversion_type
- grouped_industry
- grouped_page_type
"""
if not all([csv_path, control_img_dir, variant_img_dir, num_samples]):
return pd.DataFrame({"Error": ["Please fill in all fields."]})
try:
df = pd.read_csv(csv_path)
except FileNotFoundError:
return pd.DataFrame({"Error": [f"CSV file not found at: {csv_path}"]})
except Exception as e:
return pd.DataFrame({"Error": [f"Failed to read CSV: {e}"]})
if num_samples > len(df):
print(f"โ ๏ธ Requested {num_samples} samples, but CSV only has {len(df)} rows. Using all rows.")
num_samples = len(df)
sample_df = df.sample(n=num_samples, random_state=42)
results = []
for _, row in sample_df.iterrows():
try:
# Construct image paths
c_path = get_image_path_from_url(row[CONTROL_IMAGE_URL_COLUMN], control_img_dir)
v_path = get_image_path_from_url(row[VARIANT_IMAGE_URL_COLUMN], variant_img_dir)
if not c_path or not os.path.exists(c_path):
raise FileNotFoundError(f"Control image not found: {c_path}")
if not v_path or not os.path.exists(v_path):
raise FileNotFoundError(f"Variant image not found: {v_path}")
# Get categorical features from the row (expects grouped values in CSV)
cat_features_from_row = [row[f] for f in CATEGORICAL_FEATURES]
# Use the core prediction logic
prediction = predict_single(
control_image=np.array(Image.open(c_path)),
variant_image=np.array(Image.open(v_path)),
business_model=cat_features_from_row[0],
customer_type=cat_features_from_row[1],
conversion_type=cat_features_from_row[2],
industry=cat_features_from_row[3],
page_type=cat_features_from_row[4]
)
result_row = row.to_dict()
result_row['predicted_win_probability'] = prediction.get('Win', 0.0)
results.append(result_row)
except Exception as e:
print(f"๐ Error processing row: {e}")
error_row = row.to_dict()
error_row['predicted_win_probability'] = f"ERROR: {e}"
results.append(error_row)
return pd.DataFrame(results)
# --- 5. Build the Gradio Interface ---
with gr.Blocks() as iface:
gr.Markdown("# ๐ Multimodal A/B Test Predictor")
gr.Markdown("""
### Predict A/B test outcomes using:
- ๐ผ๏ธ **Image Analysis**: Visual features from control & variant images
- ๐ **OCR Text Extraction**: Automatically extracts and analyzes text from images
- ๐ **Categorical Features**: Business context (industry, page type, etc.)
- ๐ฏ **Smart Confidence Scores**: Based on Industry + Page Type combinations with high sample counts
**Enhanced Reliability**: Confidence scores use Industry + Page Type combinations (avg 160 samples) instead of low-count 5-feature combinations!
""")
with gr.Tab("๐ฏ API Prediction"):
gr.Markdown("### ๐ Predict with Categorical Data")
gr.Markdown("""
Upload images and provide categorical data for prediction.
**Note:** For Industry, Page Type, and Conversion Type, you can provide either:
- Specific values (e.g., "Accounting Services") - will be automatically converted to parent group (e.g., "B2B Services")
- Parent group values (e.g., "B2B Services") - will be used directly
The model uses parent groups internally, but the API accepts both for convenience.
""")
with gr.Row():
with gr.Column():
api_control_image = gr.Image(label="Control Image", type="numpy")
api_variant_image = gr.Image(label="Variant Image", type="numpy")
with gr.Column():
api_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0])
api_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0])
api_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0])
api_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0])
api_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0])
api_predict_btn = gr.Button("๐ฏ Predict with Categorical Data", variant="primary", size="lg")
api_output_json = gr.JSON(label="๐ฏ Prediction Results with Confidence Scores")
with gr.Tab("๐ Manual Selection"):
gr.Markdown("### Manual Category Selection")
gr.Markdown("Select categories manually if you prefer precise control.")
with gr.Row():
with gr.Column():
s_control_image = gr.Image(label="Control Image", type="numpy")
s_variant_image = gr.Image(label="Variant Image", type="numpy")
with gr.Column():
s_business_model = gr.Dropdown(choices=category_mappings["Business Model"]['categories'], label="Business Model", value=category_mappings["Business Model"]['categories'][0])
s_customer_type = gr.Dropdown(choices=category_mappings["Customer Type"]['categories'], label="Customer Type", value=category_mappings["Customer Type"]['categories'][0])
s_conversion_type = gr.Dropdown(choices=category_mappings["grouped_conversion_type"]['categories'], label="Conversion Type", value=category_mappings["grouped_conversion_type"]['categories'][0])
s_industry = gr.Dropdown(choices=category_mappings["grouped_industry"]['categories'], label="Industry", value=category_mappings["grouped_industry"]['categories'][0])
s_page_type = gr.Dropdown(choices=category_mappings["grouped_page_type"]['categories'], label="Page Type", value=category_mappings["grouped_page_type"]['categories'][0])
s_predict_btn = gr.Button("๐ฎ Predict A/B Test Winner", variant="secondary")
s_output_label = gr.Label(num_top_classes=6, label="๐ฏ Prediction Results & Confidence Analysis")
with gr.Tab("Batch Prediction from CSV"):
gr.Markdown("Provide paths to your data to get predictions for multiple random samples.")
b_csv_path = gr.Textbox(label="Path to CSV file", placeholder="/path/to/your/data.csv")
b_control_dir = gr.Textbox(label="Path to Control Images Folder", placeholder="/path/to/control_images/")
b_variant_dir = gr.Textbox(label="Path to Variant Images Folder", placeholder="/path/to/variant_images/")
b_num_samples = gr.Number(label="Number of random samples to predict", value=10)
b_predict_btn = gr.Button("Run Batch Prediction")
b_output_df = gr.DataFrame(label="Batch Prediction Results")
# Wire up the components
api_predict_btn.click(
fn=predict_with_categorical_data,
inputs=[api_control_image, api_variant_image, api_business_model, api_customer_type, api_conversion_type, api_industry, api_page_type],
outputs=api_output_json
)
s_predict_btn.click(
fn=predict_single,
inputs=[s_control_image, s_variant_image, s_business_model, s_customer_type, s_conversion_type, s_industry, s_page_type],
outputs=s_output_label
)
b_predict_btn.click(
fn=predict_batch,
inputs=[b_csv_path, b_control_dir, b_variant_dir, b_num_samples],
outputs=b_output_df
)
# Launch the application
if __name__ == "__main__":
# AGGRESSIVE optimization for 4x L4 GPU - push to maximum limits
iface.queue(
max_size=128, # Much larger queue for heavy concurrent load
default_concurrency_limit=64 # Push all 4 GPUs to maximum capacity
).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True # Show detailed errors for debugging
)
|