|
|
|
|
|
""" |
|
|
Streamlit Cloud Configuration Module |
|
|
|
|
|
Configures HuggingFace model caching for Streamlit Cloud deployment. |
|
|
Ensures models are cached persistently and preloaded to avoid downloads on each session. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _is_streamlit_cloud(): |
|
|
""" |
|
|
Detect if we're running on Streamlit Cloud. |
|
|
|
|
|
Returns True if running on Streamlit Cloud, False for local development. |
|
|
""" |
|
|
|
|
|
indicators = [ |
|
|
|
|
|
(Path("/app").exists() and os.access("/app", os.W_OK)), |
|
|
|
|
|
os.environ.get("STREAMLIT_SERVER_HEADLESS", "").lower() == "true", |
|
|
|
|
|
os.environ.get("HOME", "").startswith("/app"), |
|
|
] |
|
|
|
|
|
|
|
|
is_cloud = any(indicators) |
|
|
|
|
|
if is_cloud: |
|
|
logger.info("π Detected Streamlit Cloud environment") |
|
|
else: |
|
|
logger.info("π Detected local development environment") |
|
|
|
|
|
return is_cloud |
|
|
|
|
|
def configure_streamlit_cloud_cache(): |
|
|
""" |
|
|
Configure HuggingFace caching for Streamlit Cloud deployment. |
|
|
|
|
|
This ensures models are cached in a persistent location and preloaded |
|
|
to avoid downloading on every session restart. |
|
|
""" |
|
|
|
|
|
is_streamlit_cloud = _is_streamlit_cloud() |
|
|
|
|
|
if is_streamlit_cloud: |
|
|
|
|
|
cache_base = Path("/app/.cache/huggingface") |
|
|
|
|
|
|
|
|
try: |
|
|
cache_base.mkdir(parents=True, exist_ok=True) |
|
|
logger.info(f"β
Streamlit Cloud cache directory created: {cache_base}") |
|
|
except (OSError, PermissionError) as e: |
|
|
logger.warning(f"Could not create Streamlit Cloud cache directory: {e}") |
|
|
logger.warning("Falling back to default HuggingFace cache") |
|
|
return |
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_HOME", str(cache_base)) |
|
|
os.environ.setdefault("HF_HUB_CACHE", str(cache_base / "hub")) |
|
|
os.environ.setdefault("HF_DATASETS_CACHE", str(cache_base / "datasets")) |
|
|
os.environ.setdefault("TRANSFORMERS_CACHE", str(cache_base / "transformers")) |
|
|
|
|
|
logger.info(f"β
Configured Streamlit Cloud HuggingFace cache: {cache_base}") |
|
|
else: |
|
|
|
|
|
logger.info("π Local development detected - using default HuggingFace cache") |
|
|
|
|
|
|
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "true") |
|
|
|
|
|
def preload_models(): |
|
|
""" |
|
|
Preload embedding and reranking models to ensure they're cached before app starts. |
|
|
|
|
|
This prevents download delays during user interactions. |
|
|
""" |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from app.core.model_cache import get_cached_embeddings, get_cached_cross_encoder |
|
|
|
|
|
logger.info("Preloading models for Streamlit Cloud...") |
|
|
|
|
|
|
|
|
embeddings = get_cached_embeddings("sentence-transformers/all-mpnet-base-v2") |
|
|
logger.info("β
Main embedding model preloaded") |
|
|
|
|
|
|
|
|
cross_encoder = get_cached_cross_encoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
|
logger.info("β
Cross-encoder model preloaded") |
|
|
|
|
|
|
|
|
test_text = "This is a test document for model validation." |
|
|
embeddings.embed_query(test_text) |
|
|
|
|
|
test_pairs = [[test_text, "This is a relevant query."]] |
|
|
cross_encoder.predict(test_pairs) |
|
|
|
|
|
logger.info("β
All models validated and cached") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to preload models: {e}") |
|
|
raise |
|
|
|
|
|
def initialize_for_streamlit_cloud(): |
|
|
""" |
|
|
Initialize the application for Streamlit Cloud deployment. |
|
|
|
|
|
This should be called at the very beginning of the main application file to ensure |
|
|
models are cached before any user interactions. |
|
|
""" |
|
|
logger.info("Initializing application...") |
|
|
|
|
|
|
|
|
configure_streamlit_cloud_cache() |
|
|
|
|
|
|
|
|
|
|
|
logger.info("βοΈ Skipping model preloading - models loaded on-demand for better compatibility") |
|
|
|
|
|
logger.info("Application initialization complete") |
|
|
|