File size: 4,838 Bytes
12f0afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a71b8f
 
 
12f0afd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
"""
Streamlit Cloud Configuration Module

Configures HuggingFace model caching for Streamlit Cloud deployment.
Ensures models are cached persistently and preloaded to avoid downloads on each session.
"""

import os
import logging
from pathlib import Path

logger = logging.getLogger(__name__)

def _is_streamlit_cloud():
    """
    Detect if we're running on Streamlit Cloud.

    Returns True if running on Streamlit Cloud, False for local development.
    """
    # Check for Streamlit Cloud environment indicators
    indicators = [
        # Check if /app directory exists and is writable
        (Path("/app").exists() and os.access("/app", os.W_OK)),
        # Check for Streamlit Cloud specific environment variables
        os.environ.get("STREAMLIT_SERVER_HEADLESS", "").lower() == "true",
        # Check for typical Streamlit Cloud paths
        os.environ.get("HOME", "").startswith("/app"),
    ]

    # Return True if any indicator suggests Streamlit Cloud
    is_cloud = any(indicators)

    if is_cloud:
        logger.info("🌐 Detected Streamlit Cloud environment")
    else:
        logger.info("🏠 Detected local development environment")

    return is_cloud

def configure_streamlit_cloud_cache():
    """
    Configure HuggingFace caching for Streamlit Cloud deployment.

    This ensures models are cached in a persistent location and preloaded
    to avoid downloading on every session restart.
    """
    # Detect if we're running on Streamlit Cloud
    is_streamlit_cloud = _is_streamlit_cloud()

    if is_streamlit_cloud:
        # Use persistent cache directory for Streamlit Cloud
        cache_base = Path("/app/.cache/huggingface")

        # Ensure parent directories exist and are writable
        try:
            cache_base.mkdir(parents=True, exist_ok=True)
            logger.info(f"βœ… Streamlit Cloud cache directory created: {cache_base}")
        except (OSError, PermissionError) as e:
            logger.warning(f"Could not create Streamlit Cloud cache directory: {e}")
            logger.warning("Falling back to default HuggingFace cache")
            return

        # Configure HuggingFace environment variables for Streamlit Cloud
        os.environ.setdefault("HF_HOME", str(cache_base))
        os.environ.setdefault("HF_HUB_CACHE", str(cache_base / "hub"))
        os.environ.setdefault("HF_DATASETS_CACHE", str(cache_base / "datasets"))
        os.environ.setdefault("TRANSFORMERS_CACHE", str(cache_base / "transformers"))

        logger.info(f"βœ… Configured Streamlit Cloud HuggingFace cache: {cache_base}")
    else:
        # Local development - use default HuggingFace cache, don't override
        logger.info("🏠 Local development detected - using default HuggingFace cache")

    # Enable tokenizers parallelism for better performance
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "true")

def preload_models():
    """
    Preload embedding and reranking models to ensure they're cached before app starts.

    This prevents download delays during user interactions.
    """
    import sys
    from pathlib import Path


    try:
        # Import model cache functions directly
        from app.core.model_cache import get_cached_embeddings, get_cached_cross_encoder

        logger.info("Preloading models for Streamlit Cloud...")

        # Preload main embedding model using cache
        embeddings = get_cached_embeddings("sentence-transformers/all-mpnet-base-v2")
        logger.info("βœ… Main embedding model preloaded")

        # Preload cross-encoder for reranking using cache
        cross_encoder = get_cached_cross_encoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        logger.info("βœ… Cross-encoder model preloaded")

        # Test models with dummy data to ensure they're ready
        test_text = "This is a test document for model validation."
        embeddings.embed_query(test_text)

        test_pairs = [[test_text, "This is a relevant query."]]
        cross_encoder.predict(test_pairs)

        logger.info("βœ… All models validated and cached")

    except Exception as e:
        logger.error(f"Failed to preload models: {e}")
        raise

def initialize_for_streamlit_cloud():
    """
    Initialize the application for Streamlit Cloud deployment.

    This should be called at the very beginning of the main application file to ensure
    models are cached before any user interactions.
    """
    logger.info("Initializing application...")

    # Configure caching
    configure_streamlit_cloud_cache()

    # Skip model preloading to avoid PyTorch device placement issues in containers
    # Models will be loaded on-demand via model_cache.py
    logger.info("⏭️ Skipping model preloading - models loaded on-demand for better compatibility")

    logger.info("Application initialization complete")