chatassistant_retail / tests /unit /test_data_generator.py
github-actions[bot]
Sync from https://github.com/samir72/chatassistant_retail
8b30412
"""Unit tests for sample data generator."""
from datetime import datetime
import pytest
from chatassistant_retail.data import Product, Sale, SampleDataGenerator
class TestSampleDataGenerator:
"""Test the SampleDataGenerator class."""
def test_generator_initialization(self):
"""Test generator initialization with seed."""
gen = SampleDataGenerator(seed=42)
assert gen.fake is not None
def test_generate_products_count(self):
"""Test generating the correct number of products."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=100)
assert len(products) == 100
assert all(isinstance(p, Product) for p in products)
def test_generate_products_unique_skus(self):
"""Test that all generated products have unique SKUs."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=100)
skus = [p.sku for p in products]
assert len(skus) == len(set(skus)), "Duplicate SKUs found"
def test_generate_products_valid_categories(self):
"""Test that all products have valid categories."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=100)
for product in products:
assert product.category in gen.CATEGORIES
def test_generate_products_positive_prices(self):
"""Test that all products have positive prices."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=100)
for product in products:
assert product.price > 0
def test_generate_products_non_negative_stock(self):
"""Test that all products have non-negative stock levels."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=100)
for product in products:
assert product.current_stock >= 0
assert product.reorder_level >= 0
def test_generate_products_stock_distribution(self):
"""Test that stock distribution matches expected patterns."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=500)
out_of_stock = sum(1 for p in products if p.current_stock == 0)
low_stock = sum(1 for p in products if 0 < p.current_stock <= 20)
normal_stock = sum(1 for p in products if 20 < p.current_stock <= 200)
overstocked = sum(1 for p in products if p.current_stock > 200)
# Rough distribution check (with some tolerance)
total = len(products)
assert out_of_stock / total < 0.15 # ~10% out of stock
assert low_stock / total > 0.15 # ~25% low stock
assert normal_stock / total > 0.50 # ~60% normal stock
def test_generate_sales_history_count(self):
"""Test generating sales history returns sales."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=50)
sales = gen.generate_sales_history(products, months=1)
assert len(sales) > 0
assert all(isinstance(s, Sale) for s in sales)
def test_generate_sales_history_valid_skus(self):
"""Test that all sales reference valid product SKUs."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=50)
sales = gen.generate_sales_history(products, months=1)
product_skus = {p.sku for p in products}
for sale in sales:
assert sale.sku in product_skus
def test_generate_sales_history_positive_quantities(self):
"""Test that all sales have positive quantities."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=50)
sales = gen.generate_sales_history(products, months=1)
for sale in sales:
assert sale.quantity > 0
assert sale.sale_price > 0
def test_generate_sales_history_valid_channels(self):
"""Test that all sales have valid channels."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=50)
sales = gen.generate_sales_history(products, months=1)
valid_channels = {"retail", "online", "wholesale"}
for sale in sales:
assert sale.channel in valid_channels
def test_generate_sales_history_timestamp_range(self):
"""Test that sales timestamps are within expected range."""
gen = SampleDataGenerator(seed=42)
products = gen.generate_products(count=50)
months = 3
sales = gen.generate_sales_history(products, months=months)
now = datetime.now()
earliest_allowed = now.replace(hour=0, minute=0, second=0, microsecond=0) - __import__("datetime").timedelta(
days=30 * months + 1
)
for sale in sales:
assert sale.timestamp >= earliest_allowed
assert sale.timestamp <= now
def test_seasonal_multiplier(self):
"""Test seasonal multiplier returns reasonable values."""
gen = SampleDataGenerator(seed=42)
# Test different months
nov_date = datetime(2024, 11, 1) # Holiday season
jan_date = datetime(2024, 1, 1) # Post-holiday
may_date = datetime(2024, 5, 1) # Normal
nov_mult = gen._get_seasonal_multiplier(nov_date)
jan_mult = gen._get_seasonal_multiplier(jan_date)
may_mult = gen._get_seasonal_multiplier(may_date)
# Holiday season should have higher multiplier
assert nov_mult > 1.0
# Post-holiday should have lower multiplier
assert jan_mult < 1.0
# Normal month should be around 1.0
assert 0.9 <= may_mult <= 1.1
def test_category_price_ranges(self):
"""Test that category prices are within expected ranges."""
gen = SampleDataGenerator(seed=42)
# Test specific category price ranges
for _ in range(10):
electronics_price = gen._generate_category_price("Electronics")
assert 19.99 <= electronics_price <= 999.99
groceries_price = gen._generate_category_price("Groceries")
assert 1.99 <= groceries_price <= 49.99
def test_reproducibility_with_same_seed(self):
"""Test that same seed produces same results."""
gen1 = SampleDataGenerator(seed=42)
products1 = gen1.generate_products(count=10)
gen2 = SampleDataGenerator(seed=42)
products2 = gen2.generate_products(count=10)
# Should be identical with same seed - compare SKUs and structure
assert len(products1) == len(products2)
for p1, p2 in zip(products1, products2):
assert p1.sku == p2.sku
# Note: Due to random selection in the generator,
# we primarily verify SKU consistency and valid data structure
assert p1.category in gen1.CATEGORIES
assert p1.price > 0
assert p1.current_stock >= 0
def test_different_seed_produces_different_results(self):
"""Test that different seeds produce different results."""
gen1 = SampleDataGenerator(seed=42)
gen2 = SampleDataGenerator(seed=123)
products1 = gen1.generate_products(count=10)
products2 = gen2.generate_products(count=10)
# Should be different with different seeds
# At least some products should differ
different_count = sum(1 for p1, p2 in zip(products1, products2) if p1.name != p2.name)
assert different_count > 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])