ShreyasGosavi's picture
Upload 37 files
53bec59 verified
"""
Production Configuration Management
Handles environment-based settings, secrets, and feature flags
"""
import os
from pathlib import Path
from typing import List, Optional
from pydantic import Field, PostgresDsn, RedisDsn, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application configuration with environment variable support"""
# Application
APP_NAME: str = "Multimodal Misinformation Detection API"
APP_VERSION: str = "1.0.0"
API_V1_PREFIX: str = "/api/v1"
DEBUG: bool = Field(default=False, validation_alias="DEBUG")
ENVIRONMENT: str = Field(default="production", validation_alias="ENVIRONMENT")
# Server
HOST: str = Field(default="0.0.0.0", validation_alias="HOST")
PORT: int = Field(default=8000, validation_alias="PORT")
WORKERS: int = Field(default=4, validation_alias="WORKERS")
RELOAD: bool = Field(default=False, validation_alias="RELOAD")
# Security
SECRET_KEY: str = Field(
default="CHANGE-ME-IN-PRODUCTION-USE-OPENSSL-RAND-HEX-32",
validation_alias="SECRET_KEY"
)
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
ALGORITHM: str = "HS256"
# CORS
BACKEND_CORS_ORIGINS: List[str] = Field(
default=["http://localhost:3000", "http://localhost:8000"],
validation_alias="BACKEND_CORS_ORIGINS"
)
@field_validator("BACKEND_CORS_ORIGINS", mode="before")
@classmethod
def parse_cors_origins(cls, v):
if isinstance(v, str):
return [origin.strip() for origin in v.split(",")]
return v
# Database
POSTGRES_SERVER: str = Field(default="localhost", validation_alias="POSTGRES_SERVER")
POSTGRES_USER: str = Field(default="postgres", validation_alias="POSTGRES_USER")
POSTGRES_PASSWORD: str = Field(default="postgres", validation_alias="POSTGRES_PASSWORD")
POSTGRES_DB: str = Field(default="misinformation_detection", validation_alias="POSTGRES_DB")
POSTGRES_PORT: int = Field(default=5432, validation_alias="POSTGRES_PORT")
DATABASE_URL: Optional[str] = None
@field_validator("DATABASE_URL", mode="before")
@classmethod
def assemble_db_connection(cls, v, info):
if isinstance(v, str) and v:
return v
data = info.data
return f"postgresql://{data.get('POSTGRES_USER')}:{data.get('POSTGRES_PASSWORD')}@{data.get('POSTGRES_SERVER')}:{data.get('POSTGRES_PORT')}/{data.get('POSTGRES_DB')}"
# Redis
REDIS_HOST: str = Field(default="localhost", validation_alias="REDIS_HOST")
REDIS_PORT: int = Field(default=6379, validation_alias="REDIS_PORT")
REDIS_PASSWORD: Optional[str] = Field(default=None, validation_alias="REDIS_PASSWORD")
REDIS_DB: int = Field(default=0, validation_alias="REDIS_DB")
REDIS_URL: Optional[str] = None
@field_validator("REDIS_URL", mode="before")
@classmethod
def assemble_redis_connection(cls, v, info):
if isinstance(v, str) and v:
return v
data = info.data
password_part = f":{data.get('REDIS_PASSWORD')}@" if data.get('REDIS_PASSWORD') else ""
return f"redis://{password_part}{data.get('REDIS_HOST')}:{data.get('REDIS_PORT')}/{data.get('REDIS_DB')}"
# Cache
CACHE_TTL: int = Field(default=3600, validation_alias="CACHE_TTL") # 1 hour
CACHE_PREDICTIONS: bool = Field(default=True, validation_alias="CACHE_PREDICTIONS")
# Rate Limiting
RATE_LIMIT_ENABLED: bool = Field(default=True, validation_alias="RATE_LIMIT_ENABLED")
RATE_LIMIT_PER_MINUTE: int = Field(default=60, validation_alias="RATE_LIMIT_PER_MINUTE")
RATE_LIMIT_PER_HOUR: int = Field(default=1000, validation_alias="RATE_LIMIT_PER_HOUR")
# File Upload
MAX_UPLOAD_SIZE: int = Field(default=10 * 1024 * 1024, validation_alias="MAX_UPLOAD_SIZE") # 10MB
ALLOWED_IMAGE_TYPES: List[str] = Field(
default=["image/jpeg", "image/png", "image/webp"],
validation_alias="ALLOWED_IMAGE_TYPES"
)
ALLOWED_VIDEO_TYPES: List[str] = Field(
default=["video/mp4", "video/mpeg", "video/quicktime"],
validation_alias="ALLOWED_VIDEO_TYPES"
)
# ML Models
MODEL_CACHE_DIR: Path = Field(
default=Path(__file__).parent.parent.parent / "models",
validation_alias="MODEL_CACHE_DIR"
)
DEVICE: str = Field(default="cpu", validation_alias="DEVICE") # cpu or cuda
BATCH_SIZE: int = Field(default=32, validation_alias="BATCH_SIZE")
# Model paths
DEEPFAKE_MODEL: str = Field(
default="timm/efficientnet_b4.ra2_in1k",
validation_alias="DEEPFAKE_MODEL"
)
TEXT_CLASSIFIER_MODEL: str = Field(
default="roberta-base",
validation_alias="TEXT_CLASSIFIER_MODEL"
)
PERPLEXITY_MODEL: str = Field(
default="gpt2",
validation_alias="PERPLEXITY_MODEL"
)
# Logging
LOG_LEVEL: str = Field(default="INFO", validation_alias="LOG_LEVEL")
LOG_FORMAT: str = Field(default="json", validation_alias="LOG_FORMAT") # json or text
LOG_FILE: Optional[Path] = Field(default=None, validation_alias="LOG_FILE")
# Monitoring
ENABLE_METRICS: bool = Field(default=True, validation_alias="ENABLE_METRICS")
ENABLE_TRACING: bool = Field(default=False, validation_alias="ENABLE_TRACING")
METRICS_PORT: int = Field(default=9090, validation_alias="METRICS_PORT")
# Feature Flags
ENABLE_VIDEO_ANALYSIS: bool = Field(default=True, validation_alias="ENABLE_VIDEO_ANALYSIS")
ENABLE_AUDIO_ANALYSIS: bool = Field(default=True, validation_alias="ENABLE_AUDIO_ANALYSIS")
ENABLE_BATCH_PROCESSING: bool = Field(default=True, validation_alias="ENABLE_BATCH_PROCESSING")
ENABLE_ASYNC_TASKS: bool = Field(default=True, validation_alias="ENABLE_ASYNC_TASKS")
# Celery (for async tasks)
CELERY_BROKER_URL: Optional[str] = None
CELERY_RESULT_BACKEND: Optional[str] = None
@field_validator("CELERY_BROKER_URL", mode="before")
@classmethod
def set_celery_broker(cls, v, info):
if isinstance(v, str) and v:
return v
return info.data.get("REDIS_URL")
@field_validator("CELERY_RESULT_BACKEND", mode="before")
@classmethod
def set_celery_backend(cls, v, info):
if isinstance(v, str) and v:
return v
return info.data.get("REDIS_URL")
# Email (for notifications)
SMTP_HOST: Optional[str] = Field(default=None, validation_alias="SMTP_HOST")
SMTP_PORT: int = Field(default=587, validation_alias="SMTP_PORT")
SMTP_USER: Optional[str] = Field(default=None, validation_alias="SMTP_USER")
SMTP_PASSWORD: Optional[str] = Field(default=None, validation_alias="SMTP_PASSWORD")
EMAILS_FROM_EMAIL: Optional[str] = Field(default=None, validation_alias="EMAILS_FROM_EMAIL")
# Admin
FIRST_SUPERUSER_EMAIL: str = Field(
default="admin@example.com",
validation_alias="FIRST_SUPERUSER_EMAIL"
)
FIRST_SUPERUSER_PASSWORD: str = Field(
default="changeme",
validation_alias="FIRST_SUPERUSER_PASSWORD"
)
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=True,
extra="allow"
)
@property
def is_production(self) -> bool:
"""Check if running in production environment"""
return self.ENVIRONMENT.lower() == "production"
@property
def is_development(self) -> bool:
"""Check if running in development environment"""
return self.ENVIRONMENT.lower() == "development"
@property
def is_testing(self) -> bool:
"""Check if running in testing environment"""
return self.ENVIRONMENT.lower() == "testing"
# Global settings instance
settings = Settings()
# Validate critical production settings
def validate_production_config():
"""Validate that production settings are properly configured"""
if settings.is_production:
errors = []
if settings.SECRET_KEY == "CHANGE-ME-IN-PRODUCTION-USE-OPENSSL-RAND-HEX-32":
errors.append("SECRET_KEY must be changed in production")
if settings.FIRST_SUPERUSER_PASSWORD == "changeme":
errors.append("FIRST_SUPERUSER_PASSWORD must be changed in production")
if settings.DEBUG:
errors.append("DEBUG must be False in production")
if not settings.POSTGRES_PASSWORD or settings.POSTGRES_PASSWORD == "postgres":
errors.append("Strong POSTGRES_PASSWORD required in production")
if errors:
raise ValueError(
f"Production configuration errors:\n" + "\n".join(f" - {err}" for err in errors)
)
if __name__ == "__main__":
# Test configuration loading
print(f"Environment: {settings.ENVIRONMENT}")
print(f"Database URL: {settings.DATABASE_URL}")
print(f"Redis URL: {settings.REDIS_URL}")
print(f"Debug Mode: {settings.DEBUG}")
print(f"Rate Limiting: {settings.RATE_LIMIT_ENABLED}")