File size: 8,808 Bytes
c413127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import asyncio
import logging
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union

import fire
from langchain_core.embeddings import Embeddings

from src.chains import PresentationAnalysis
from src.config import EmbeddingConfig, Navigator, Provider
from src.config.logging import setup_logging
from src.rag.storage import (ChromaSlideStore, create_slides_database,
                             create_slides_database_async)

logger = logging.getLogger(__name__)


class Mode(str, Enum):
    """Available conversion modes"""

    FRESH = "fresh"  # Create new collection
    APPEND = "append"  # Add to existing collection


def load_openai_embeddings(
    provider: Provider, model_name: Optional[str] = "text-embedding-3-small"
) -> Embeddings:
    """Get embeddings model based on provider and name

    Args:
        provider: Provider type (vsegpt or openai)
        model_name: Optional model name override

    Returns:
        Configured embeddings model
    """
    config = EmbeddingConfig()
    model_name = model_name

    logger.info(f"Using {provider} embeddings model: {model_name}")

    if provider == Provider.VSEGPT:
        return config.load_vsegpt(model=model_name)
    elif provider == Provider.OPENAI:
        return config.load_openai(model=model_name)
    else:
        raise ValueError(f"Unknown provider: {provider}")


class FindPresentationJsons:
    """Helper class for finding presentation JSON files"""

    navigator: Navigator = Navigator()

    def find_jsons(
        self, patterns: Optional[List[str]] = None, base_dir: Optional[Path] = None
    ) -> List[Path]:
        """Find JSON files using patterns

        Args:
            patterns: List of substrings to search for, or None to get all JSONs
            base_dir: Directory to search in (defaults to interim)

        Returns:
            List of found JSON file paths
        """
        if base_dir is None:
            base_dir = self.navigator.interim

        if not patterns:
            # Get all JSONs from interim if no patterns specified
            return list(base_dir.rglob("*.json"))

        found_files = []
        for pattern in patterns:
            found = self.navigator.find_file_by_substr(
                substr=pattern, extension=".json", base_dir=base_dir, return_first=False
            )
            if found:
                found_files.extend(found)
            else:
                logger.warning(f"No JSONs found matching '{pattern}'")

        # Remove duplicates while preserving order
        return list(dict.fromkeys(found_files))


def process_presentations(
    json_paths: List[Path],
    collection_name: str = "pres1",
    mode: Mode = Mode.FRESH,
    embeddings: Optional[Embeddings] = None,
) -> None:
    """Process presentation JSONs into ChromaDB collection

    Args:
        json_paths: List of JSON file paths
        collection_name: Name for ChromaDB collection
        mode: Processing mode (fresh or append)
        embeddings: Optional embedding model (default OpenAI)
    """
    logger.info(f"Processing presentations in {mode} mode")
    logger.debug(f"JSON paths: {json_paths}")

    # Load presentations from JSONs
    presentations = []
    for path in json_paths:
        try:
            pres = PresentationAnalysis.load(path)
            presentations.append(pres)
            logger.info(f"Loaded presentation: {path.stem}")
        except Exception as e:
            logger.error(f"Failed to load {path}: {str(e)}")
            continue

    if not presentations:
        logger.error("No presentations loaded")
        return

    try:
        if mode == Mode.FRESH:
            logger.info(f"Creating new collection: {collection_name}")
            store = create_slides_database(
                presentations=presentations,
                collection_name=collection_name,
                embedding_model=embeddings,
            )
        else:
            logger.info(f"Adding to existing collection: {collection_name}")
            store = ChromaSlideStore(
                collection_name=collection_name, embedding_model=embeddings
            )
            for pres in presentations:
                for slide in pres.slides:
                    store.add_slide(slide)

        logger.info("Processing completed successfully")

    except Exception as e:
        logger.error("Processing failed", exc_info=True)


async def process_presentations_async(
    json_paths: List[Path],
    collection_name: str = "pres0",
    mode: Mode = Mode.FRESH,
    embeddings: Optional[Embeddings] = None,
    max_concurrent_slides: int = 5,
) -> None:
    """Process presentation JSONs into ChromaDB collection asynchronously"""
    logger.info(f"Processing presentations in {mode} mode")
    logger.debug(f"JSON paths: {json_paths}")

    # Load presentations from JSONs
    presentations = []
    for path in json_paths:
        try:
            pres = PresentationAnalysis.load(path)
            presentations.append(pres)
            logger.info(f"Loaded presentation: {path.stem}")
        except Exception as e:
            logger.error(f"Failed to load {path}: {str(e)}")
            continue

    if not presentations:
        logger.error("No presentations loaded")
        return

    try:
        if mode == Mode.FRESH:
            logger.info(f"Creating new collection: {collection_name}")
            store = await create_slides_database_async(
                presentations=presentations,
                collection_name=collection_name,
                embedding_model=embeddings,
                max_concurrent_slides=max_concurrent_slides,
            )
        else:
            logger.info(f"Adding to existing collection: {collection_name}")
            store = ChromaSlideStore(
                collection_name=collection_name,
                embedding_model=embeddings,
            )
            for pres in presentations:
                await store.process_presentation_async(
                    pres, max_concurrent=max_concurrent_slides
                )

        logger.info("Processing completed successfully")

    except Exception as e:
        logger.error("Processing failed", exc_info=True)


class ChromaCLI:
    """CLI for converting presentation JSONs to ChromaDB"""

    def __init__(self):
        """Initialize CLI with logging setup"""
        setup_logging(logger, Path("logs"))
        self.navigator = Navigator()
        self.finder = FindPresentationJsons()

    def convert(
        self,
        *patterns: str,
        collection: str = "pres1",
        mode: str = "fresh",
        provider: str = "openai",
        model_name: Optional[str] = "text-embedding-3-small",
        base_dir: Optional[str] = None,
        max_concurrent: int = 5,
    ) -> None:
        """Convert presentation JSONs to ChromaDB collection

        Args:
            *patterns: Optional patterns to search for specific JSONs
            collection: Name for ChromaDB collection
            mode: Processing mode ('fresh' or 'append')
            provider: Embedding provider ('vsegpt' or 'openai')
            model_name: Optional specific model name
            base_dir: Optional base directory to search in
        """
        try:
            mode = Mode(mode.lower())
            provider = Provider(provider.lower())
        except ValueError as e:
            logger.error(f"Invalid parameter: {str(e)}")
            return

        # Get embeddings model
        try:
            embeddings = load_openai_embeddings(provider, model_name)
        except Exception as e:
            logger.error(f"Failed to initialize embeddings: {str(e)}")
            return

        # Set base directory
        base_path = Path(base_dir) if base_dir else None

        # Find JSON files
        json_paths = self.finder.find_jsons(
            patterns=list(patterns) if patterns else None, base_dir=base_path
        )

        if not json_paths:
            logger.error("No JSON files found")
            return

        logger.info(f"Found {len(json_paths)} JSON files")
        logger.debug(f"Files: {[p.name for p in json_paths]}")

        try:
            asyncio.run(
                process_presentations_async(
                    json_paths=json_paths,
                    collection_name=collection,
                    mode=mode,
                    embeddings=embeddings,
                    max_concurrent_slides=max_concurrent,
                )
            )
        except KeyboardInterrupt:
            logger.warning("Processing interrupted by user")
        except Exception as e:
            logger.error("Processing failed with error", exc_info=True)


def main():
    """Entry point for Fire CLI"""
    fire.Fire(ChromaCLI)


if __name__ == "__main__":
    main()