ABTestPredictor / API_CLIENT_EXAMPLES.py
nitish-spz's picture
build error - fix 1
e93a798
"""
A/B Test Predictor - API Client Examples
==========================================
This file shows how to send requests to the A/B Test Predictor API.
"""
# ============================================================================
# Option 1: Gradio Python Client (Recommended)
# ============================================================================
from gradio_client import Client
from PIL import Image
import json
# Initialize the client
# For local deployment:
client = Client("http://localhost:7860")
# For Hugging Face Spaces deployment:
# client = Client("your-username/ABTestPredictor")
def predict_with_gradio_client(control_image_path, variant_image_path,
business_model, customer_type, conversion_type,
industry, page_type):
"""
Send prediction request using Gradio Client
Args:
control_image_path: Path to control image file
variant_image_path: Path to variant image file
business_model: One of ["E-Commerce", "Lead Generation", "Other*", "SaaS"]
customer_type: One of ["B2B", "B2C", "Both", "Other*"]
conversion_type: One of ["Direct Purchase", "High-Intent Lead Gen",
"Info/Content Lead Gen", "Location Search",
"Non-Profit/Community", "Other Conversion"]
industry: One of the 14 industry categories
page_type: One of ["Awareness & Discovery", "Consideration & Evaluation",
"Conversion", "Internal & Navigation", "Post-Conversion & Other"]
Returns:
dict: Prediction results with confidence scores
"""
result = client.predict(
control_image_path, # Control image file path
variant_image_path, # Variant image file path
business_model, # Business Model dropdown
customer_type, # Customer Type dropdown
conversion_type, # Conversion Type dropdown
industry, # Industry dropdown
page_type, # Page Type dropdown
api_name="/predict_with_categorical_data" # The function endpoint
)
return result
# Example usage
if __name__ == "__main__":
# Example 1: Basic prediction
result = predict_with_gradio_client(
control_image_path="path/to/control_image.jpg",
variant_image_path="path/to/variant_image.jpg",
business_model="SaaS",
customer_type="B2B",
conversion_type="High-Intent Lead Gen",
industry="B2B Software & Tech",
page_type="Awareness & Discovery"
)
print("Prediction Results:")
print(json.dumps(result, indent=2))
# Access specific fields
win_probability = result['predictionResults']['probability']
confidence = result['predictionResults']['modelConfidence']
print(f"\nWin Probability: {win_probability}")
print(f"Model Confidence: {confidence}%")
# ============================================================================
# Option 2: Direct HTTP POST Request (cURL equivalent in Python)
# ============================================================================
import requests
import base64
def predict_with_http_request(control_image_path, variant_image_path,
business_model, customer_type, conversion_type,
industry, page_type, api_url="http://localhost:7860"):
"""
Send prediction request using direct HTTP POST
Note: This requires converting images to base64 for Gradio's API format
"""
# Read and encode images
with open(control_image_path, "rb") as f:
control_b64 = base64.b64encode(f.read()).decode()
with open(variant_image_path, "rb") as f:
variant_b64 = base64.b64encode(f.read()).decode()
# Prepare the request payload (Gradio format)
payload = {
"data": [
f"data:image/jpeg;base64,{control_b64}", # Control image
f"data:image/jpeg;base64,{variant_b64}", # Variant image
business_model,
customer_type,
conversion_type,
industry,
page_type
]
}
# Send POST request to Gradio API
response = requests.post(
f"{api_url}/api/predict",
json=payload,
headers={"Content-Type": "application/json"}
)
if response.status_code == 200:
return response.json()['data'][0] # Gradio wraps response in 'data' array
else:
raise Exception(f"API request failed: {response.status_code} - {response.text}")
# ============================================================================
# Option 3: Using PIL Images (in-memory)
# ============================================================================
import numpy as np
from PIL import Image
def predict_with_pil_images(control_img, variant_img,
business_model, customer_type, conversion_type,
industry, page_type):
"""
Send prediction with PIL Image objects (useful for programmatic image generation)
Args:
control_img: PIL Image object
variant_img: PIL Image object
"""
# Convert PIL images to numpy arrays (Gradio expects numpy arrays)
control_array = np.array(control_img)
variant_array = np.array(variant_img)
# Use the Gradio client
result = client.predict(
control_array,
variant_array,
business_model,
customer_type,
conversion_type,
industry,
page_type,
api_name="/predict_with_categorical_data"
)
return result
# Example with PIL
if __name__ == "__main__":
# Load images using PIL
control_img = Image.open("control.jpg")
variant_img = Image.open("variant.jpg")
result = predict_with_pil_images(
control_img=control_img,
variant_img=variant_img,
business_model="SaaS",
customer_type="B2B",
conversion_type="High-Intent Lead Gen",
industry="B2B Software & Tech",
page_type="Awareness & Discovery"
)
# ============================================================================
# Option 4: Batch Processing Multiple Tests
# ============================================================================
def batch_predict(test_cases, output_file="results.json"):
"""
Process multiple A/B tests in batch
Args:
test_cases: List of dicts with test parameters
output_file: Where to save results
Example test_cases:
[
{
"control_image": "test1_control.jpg",
"variant_image": "test1_variant.jpg",
"business_model": "SaaS",
"customer_type": "B2B",
"conversion_type": "High-Intent Lead Gen",
"industry": "B2B Software & Tech",
"page_type": "Awareness & Discovery"
},
# ... more tests
]
"""
results = []
for i, test in enumerate(test_cases):
print(f"Processing test {i+1}/{len(test_cases)}...")
try:
result = predict_with_gradio_client(
control_image_path=test["control_image"],
variant_image_path=test["variant_image"],
business_model=test["business_model"],
customer_type=test["customer_type"],
conversion_type=test["conversion_type"],
industry=test["industry"],
page_type=test["page_type"]
)
results.append({
"test_id": i + 1,
"input": test,
"prediction": result
})
except Exception as e:
print(f"Error processing test {i+1}: {e}")
results.append({
"test_id": i + 1,
"input": test,
"error": str(e)
})
# Save results
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nBatch processing complete! Results saved to {output_file}")
return results
# ============================================================================
# Valid Category Values (for reference)
# ============================================================================
VALID_CATEGORIES = {
"business_model": [
"E-Commerce",
"Lead Generation",
"Other*",
"SaaS"
],
"customer_type": [
"B2B",
"B2C",
"Both",
"Other*"
],
"conversion_type": [
"Direct Purchase",
"High-Intent Lead Gen",
"Info/Content Lead Gen",
"Location Search",
"Non-Profit/Community",
"Other Conversion"
],
"industry": [
"Automotive & Transportation",
"B2B Services",
"B2B Software & Tech",
"Consumer Services",
"Consumer Software & Apps",
"Education",
"Finance, Insurance & Real Estate",
"Food, Hospitality & Travel",
"Health & Wellness",
"Industrial & Manufacturing",
"Media & Entertainment",
"Non-Profit & Government",
"Other",
"Retail & E-commerce"
],
"page_type": [
"Awareness & Discovery",
"Consideration & Evaluation",
"Conversion",
"Internal & Navigation",
"Post-Conversion & Other"
]
}
def validate_categories(business_model, customer_type, conversion_type,
industry, page_type):
"""Validate that all categories are valid"""
errors = []
if business_model not in VALID_CATEGORIES["business_model"]:
errors.append(f"Invalid business_model: {business_model}")
if customer_type not in VALID_CATEGORIES["customer_type"]:
errors.append(f"Invalid customer_type: {customer_type}")
if conversion_type not in VALID_CATEGORIES["conversion_type"]:
errors.append(f"Invalid conversion_type: {conversion_type}")
if industry not in VALID_CATEGORIES["industry"]:
errors.append(f"Invalid industry: {industry}")
if page_type not in VALID_CATEGORIES["page_type"]:
errors.append(f"Invalid page_type: {page_type}")
if errors:
raise ValueError("Category validation failed:\n" + "\n".join(errors))
return True
# ============================================================================
# Error Handling Example
# ============================================================================
def safe_predict(control_image_path, variant_image_path,
business_model, customer_type, conversion_type,
industry, page_type):
"""
Safe prediction with error handling and validation
"""
try:
# Validate categories first
validate_categories(business_model, customer_type, conversion_type,
industry, page_type)
# Make prediction
result = predict_with_gradio_client(
control_image_path=control_image_path,
variant_image_path=variant_image_path,
business_model=business_model,
customer_type=customer_type,
conversion_type=conversion_type,
industry=industry,
page_type=page_type
)
return {
"success": True,
"result": result
}
except ValueError as e:
return {
"success": False,
"error": "Validation Error",
"message": str(e)
}
except Exception as e:
return {
"success": False,
"error": "API Error",
"message": str(e)
}