dd-poc / scripts /streamlit_cloud_config.py
Juan Salas
πŸš€ Refactor: Move datasets to HuggingFace repos
9a71b8f
#!/usr/bin/env python3
"""
Streamlit Cloud Configuration Module
Configures HuggingFace model caching for Streamlit Cloud deployment.
Ensures models are cached persistently and preloaded to avoid downloads on each session.
"""
import os
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def _is_streamlit_cloud():
"""
Detect if we're running on Streamlit Cloud.
Returns True if running on Streamlit Cloud, False for local development.
"""
# Check for Streamlit Cloud environment indicators
indicators = [
# Check if /app directory exists and is writable
(Path("/app").exists() and os.access("/app", os.W_OK)),
# Check for Streamlit Cloud specific environment variables
os.environ.get("STREAMLIT_SERVER_HEADLESS", "").lower() == "true",
# Check for typical Streamlit Cloud paths
os.environ.get("HOME", "").startswith("/app"),
]
# Return True if any indicator suggests Streamlit Cloud
is_cloud = any(indicators)
if is_cloud:
logger.info("🌐 Detected Streamlit Cloud environment")
else:
logger.info("🏠 Detected local development environment")
return is_cloud
def configure_streamlit_cloud_cache():
"""
Configure HuggingFace caching for Streamlit Cloud deployment.
This ensures models are cached in a persistent location and preloaded
to avoid downloading on every session restart.
"""
# Detect if we're running on Streamlit Cloud
is_streamlit_cloud = _is_streamlit_cloud()
if is_streamlit_cloud:
# Use persistent cache directory for Streamlit Cloud
cache_base = Path("/app/.cache/huggingface")
# Ensure parent directories exist and are writable
try:
cache_base.mkdir(parents=True, exist_ok=True)
logger.info(f"βœ… Streamlit Cloud cache directory created: {cache_base}")
except (OSError, PermissionError) as e:
logger.warning(f"Could not create Streamlit Cloud cache directory: {e}")
logger.warning("Falling back to default HuggingFace cache")
return
# Configure HuggingFace environment variables for Streamlit Cloud
os.environ.setdefault("HF_HOME", str(cache_base))
os.environ.setdefault("HF_HUB_CACHE", str(cache_base / "hub"))
os.environ.setdefault("HF_DATASETS_CACHE", str(cache_base / "datasets"))
os.environ.setdefault("TRANSFORMERS_CACHE", str(cache_base / "transformers"))
logger.info(f"βœ… Configured Streamlit Cloud HuggingFace cache: {cache_base}")
else:
# Local development - use default HuggingFace cache, don't override
logger.info("🏠 Local development detected - using default HuggingFace cache")
# Enable tokenizers parallelism for better performance
os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")
def preload_models():
"""
Preload embedding and reranking models to ensure they're cached before app starts.
This prevents download delays during user interactions.
"""
import sys
from pathlib import Path
try:
# Import model cache functions directly
from app.core.model_cache import get_cached_embeddings, get_cached_cross_encoder
logger.info("Preloading models for Streamlit Cloud...")
# Preload main embedding model using cache
embeddings = get_cached_embeddings("sentence-transformers/all-mpnet-base-v2")
logger.info("βœ… Main embedding model preloaded")
# Preload cross-encoder for reranking using cache
cross_encoder = get_cached_cross_encoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
logger.info("βœ… Cross-encoder model preloaded")
# Test models with dummy data to ensure they're ready
test_text = "This is a test document for model validation."
embeddings.embed_query(test_text)
test_pairs = [[test_text, "This is a relevant query."]]
cross_encoder.predict(test_pairs)
logger.info("βœ… All models validated and cached")
except Exception as e:
logger.error(f"Failed to preload models: {e}")
raise
def initialize_for_streamlit_cloud():
"""
Initialize the application for Streamlit Cloud deployment.
This should be called at the very beginning of the main application file to ensure
models are cached before any user interactions.
"""
logger.info("Initializing application...")
# Configure caching
configure_streamlit_cloud_cache()
# Skip model preloading to avoid PyTorch device placement issues in containers
# Models will be loaded on-demand via model_cache.py
logger.info("⏭️ Skipping model preloading - models loaded on-demand for better compatibility")
logger.info("Application initialization complete")