Spaces:
Runtime error
Runtime error
| """ | |
| A/B Test Predictor - API Client Examples | |
| ========================================== | |
| This file shows how to send requests to the A/B Test Predictor API. | |
| """ | |
| # ============================================================================ | |
| # Option 1: Gradio Python Client (Recommended) | |
| # ============================================================================ | |
| from gradio_client import Client | |
| from PIL import Image | |
| import json | |
| # Initialize the client | |
| # For local deployment: | |
| client = Client("http://localhost:7860") | |
| # For Hugging Face Spaces deployment: | |
| # client = Client("your-username/ABTestPredictor") | |
| def predict_with_gradio_client(control_image_path, variant_image_path, | |
| business_model, customer_type, conversion_type, | |
| industry, page_type): | |
| """ | |
| Send prediction request using Gradio Client | |
| Args: | |
| control_image_path: Path to control image file | |
| variant_image_path: Path to variant image file | |
| business_model: One of ["E-Commerce", "Lead Generation", "Other*", "SaaS"] | |
| customer_type: One of ["B2B", "B2C", "Both", "Other*"] | |
| conversion_type: One of ["Direct Purchase", "High-Intent Lead Gen", | |
| "Info/Content Lead Gen", "Location Search", | |
| "Non-Profit/Community", "Other Conversion"] | |
| industry: One of the 14 industry categories | |
| page_type: One of ["Awareness & Discovery", "Consideration & Evaluation", | |
| "Conversion", "Internal & Navigation", "Post-Conversion & Other"] | |
| Returns: | |
| dict: Prediction results with confidence scores | |
| """ | |
| result = client.predict( | |
| control_image_path, # Control image file path | |
| variant_image_path, # Variant image file path | |
| business_model, # Business Model dropdown | |
| customer_type, # Customer Type dropdown | |
| conversion_type, # Conversion Type dropdown | |
| industry, # Industry dropdown | |
| page_type, # Page Type dropdown | |
| api_name="/predict_with_categorical_data" # The function endpoint | |
| ) | |
| return result | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Example 1: Basic prediction | |
| result = predict_with_gradio_client( | |
| control_image_path="path/to/control_image.jpg", | |
| variant_image_path="path/to/variant_image.jpg", | |
| business_model="SaaS", | |
| customer_type="B2B", | |
| conversion_type="High-Intent Lead Gen", | |
| industry="B2B Software & Tech", | |
| page_type="Awareness & Discovery" | |
| ) | |
| print("Prediction Results:") | |
| print(json.dumps(result, indent=2)) | |
| # Access specific fields | |
| win_probability = result['predictionResults']['probability'] | |
| confidence = result['predictionResults']['modelConfidence'] | |
| print(f"\nWin Probability: {win_probability}") | |
| print(f"Model Confidence: {confidence}%") | |
| # ============================================================================ | |
| # Option 2: Direct HTTP POST Request (cURL equivalent in Python) | |
| # ============================================================================ | |
| import requests | |
| import base64 | |
| def predict_with_http_request(control_image_path, variant_image_path, | |
| business_model, customer_type, conversion_type, | |
| industry, page_type, api_url="http://localhost:7860"): | |
| """ | |
| Send prediction request using direct HTTP POST | |
| Note: This requires converting images to base64 for Gradio's API format | |
| """ | |
| # Read and encode images | |
| with open(control_image_path, "rb") as f: | |
| control_b64 = base64.b64encode(f.read()).decode() | |
| with open(variant_image_path, "rb") as f: | |
| variant_b64 = base64.b64encode(f.read()).decode() | |
| # Prepare the request payload (Gradio format) | |
| payload = { | |
| "data": [ | |
| f"data:image/jpeg;base64,{control_b64}", # Control image | |
| f"data:image/jpeg;base64,{variant_b64}", # Variant image | |
| business_model, | |
| customer_type, | |
| conversion_type, | |
| industry, | |
| page_type | |
| ] | |
| } | |
| # Send POST request to Gradio API | |
| response = requests.post( | |
| f"{api_url}/api/predict", | |
| json=payload, | |
| headers={"Content-Type": "application/json"} | |
| ) | |
| if response.status_code == 200: | |
| return response.json()['data'][0] # Gradio wraps response in 'data' array | |
| else: | |
| raise Exception(f"API request failed: {response.status_code} - {response.text}") | |
| # ============================================================================ | |
| # Option 3: Using PIL Images (in-memory) | |
| # ============================================================================ | |
| import numpy as np | |
| from PIL import Image | |
| def predict_with_pil_images(control_img, variant_img, | |
| business_model, customer_type, conversion_type, | |
| industry, page_type): | |
| """ | |
| Send prediction with PIL Image objects (useful for programmatic image generation) | |
| Args: | |
| control_img: PIL Image object | |
| variant_img: PIL Image object | |
| """ | |
| # Convert PIL images to numpy arrays (Gradio expects numpy arrays) | |
| control_array = np.array(control_img) | |
| variant_array = np.array(variant_img) | |
| # Use the Gradio client | |
| result = client.predict( | |
| control_array, | |
| variant_array, | |
| business_model, | |
| customer_type, | |
| conversion_type, | |
| industry, | |
| page_type, | |
| api_name="/predict_with_categorical_data" | |
| ) | |
| return result | |
| # Example with PIL | |
| if __name__ == "__main__": | |
| # Load images using PIL | |
| control_img = Image.open("control.jpg") | |
| variant_img = Image.open("variant.jpg") | |
| result = predict_with_pil_images( | |
| control_img=control_img, | |
| variant_img=variant_img, | |
| business_model="SaaS", | |
| customer_type="B2B", | |
| conversion_type="High-Intent Lead Gen", | |
| industry="B2B Software & Tech", | |
| page_type="Awareness & Discovery" | |
| ) | |
| # ============================================================================ | |
| # Option 4: Batch Processing Multiple Tests | |
| # ============================================================================ | |
| def batch_predict(test_cases, output_file="results.json"): | |
| """ | |
| Process multiple A/B tests in batch | |
| Args: | |
| test_cases: List of dicts with test parameters | |
| output_file: Where to save results | |
| Example test_cases: | |
| [ | |
| { | |
| "control_image": "test1_control.jpg", | |
| "variant_image": "test1_variant.jpg", | |
| "business_model": "SaaS", | |
| "customer_type": "B2B", | |
| "conversion_type": "High-Intent Lead Gen", | |
| "industry": "B2B Software & Tech", | |
| "page_type": "Awareness & Discovery" | |
| }, | |
| # ... more tests | |
| ] | |
| """ | |
| results = [] | |
| for i, test in enumerate(test_cases): | |
| print(f"Processing test {i+1}/{len(test_cases)}...") | |
| try: | |
| result = predict_with_gradio_client( | |
| control_image_path=test["control_image"], | |
| variant_image_path=test["variant_image"], | |
| business_model=test["business_model"], | |
| customer_type=test["customer_type"], | |
| conversion_type=test["conversion_type"], | |
| industry=test["industry"], | |
| page_type=test["page_type"] | |
| ) | |
| results.append({ | |
| "test_id": i + 1, | |
| "input": test, | |
| "prediction": result | |
| }) | |
| except Exception as e: | |
| print(f"Error processing test {i+1}: {e}") | |
| results.append({ | |
| "test_id": i + 1, | |
| "input": test, | |
| "error": str(e) | |
| }) | |
| # Save results | |
| with open(output_file, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nBatch processing complete! Results saved to {output_file}") | |
| return results | |
| # ============================================================================ | |
| # Valid Category Values (for reference) | |
| # ============================================================================ | |
| VALID_CATEGORIES = { | |
| "business_model": [ | |
| "E-Commerce", | |
| "Lead Generation", | |
| "Other*", | |
| "SaaS" | |
| ], | |
| "customer_type": [ | |
| "B2B", | |
| "B2C", | |
| "Both", | |
| "Other*" | |
| ], | |
| "conversion_type": [ | |
| "Direct Purchase", | |
| "High-Intent Lead Gen", | |
| "Info/Content Lead Gen", | |
| "Location Search", | |
| "Non-Profit/Community", | |
| "Other Conversion" | |
| ], | |
| "industry": [ | |
| "Automotive & Transportation", | |
| "B2B Services", | |
| "B2B Software & Tech", | |
| "Consumer Services", | |
| "Consumer Software & Apps", | |
| "Education", | |
| "Finance, Insurance & Real Estate", | |
| "Food, Hospitality & Travel", | |
| "Health & Wellness", | |
| "Industrial & Manufacturing", | |
| "Media & Entertainment", | |
| "Non-Profit & Government", | |
| "Other", | |
| "Retail & E-commerce" | |
| ], | |
| "page_type": [ | |
| "Awareness & Discovery", | |
| "Consideration & Evaluation", | |
| "Conversion", | |
| "Internal & Navigation", | |
| "Post-Conversion & Other" | |
| ] | |
| } | |
| def validate_categories(business_model, customer_type, conversion_type, | |
| industry, page_type): | |
| """Validate that all categories are valid""" | |
| errors = [] | |
| if business_model not in VALID_CATEGORIES["business_model"]: | |
| errors.append(f"Invalid business_model: {business_model}") | |
| if customer_type not in VALID_CATEGORIES["customer_type"]: | |
| errors.append(f"Invalid customer_type: {customer_type}") | |
| if conversion_type not in VALID_CATEGORIES["conversion_type"]: | |
| errors.append(f"Invalid conversion_type: {conversion_type}") | |
| if industry not in VALID_CATEGORIES["industry"]: | |
| errors.append(f"Invalid industry: {industry}") | |
| if page_type not in VALID_CATEGORIES["page_type"]: | |
| errors.append(f"Invalid page_type: {page_type}") | |
| if errors: | |
| raise ValueError("Category validation failed:\n" + "\n".join(errors)) | |
| return True | |
| # ============================================================================ | |
| # Error Handling Example | |
| # ============================================================================ | |
| def safe_predict(control_image_path, variant_image_path, | |
| business_model, customer_type, conversion_type, | |
| industry, page_type): | |
| """ | |
| Safe prediction with error handling and validation | |
| """ | |
| try: | |
| # Validate categories first | |
| validate_categories(business_model, customer_type, conversion_type, | |
| industry, page_type) | |
| # Make prediction | |
| result = predict_with_gradio_client( | |
| control_image_path=control_image_path, | |
| variant_image_path=variant_image_path, | |
| business_model=business_model, | |
| customer_type=customer_type, | |
| conversion_type=conversion_type, | |
| industry=industry, | |
| page_type=page_type | |
| ) | |
| return { | |
| "success": True, | |
| "result": result | |
| } | |
| except ValueError as e: | |
| return { | |
| "success": False, | |
| "error": "Validation Error", | |
| "message": str(e) | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": "API Error", | |
| "message": str(e) | |
| } | |