| | |
| | |
| | |
| | |
| |
|
| | import marimo |
| | import os |
| |
|
| | __generated_with = "0.13.0" |
| | app = marimo.App(width="full") |
| |
|
| | with app.setup: |
| | |
| | import marimo as mo |
| | from typing import Dict, Optional, List, Union, Any |
| | from ibm_watsonx_ai import APIClient, Credentials |
| | from pathlib import Path |
| | import pandas as pd |
| | import mimetypes |
| | import requests |
| | import zipfile |
| | import tempfile |
| | import certifi |
| | import base64 |
| | import polars |
| | import nltk |
| | import time |
| | import json |
| | import ast |
| | import io |
| | import re |
| |
|
| | def get_iam_token(api_key): |
| | return requests.post( |
| | 'https://iam.cloud.ibm.com/identity/token', |
| | headers={'Content-Type': 'application/x-www-form-urlencoded'}, |
| | data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key}, |
| | verify=certifi.where() |
| | ).json()['access_token'] |
| |
|
| | def setup_task_credentials(client): |
| | |
| | existing_credentials = client.task_credentials.get_details() |
| |
|
| | |
| | if "resources" in existing_credentials and existing_credentials["resources"]: |
| | for cred in existing_credentials["resources"]: |
| | cred_id = client.task_credentials.get_id(cred) |
| | client.task_credentials.delete(cred_id) |
| |
|
| | |
| | return client.task_credentials.store() |
| |
|
| | def get_cred_value(key, creds_var_name="baked_in_creds", default=""): |
| | """ |
| | Helper function to safely get a value from a credentials dictionary. |
| | |
| | Args: |
| | key: The key to look up in the credentials dictionary. |
| | creds_var_name: The variable name of the credentials dictionary. |
| | default: The default value to return if the key is not found. |
| | |
| | Returns: |
| | The value from the credentials dictionary if it exists and contains the key, |
| | otherwise returns the default value. |
| | """ |
| | |
| | if creds_var_name in globals(): |
| | creds_dict = globals()[creds_var_name] |
| | if isinstance(creds_dict, dict) and key in creds_dict: |
| | |
| | value = creds_dict[key] |
| | return "" if value is None else value |
| | return default |
| |
|
| | @app.cell |
| | def client_variables(client_instantiation_form): |
| | client_setup = client_instantiation_form.value or None |
| | |
| | |
| | if client_setup: |
| | wx_url = client_setup["wx_region"] if client_setup["wx_region"] else "EU" |
| | wx_api_key = client_setup["wx_api_key"].strip() if client_setup["wx_api_key"] else None |
| | os.environ["WATSONX_APIKEY"] = wx_api_key or "" |
| | |
| | project_id = client_setup["project_id"].strip() if client_setup["project_id"] else None |
| | space_id = client_setup["space_id"].strip() if client_setup["space_id"] else None |
| | else: |
| | os.environ["WATSONX_APIKEY"] = "" |
| | project_id = space_id = wx_api_key = wx_url = None |
| | return client_setup, project_id, space_id, wx_api_key, wx_url |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @app.cell |
| | def _(): |
| | from baked_in_credentials.creds import credentials |
| | from base_variables import wx_regions, wx_platform_url |
| | from helper_functions.helper_functions import wrap_with_spaces, get_key_by_value, markdown_spacing |
| | return ( |
| | credentials, |
| | wx_regions, |
| | wx_platform_url, |
| | wrap_with_spaces, |
| | get_key_by_value, |
| | markdown_spacing, |
| | ) |
| |
|
| | @app.cell |
| | def client_instantiation( |
| | APIClient, |
| | Credentials, |
| | client_setup, |
| | project_id, |
| | space_id, |
| | wx_api_key, |
| | wx_url, |
| | ): |
| | |
| | if client_setup: |
| | try: |
| | wx_credentials = Credentials(url=wx_url, api_key=wx_api_key) |
| | project_client = ( |
| | APIClient(credentials=wx_credentials, project_id=project_id) |
| | if project_id |
| | else None |
| | ) |
| | deployment_client = ( |
| | APIClient(credentials=wx_credentials, space_id=space_id) |
| | if space_id |
| | else None |
| | ) |
| | instantiation_success = True |
| | instantiation_error = None |
| | except Exception as e: |
| | instantiation_success = False |
| | instantiation_error = str(e) |
| | wx_credentials = project_client = deployment_client = None |
| | else: |
| | wx_credentials = project_client = deployment_client = None |
| | instantiation_success = None |
| | instantiation_error = None |
| | |
| | return ( |
| | deployment_client, |
| | instantiation_error, |
| | instantiation_success, |
| | project_client, |
| | ) |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | mo.md( |
| | r""" |
| | #watsonx.ai Embedding Visualizer - Marimo Notebook |
| | |
| | #### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages. |
| | |
| | #### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want. |
| | <br> |
| | |
| | /// admonition |
| | Created by ***Milan Mrdenovic*** [milan.mrdenovic@ibm.com] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025* |
| | /// |
| | |
| | |
| | >Licensed under apache 2.0, users hold full accountability for any use or modification of the code. |
| | ><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter. |
| | |
| | <br> |
| | """ |
| | ) |
| | return |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""") |
| | return |
| |
|
| |
|
| | @app.cell |
| | def accordion_client_setup(client_selector, client_stack): |
| | ui_accordion_part_1_1 = mo.accordion( |
| | { |
| | "Instantiate Client": mo.vstack([client_stack, client_selector], align="center"), |
| | } |
| | ) |
| |
|
| | ui_accordion_part_1_1 |
| | return |
| |
|
| |
|
| | @app.cell |
| | def accordion_file_upload(select_stack): |
| | ui_accordion_part_1_2 = mo.accordion( |
| | { |
| | "Select Model & Upload Files": select_stack |
| | } |
| | ) |
| |
|
| | ui_accordion_part_1_2 |
| | return |
| |
|
| |
|
| | @app.cell |
| | def loaded_texts( |
| | create_temp_files_from_uploads, |
| | file_loader, |
| | pdf_reader, |
| | run_upload_button, |
| | set_text_state, |
| | ): |
| | if file_loader.value is not None and run_upload_button.value: |
| | filepaths = create_temp_files_from_uploads(file_loader.value) |
| | loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True) |
| |
|
| | set_text_state(loaded_texts) |
| | else: |
| | filepaths = None |
| | loaded_texts = None |
| | return |
| |
|
| |
|
| | @app.cell |
| | def accordion_chunker_setup(chunker_setup): |
| | ui_accordion_part_1_3 = mo.accordion( |
| | { |
| | "Chunker Setup": chunker_setup |
| | } |
| | ) |
| |
|
| | ui_accordion_part_1_3 |
| | return |
| |
|
| |
|
| | @app.cell |
| | def chunk_documents_to_nodes( |
| | get_text_state, |
| | sentence_splitter, |
| | sentence_splitter_config, |
| | set_chunk_state, |
| | ): |
| | if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None: |
| | chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True) |
| | set_chunk_state(chunked_texts) |
| | else: |
| | chunked_texts = None |
| | return (chunked_texts,) |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | mo.md(r"""###Part 2 - Query Setup and Visualization""") |
| | return |
| |
|
| |
|
| | @app.cell |
| | def accordion_chunk_range(chart_range_selection): |
| | ui_accordion_part_2_1 = mo.accordion( |
| | { |
| | "Chunk Range Selection": chart_range_selection |
| | } |
| | ) |
| | ui_accordion_part_2_1 |
| | return |
| |
|
| |
|
| | @app.cell |
| | def chunk_embedding( |
| | chunks_to_process, |
| | embedding, |
| | sentence_splitter_config, |
| | set_embedding_state, |
| | ): |
| | if sentence_splitter_config.value is not None and chunks_to_process is not None: |
| | with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner: |
| | output_embeddings = embedding.embed_documents(chunks_to_process) |
| | _spinner.update("Almost Done") |
| | time.sleep(1.5) |
| | set_embedding_state(output_embeddings) |
| | _spinner.update("Documents Embedded") |
| | else: |
| | output_embeddings = None |
| | return |
| |
|
| |
|
| | @app.cell |
| | def preview_chunks(chunks_dict): |
| | if chunks_dict is not None: |
| | stats = create_stats(chunks_dict, |
| | bordered=True, |
| | object_names=['text','text'], |
| | group_by_row=True, |
| | items_per_row=5, |
| | gap=1, |
| | label="Chunk") |
| | ui_chunk_viewer = mo.accordion( |
| | { |
| | "View Chunks": stats, |
| | } |
| | ) |
| | else: |
| | ui_chunk_viewer = None |
| |
|
| | ui_chunk_viewer |
| | return |
| |
|
| |
|
| | @app.cell |
| | def accordion_query_view(chart_visualization, query_stack): |
| | ui_accordion_part_2_2 = mo.accordion( |
| | { |
| | "Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3) |
| | } |
| | ) |
| | ui_accordion_part_2_2 |
| | return |
| |
|
| |
|
| | @app.cell |
| | def chunker_setup(sentence_splitter_config): |
| | chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55]) |
| | return (chunker_setup,) |
| |
|
| |
|
| | @app.cell |
| | def file_and_model_select( |
| | file_loader, |
| | get_embedding_model_list, |
| | run_upload_button, |
| | ): |
| | select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([mo.md("Drag & Drop or Double Click to select PDFs, then press **Load Files**"),file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3]) |
| | return (select_stack,) |
| |
|
| |
|
| | @app.cell |
| | def client_instantiation_form(): |
| | baked_in_creds = credentials |
| | |
| | client_instantiation_form = ( |
| | mo.md(''' |
| | ###**watsonx.ai credentials:** |
| | |
| | {wx_region} |
| | |
| | {wx_api_key} |
| | |
| | {project_id} |
| | |
| | {space_id} |
| | |
| | > You can add either a project_id, space_id or both, **only one is required**. |
| | > If you provide both you can switch the active one in the dropdown. |
| | ''') |
| | .batch( |
| | wx_region = mo.ui.dropdown( |
| | wx_regions, |
| | label="Select your watsonx.ai region:", |
| | value=get_cred_value('region', creds_var_name='baked_in_creds'), |
| | searchable=True |
| | ), |
| | wx_api_key = mo.ui.text( |
| | placeholder="Add your IBM Cloud api-key...", |
| | label="IBM Cloud Api-key:", |
| | kind="password", |
| | value=get_cred_value('api_key', creds_var_name='baked_in_creds') |
| | ), |
| | project_id = mo.ui.text( |
| | placeholder="Add your watsonx.ai project_id...", |
| | label="Project_ID:", |
| | kind="text", |
| | value=get_cred_value('project_id', creds_var_name='baked_in_creds') |
| | ), |
| | space_id = mo.ui.text( |
| | placeholder="Add your watsonx.ai space_id...", |
| | label="Space_ID:", |
| | kind="text", |
| | value=get_cred_value('space_id', creds_var_name='baked_in_creds') |
| | ) |
| | ,) |
| | .form(show_clear_button=True, bordered=False) |
| | ) |
| | return (client_instantiation_form,) |
| |
|
| |
|
| | @app.cell |
| | def instantiation_status( |
| | client_callout_kind, |
| | client_instantiation_form, |
| | client_status, |
| | ): |
| | client_callout = mo.callout(client_status, kind=client_callout_kind) |
| | client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10) |
| | return (client_stack,) |
| |
|
| | def _( |
| | client, |
| | client_key, |
| | client_options, |
| | client_selector, |
| | client_setup, |
| | get_key_by_value, |
| | instantiation_error, |
| | instantiation_success, |
| | mo, |
| | wrap_with_spaces, |
| | ): |
| | active_client_name = get_key_by_value(client_options, client_key) |
| | |
| | if client_setup: |
| | if instantiation_success: |
| | client_status = mo.md( |
| | f"### ✅ Client Instantiation Successful ✅\n\n" |
| | f"{client_selector}\n\n" |
| | f"**Active Client:**{wrap_with_spaces(active_client_name, prefix_spaces=5)}" |
| | ) |
| | client_callout_kind = "success" |
| | else: |
| | client_status = mo.md( |
| | f"### ❌ Client Instantiation Failed\n**Error:** {instantiation_error}\n\nCheck your region selection and credentials" |
| | ) |
| | client_callout_kind = "danger" |
| | else: |
| | client_status = mo.md( |
| | f"### Client Instantiation Status will turn Green When Ready\n\n" |
| | f"{client_selector}\n\n" |
| | f"**Active Client:**{wrap_with_spaces(active_client_name, prefix_spaces=5)}" |
| | ) |
| | client_callout_kind = "neutral" |
| | |
| | return active_client_name, client_callout_kind, client_status |
| |
|
| | @app.cell |
| | def client_selector(deployment_client, project_client): |
| | if deployment_client is not None: |
| | client_options = {"Deployment Client":deployment_client} |
| | |
| | elif project_client is not None: |
| | client_options = {"Project Client":project_client} |
| |
|
| | elif project_client is not None and deployment_client is not None: |
| | client_options = {"Project Client":project_client,"Deployment Client":deployment_client} |
| | |
| | else: |
| | client_options = {"No Client": "Instantiate a Client"} |
| |
|
| | default_client = next(iter(client_options)) |
| | client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Select your active client:**") |
| | return client_options, client_selector |
| |
|
| |
|
| | @app.cell |
| | def active_client(client_selector): |
| | client_key = client_selector.value |
| | if client_key == "Instantiate a Client": |
| | client = None |
| | else: |
| | client = client_key |
| | return client, client_key |
| |
|
| | @app.cell |
| | def emb_model_selection(client, set_embedding_model_list): |
| | if client is not None: |
| | model_specs = client.foundation_models.get_embeddings_model_specs() |
| | |
| | resources = model_specs["resources"] |
| | |
| | embedding_models = { |
| | "ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384}, |
| | "ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768}, |
| | "ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768}, |
| | "ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768}, |
| | "ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384}, |
| | "ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384}, |
| | "sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384}, |
| | "sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384}, |
| | "intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024} |
| | } |
| |
|
| | |
| | model_id_list = [] |
| | for resource in resources: |
| | model_id_list.append(resource["model_id"]) |
| |
|
| | |
| | embedding_model_data = [] |
| | for model_id in model_id_list: |
| | model_entry = {"model_id": model_id} |
| |
|
| | |
| | if model_id in embedding_models: |
| | model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"] |
| | model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"] |
| | else: |
| | model_entry["max_tokens"] = 0 |
| | model_entry["embedding_dimensions"] = 0 |
| |
|
| | embedding_model_data.append(model_entry) |
| |
|
| | embedding_model_selection = mo.ui.table( |
| | embedding_model_data, |
| | selection="single", |
| | label="Select an embedding model to use.", |
| | page_size=30, |
| | initial_selection=[1] |
| | ) |
| | set_embedding_model_list(embedding_model_selection) |
| | else: |
| | default_model_data = [{ |
| | "model_id": "ibm/granite-embedding-107m-multilingual", |
| | "max_tokens": 512, |
| | "embedding_dimensions": 384 |
| | }] |
| |
|
| | set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use.")) |
| | return |
| |
|
| |
|
| | @app.function |
| | def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."): |
| | embedding_model_selection = mo.ui.table( |
| | model_data, |
| | selection=selection_type, |
| | label=label, |
| | page_size=30, |
| | initial_selection=[initial_selection] |
| | ) |
| | return embedding_model_selection |
| |
|
| |
|
| | @app.cell |
| | def embedding_model(): |
| | get_embedding_model_list, set_embedding_model_list = mo.state(None) |
| | return get_embedding_model_list, set_embedding_model_list |
| |
|
| |
|
| | @app.cell |
| | def emb_model_parameters(emb_model_max_tk): |
| | from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams |
| | if embedding_model is not None: |
| | embed_params = { |
| | EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk, |
| | EmbedParams.RETURN_OPTIONS: { |
| | 'input_text': True |
| | } |
| | } |
| | else: |
| | embed_params = { |
| | EmbedParams.TRUNCATE_INPUT_TOKENS: 128, |
| | EmbedParams.RETURN_OPTIONS: { |
| | 'input_text': True |
| | } |
| | } |
| | return embed_params |
| |
|
| |
|
| | @app.cell |
| | def emb_model_state(get_embedding_model_list): |
| | embedding_model = get_embedding_model_list() |
| | return (embedding_model,) |
| |
|
| |
|
| | @app.cell |
| | def emb_model_setup(embedding_model): |
| | if embedding_model is not None: |
| | emb_model = embedding_model.value[0]['model_id'] |
| | emb_model_max_tk = embedding_model.value[0]['max_tokens'] |
| | emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions'] |
| | else: |
| | emb_model = None |
| | emb_model_max_tk = None |
| | emb_model_emb_dim = None |
| | return emb_model, emb_model_emb_dim, emb_model_max_tk |
| |
|
| |
|
| | @app.cell |
| | def emb_model_instantiation(client, emb_model, embed_params): |
| | from ibm_watsonx_ai.foundation_models import Embeddings |
| | if client is not None: |
| | embedding = Embeddings( |
| | model_id=emb_model, |
| | api_client=client, |
| | params=embed_params, |
| | batch_size=1000, |
| | concurrency_limit=10 |
| | ) |
| | else: |
| | embedding = None |
| | return (embedding,) |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | get_embedding_state, set_embedding_state = mo.state(None) |
| | return get_embedding_state, set_embedding_state |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | get_query_state, set_query_state = mo.state(None) |
| | return get_query_state, set_query_state |
| |
|
| |
|
| | @app.cell |
| | def file_loader_input(): |
| | file_loader = mo.ui.file( |
| | kind="area", |
| | filetypes=[".pdf"], |
| | label=" Load .pdf files ", |
| | multiple=True |
| | ) |
| | return (file_loader,) |
| |
|
| |
|
| | @app.cell |
| | def file_loader_run(file_loader): |
| | if file_loader.value: |
| | run_upload_button = mo.ui.run_button(label="Load Files") |
| | else: |
| | run_upload_button = mo.ui.run_button(disabled=True, label="Load Files") |
| | return (run_upload_button,) |
| |
|
| |
|
| | @app.cell |
| | def helper_function_tempfiles(): |
| | def create_temp_files_from_uploads(upload_results) -> List[str]: |
| | """ |
| | Creates temporary files from a tuple of FileUploadResults objects and returns their paths. |
| | Args: |
| | upload_results: Object containing a value attribute that is a tuple of FileUploadResults |
| | Returns: |
| | List of temporary file paths |
| | """ |
| | temp_file_paths = [] |
| |
|
| | |
| | num_items = len(upload_results) |
| |
|
| | |
| | for i in range(num_items): |
| | result = upload_results[i] |
| |
|
| | |
| | temp_dir = tempfile.gettempdir() |
| | file_name = result.name |
| | temp_path = os.path.join(temp_dir, file_name) |
| | |
| | with open(temp_path, 'wb') as temp_file: |
| | temp_file.write(result.contents) |
| | |
| | temp_file_paths.append(temp_path) |
| |
|
| | return temp_file_paths |
| |
|
| | def cleanup_temp_files(temp_file_paths: List[str]) -> None: |
| | """Delete temporary files after use.""" |
| | for path in temp_file_paths: |
| | if os.path.exists(path): |
| | os.unlink(path) |
| | return (create_temp_files_from_uploads,) |
| |
|
| |
|
| | @app.function |
| | def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True): |
| | """ |
| | Loads PDF data for each file path and organizes results by original filename. |
| | Args: |
| | pdf_reader: The PyMuPDFReader instance |
| | filepaths: List of temporary file paths |
| | file_loader_value: The original upload results value containing file information |
| | show_progress: Whether to show a progress bar during loading (default: False) |
| | Returns: |
| | Dictionary mapping original filenames to their loaded text content |
| | """ |
| | results = {} |
| |
|
| | |
| | if show_progress: |
| | import marimo as mo |
| | |
| | with mo.status.progress_bar( |
| | total=len(filepaths), |
| | title="Loading PDFs", |
| | subtitle="Processing documents...", |
| | completion_title="PDF Loading Complete", |
| | completion_subtitle=f"{len(filepaths)} documents processed", |
| | remove_on_exit=True |
| | ) as bar: |
| | |
| | for i, file_path in enumerate(filepaths): |
| |
|
| | original_file_name = file_loader_value[i].name |
| | bar.update(subtitle=f"Processing {original_file_name}...") |
| | loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) |
| |
|
| | |
| | results[original_file_name] = loaded_text |
| | |
| | bar.update(increment=1) |
| | else: |
| | |
| | for i, file_path in enumerate(filepaths): |
| | original_file_name = file_loader_value[i].name |
| | loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) |
| | results[original_file_name] = loaded_text |
| |
|
| | return results |
| |
|
| |
|
| | @app.cell |
| | def file_readers(): |
| | from llama_index.readers.file import PyMuPDFReader |
| | from llama_index.readers.file import FlatReader |
| | from llama_index.core.node_parser import SentenceSplitter |
| |
|
| | |
| | pdf_reader = PyMuPDFReader() |
| | |
| | return SentenceSplitter, pdf_reader |
| |
|
| |
|
| | @app.cell |
| | def sentence_splitter_setup(): |
| | |
| | sentence_splitter_config = ( |
| | mo.md(''' |
| | ###**Chunking Setup:** |
| | |
| | > Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results. |
| | |
| | Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**. |
| | |
| | {chunk_size} |
| | |
| | {chunk_overlap} |
| | |
| | {separator} {paragraph_separator} |
| | |
| | {secondary_chunking_regex} {include_metadata} |
| | |
| | ''') |
| | .batch( |
| | chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk Size:**", value=275, show_value=True, full_width=True), |
| | chunk_overlap = mo.ui.slider(start=0, stop=1000, step=1, label="**Chunk Overlap** *(Must always be smaller than Chunk Size)* **:**", value=0, show_value=True, full_width=True), |
| | separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "), |
| | paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator", |
| | label="**Paragraph Separator:**", kind="text", |
| | value="\n\n\n"), |
| | secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex", |
| | label="**Chunking Regex:**", kind="text", |
| | value="[^,.;?!]+[,.;?!]?"), |
| | include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**") |
| | ) |
| | .form(show_clear_button=True, bordered=False, submit_button_label="Chunk Documents") |
| | ) |
| | return (sentence_splitter_config,) |
| |
|
| |
|
| | @app.cell |
| | def sentence_splitter_instantiation( |
| | SentenceSplitter, |
| | sentence_splitter_config, |
| | ): |
| | |
| | def simple_whitespace_tokenizer(text): |
| | return text.split() |
| | |
| | if sentence_splitter_config.value is not None: |
| | sentence_splitter_config_values = sentence_splitter_config.value |
| | validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"), |
| | int(sentence_splitter_config_values.get("chunk_size") * 0.3)) |
| |
|
| | sentence_splitter = SentenceSplitter( |
| | chunk_size=sentence_splitter_config_values.get("chunk_size"), |
| | chunk_overlap=validated_chunk_overlap, |
| | separator=sentence_splitter_config_values.get("separator"), |
| | paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"), |
| | secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"), |
| | include_metadata=sentence_splitter_config_values.get("include_metadata"), |
| | tokenizer=simple_whitespace_tokenizer |
| | ) |
| |
|
| | else: |
| | sentence_splitter = SentenceSplitter( |
| | chunk_size=2048, |
| | chunk_overlap=204, |
| | separator=" ", |
| | paragraph_separator="\n\n\n", |
| | secondary_chunking_regex="[^,.;?!]+[,.;?!]?", |
| | include_metadata=True, |
| | tokenizer=simple_whitespace_tokenizer |
| | ) |
| | return (sentence_splitter,) |
| |
|
| |
|
| | @app.cell |
| | def text_state(): |
| | get_text_state, set_text_state = mo.state(None) |
| | return get_text_state, set_text_state |
| |
|
| |
|
| | @app.cell |
| | def chunk_state(): |
| | get_chunk_state, set_chunk_state = mo.state(None) |
| | return get_chunk_state, set_chunk_state |
| |
|
| |
|
| | @app.function |
| | def chunk_documents(loaded_texts, sentence_splitter, show_progress=True): |
| | """ |
| | Process each document in the loaded_texts dictionary using the sentence_splitter, |
| | with an optional marimo progress bar tracking progress at document level. |
| | |
| | Args: |
| | loaded_texts (dict): Dictionary containing lists of Document objects |
| | sentence_splitter: The sentence splitter object with get_nodes_from_documents method |
| | show_progress (bool): Whether to show a progress bar during processing |
| | |
| | Returns: |
| | dict: Dictionary with the same structure but containing chunked texts |
| | """ |
| | chunked_texts_dict = {} |
| |
|
| | |
| | total_docs = sum(len(docs) for docs in loaded_texts.values()) |
| | processed_docs = 0 |
| |
|
| | |
| | if show_progress: |
| | import marimo as mo |
| | |
| | with mo.status.progress_bar( |
| | total=total_docs, |
| | title="Processing Documents", |
| | subtitle="Chunking documents...", |
| | completion_title="Processing Complete", |
| | completion_subtitle=f"{total_docs} documents processed", |
| | remove_on_exit=True |
| | ) as bar: |
| | |
| | for key, documents in loaded_texts.items(): |
| | |
| | doc_count = len(documents) |
| | bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)") |
| |
|
| | |
| | chunked_texts = sentence_splitter.get_nodes_from_documents( |
| | documents, |
| | show_progress=False |
| | ) |
| |
|
| | |
| | chunked_texts_dict[key] = chunked_texts |
| | time.sleep(0.15) |
| |
|
| | |
| | bar.update(increment=doc_count) |
| | processed_docs += doc_count |
| | else: |
| | |
| | for key, documents in loaded_texts.items(): |
| | chunked_texts = sentence_splitter.get_nodes_from_documents( |
| | documents, |
| | show_progress=True |
| | ) |
| | chunked_texts_dict[key] = chunked_texts |
| |
|
| | return chunked_texts_dict |
| |
|
| |
|
| | @app.cell |
| | def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter): |
| | if chunked_texts is not None and sentence_splitter: |
| | chunked_documents = get_chunk_state() |
| | else: |
| | chunked_documents = None |
| | return (chunked_documents,) |
| |
|
| |
|
| | @app.cell |
| | def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi): |
| | if chunked_documents is not None: |
| | dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents) |
| | nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes) |
| | else: |
| | dict_from_nodes = None |
| | nodes_from_dict = None |
| | return (dict_from_nodes,) |
| |
|
| |
|
| | @app.cell |
| | def chunks_to_process( |
| | dict_from_nodes, |
| | document_range_stack, |
| | get_data_in_range_triplequote, |
| | ): |
| | if dict_from_nodes is not None and document_range_stack is not None: |
| |
|
| | chunk_dict_df = create_cumulative_dataframe(dict_from_nodes) |
| |
|
| | if document_range_stack.value is not None: |
| | chunk_start_idx = document_range_stack.value[0] |
| | chunk_end_idx = document_range_stack.value[1] |
| | else: |
| | chunk_start_idx = 0 |
| | chunk_end_idx = len(chunk_dict_df) |
| |
|
| | chunk_range_index = [chunk_start_idx, chunk_end_idx] |
| | chunks_dict = get_data_in_range_triplequote(chunk_dict_df, |
| | index_range=chunk_range_index, |
| | columns_to_include=["text"]) |
| |
|
| | chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else [] |
| | else: |
| | chunk_objects = None |
| | chunks_dict = None |
| | chunks_to_process = None |
| | return chunks_dict, chunks_to_process |
| |
|
| |
|
| | @app.cell |
| | def helper_function_doc_formatting(): |
| | def llamaindex_convert_docs_multi(items): |
| | """ |
| | Automatically convert between document objects and dictionaries. |
| | |
| | This function handles: |
| | - Converting dictionaries to document objects |
| | - Converting document objects to dictionaries |
| | - Processing lists or individual items |
| | - Supporting dictionary structures where values are lists of documents |
| | |
| | Args: |
| | items: A document object, dictionary, or list of either. |
| | Can also be a dictionary mapping filenames to lists of documents. |
| | |
| | Returns: |
| | Converted item(s) maintaining the original structure |
| | """ |
| | |
| | if not items: |
| | return [] |
| |
|
| | |
| | if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()): |
| | result = {} |
| | for filename, doc_list in items.items(): |
| | result[filename] = llamaindex_convert_docs(doc_list) |
| | return result |
| |
|
| | |
| | if not isinstance(items, list): |
| | |
| | if isinstance(items, dict): |
| | |
| | doc_class = None |
| | if 'doc_type' in items: |
| | import importlib |
| | module_path, class_name = items['doc_type'].rsplit('.', 1) |
| | module = importlib.import_module(module_path) |
| | doc_class = getattr(module, class_name) |
| | if not doc_class: |
| | from llama_index.core.schema import Document |
| | doc_class = Document |
| | return doc_class.from_dict(items) |
| | |
| | elif hasattr(items, 'to_dict'): |
| | return items.to_dict() |
| | |
| | return items |
| |
|
| | |
| | result = [] |
| |
|
| | |
| | if len(items) == 0: |
| | return result |
| |
|
| | |
| | first_item = next((item for item in items if item is not None), None) |
| |
|
| | |
| | if first_item is None: |
| | return result |
| |
|
| | |
| | if isinstance(first_item, dict): |
| | |
| | doc_class = None |
| | |
| | if 'doc_type' in first_item: |
| | import importlib |
| | module_path, class_name = first_item['doc_type'].rsplit('.', 1) |
| | module = importlib.import_module(module_path) |
| | doc_class = getattr(module, class_name) |
| | if not doc_class: |
| | |
| | from llama_index.core.schema import Document |
| | doc_class = Document |
| |
|
| | |
| | for item in items: |
| | if isinstance(item, dict): |
| | result.append(doc_class.from_dict(item)) |
| | elif item is None: |
| | result.append(None) |
| | elif isinstance(item, list): |
| | result.append(llamaindex_convert_docs(item)) |
| | else: |
| | result.append(item) |
| |
|
| | |
| | else: |
| | for item in items: |
| | if hasattr(item, 'to_dict'): |
| | result.append(item.to_dict()) |
| | elif item is None: |
| | result.append(None) |
| | elif isinstance(item, list): |
| | result.append(llamaindex_convert_docs(item)) |
| | else: |
| | result.append(item) |
| |
|
| | return result |
| |
|
| | def llamaindex_convert_docs(items): |
| | """ |
| | Automatically convert between document objects and dictionaries. |
| | |
| | Args: |
| | items: A list of document objects or dictionaries |
| | |
| | Returns: |
| | List of converted items (dictionaries or document objects) |
| | """ |
| | result = [] |
| |
|
| | |
| | if not items: |
| | return result |
| |
|
| | |
| | if isinstance(items[0], dict): |
| | |
| | |
| | doc_class = None |
| |
|
| | |
| | if 'doc_type' in items[0]: |
| | import importlib |
| | module_path, class_name = items[0]['doc_type'].rsplit('.', 1) |
| | module = importlib.import_module(module_path) |
| | doc_class = getattr(module, class_name) |
| |
|
| | if not doc_class: |
| | |
| | from llama_index.core.schema import Document |
| | doc_class = Document |
| |
|
| | |
| | for item in items: |
| | if isinstance(item, dict): |
| | result.append(doc_class.from_dict(item)) |
| | else: |
| | |
| | for item in items: |
| | if hasattr(item, 'to_dict'): |
| | result.append(item.to_dict()) |
| |
|
| | return result |
| | return (llamaindex_convert_docs_multi,) |
| |
|
| |
|
| | @app.cell |
| | def helper_function_create_df(): |
| | def create_document_dataframes(dict_from_docs): |
| | """ |
| | Creates a pandas DataFrame for each file in the dictionary. |
| | |
| | Args: |
| | dict_from_docs: Dictionary mapping filenames to lists of documents |
| | |
| | Returns: |
| | List of pandas DataFrames, each representing all documents from a single file |
| | """ |
| | dataframes = [] |
| |
|
| | for filename, docs in dict_from_docs.items(): |
| | |
| | file_records = [] |
| |
|
| | for i, doc in enumerate(docs): |
| | |
| | if hasattr(doc, 'to_dict'): |
| | doc_data = doc.to_dict() |
| | elif isinstance(doc, dict): |
| | doc_data = doc |
| | else: |
| | doc_data = {'content': str(doc)} |
| |
|
| | |
| | doc_data['doc_index'] = i |
| |
|
| | |
| | file_records.append(doc_data) |
| |
|
| | |
| | if file_records: |
| | df = pd.DataFrame(file_records) |
| | df['filename'] = filename |
| | dataframes.append(df) |
| |
|
| | return dataframes |
| |
|
| | def create_dataframe_previews(dataframe_list, page_size=5): |
| | """ |
| | Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list. |
| | |
| | Args: |
| | dataframe_list: List of pandas DataFrames (output from create_document_dataframes) |
| | page_size: Number of rows to show per page for each component |
| | |
| | Returns: |
| | List of mo.ui.dataframe components |
| | """ |
| | |
| | preview_components = [] |
| |
|
| | for df in dataframe_list: |
| | |
| | preview = mo.ui.dataframe(df, page_size=page_size) |
| | preview_components.append(preview) |
| |
|
| | return preview_components |
| | return |
| |
|
| |
|
| | @app.cell |
| | def helper_function_chart_preparation(): |
| | import altair as alt |
| | import numpy as np |
| | import plotly.express as px |
| | from sklearn.manifold import TSNE |
| |
|
| | def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None): |
| | """ |
| | Prepare embedding data for visualization |
| | |
| | Args: |
| | embeddings: List of embeddings arrays |
| | texts: List of text strings |
| | model_id: Embedding model ID (optional) |
| | embedding_dimensions: Embedding dimensions (optional) |
| | |
| | Returns: |
| | DataFrame with processed data and metadata |
| | """ |
| | |
| | flattened_embeddings = [] |
| | for emb in embeddings: |
| | if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): |
| | flattened_embeddings.append(emb[0]) |
| | else: |
| | flattened_embeddings.append(emb) |
| |
|
| | |
| | embedding_array = np.array(flattened_embeddings) |
| |
|
| | |
| | tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1)) |
| | reduced_embeddings = tsne.fit_transform(embedding_array) |
| |
|
| | |
| | truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts] |
| |
|
| | |
| | df = pd.DataFrame({ |
| | "x": reduced_embeddings[:, 0], |
| | "y": reduced_embeddings[:, 1], |
| | "text": truncated_texts, |
| | "full_text": texts, |
| | "index": range(len(texts)) |
| | }) |
| |
|
| | |
| | metadata = { |
| | "model_id": model_id, |
| | "embedding_dimensions": embedding_dimensions |
| | } |
| |
|
| | return df, metadata |
| |
|
| | def create_embedding_chart(df, metadata=None): |
| | """ |
| | Create an Altair chart for embedding visualization |
| | |
| | Args: |
| | df: DataFrame with x, y coordinates and text |
| | metadata: Dictionary with model_id and embedding_dimensions |
| | |
| | Returns: |
| | Altair chart |
| | """ |
| | model_id = metadata.get("model_id") if metadata else None |
| | embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None |
| |
|
| | selection = alt.selection_multi(fields=['index']) |
| |
|
| | base = alt.Chart(df).encode( |
| | x=alt.X("x:Q", title="Dimension 1"), |
| | y=alt.Y("y:Q", title="Dimension 2"), |
| | tooltip=["text", "index"] |
| | ) |
| |
|
| | points = base.mark_circle(size=100).encode( |
| | color=alt.Color("index:N", legend=None), |
| | opacity=alt.condition(selection, alt.value(1), alt.value(0.2)) |
| | ).add_selection(selection) |
| |
|
| | text = base.mark_text(align="left", dx=7).encode( |
| | text="index:N" |
| | ) |
| |
|
| | return (points + text).properties( |
| | width=700, |
| | height=500, |
| | title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}" |
| | ).interactive() |
| |
|
| | def show_selected_text(indices, texts): |
| | """ |
| | Create markdown display for selected texts |
| | |
| | Args: |
| | indices: List of selected indices |
| | texts: List of all texts |
| | |
| | Returns: |
| | Markdown string |
| | """ |
| | if not indices: |
| | return "No text selected" |
| |
|
| | selected_texts = [texts[i] for i in indices if i < len(texts)] |
| | return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)]) |
| |
|
| | def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None): |
| | """ |
| | Prepare embedding data for 3D visualization |
| | |
| | Args: |
| | embeddings: List of embeddings arrays |
| | texts: List of text strings |
| | model_id: Embedding model ID (optional) |
| | embedding_dimensions: Embedding dimensions (optional) |
| | |
| | Returns: |
| | DataFrame with processed data and metadata |
| | """ |
| | |
| | flattened_embeddings = [] |
| | for emb in embeddings: |
| | if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): |
| | flattened_embeddings.append(emb[0]) |
| | else: |
| | flattened_embeddings.append(emb) |
| |
|
| | |
| | embedding_array = np.array(flattened_embeddings) |
| |
|
| | |
| | if len(embedding_array) == 1: |
| | |
| | reduced_embeddings = np.array([[0.0, 0.0, 0.0]]) |
| | else: |
| | |
| | |
| | perplexity_value = max(1.0, min(30, len(embedding_array)-1)) |
| | tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value) |
| | reduced_embeddings = tsne.fit_transform(embedding_array) |
| |
|
| | |
| | formatted_texts = [] |
| | for text in texts: |
| | |
| | if len(text) > 500: |
| | text = text[:500] + "..." |
| |
|
| | |
| | wrapped_text = "" |
| | for i in range(0, len(text), 50): |
| | wrapped_text += text[i:i+50] + "<br>" |
| |
|
| | formatted_texts.append("<b>"+wrapped_text+"</b>") |
| |
|
| | |
| | df = pd.DataFrame({ |
| | "x": reduced_embeddings[:, 0], |
| | "y": reduced_embeddings[:, 1], |
| | "z": reduced_embeddings[:, 2], |
| | "text": formatted_texts, |
| | "full_text": texts, |
| | "index": range(len(texts)), |
| | "embedding": flattened_embeddings |
| | }) |
| |
|
| | |
| | metadata = { |
| | "model_id": model_id, |
| | "embedding_dimensions": embedding_dimensions |
| | } |
| |
|
| | return df, metadata |
| |
|
| | def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3): |
| | """ |
| | Create a 3D Plotly chart for embedding visualization with proximity-based coloring |
| | """ |
| | model_id = metadata.get("model_id") if metadata else None |
| | embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None |
| |
|
| | |
| | from scipy.spatial.distance import pdist, squareform |
| | |
| | coords = df[['x', 'y', 'z']].values |
| |
|
| | |
| | dist_matrix = squareform(pdist(coords)) |
| |
|
| | |
| | avg_distances = np.mean(dist_matrix, axis=1) |
| |
|
| | |
| | df['proximity'] = avg_distances |
| |
|
| | |
| | fig = px.scatter_3d( |
| | df, |
| | x='x', |
| | y='y', |
| | z='z', |
| | |
| | |
| | |
| | color='proximity', |
| | color_continuous_scale='Viridis_r', |
| | hover_data=['text', 'index', 'proximity'], |
| | labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'}, |
| | |
| | title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}", |
| | text='index', |
| | |
| | ) |
| |
|
| | |
| | |
| | fig.update_traces( |
| | marker=dict( |
| | size=marker_size_var, |
| | opacity=0.7, |
| | symbol="diamond", |
| | line=dict( |
| | width=0.5, |
| | color="white" |
| | ) |
| | ), |
| | textfont=dict( |
| | color="rgba(255, 255, 255, 0.3)", |
| | size=8 |
| | ), |
| | |
| | hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>", |
| | hoverinfo="text+name", |
| | hoverlabel=dict( |
| | bgcolor="white", |
| | font_size=12 |
| | ), |
| | selector=dict(type='scatter3d') |
| | ) |
| |
|
| | |
| | fig.update_layout( |
| | scene=dict( |
| | xaxis=dict( |
| | title='Dimension 1', |
| | nticks=40, |
| | backgroundcolor="rgb(10, 10, 20, 0.1)", |
| | gridcolor="white", |
| | showbackground=True, |
| | gridwidth=0.35, |
| | zerolinecolor="white", |
| | ), |
| | yaxis=dict( |
| | title='Dimension 2', |
| | nticks=40, |
| | backgroundcolor="rgb(10, 10, 20, 0.1)", |
| | gridcolor="white", |
| | showbackground=True, |
| | gridwidth=0.35, |
| | zerolinecolor="white", |
| | ), |
| | zaxis=dict( |
| | title='Dimension 3', |
| | nticks=40, |
| | backgroundcolor="rgb(10, 10, 20, 0.1)", |
| | gridcolor="white", |
| | showbackground=True, |
| | gridwidth=0.35, |
| | zerolinecolor="white", |
| | ), |
| | |
| | camera=dict( |
| | up=dict(x=0, y=0, z=1), |
| | center=dict(x=0, y=0, z=0), |
| | eye=dict(x=1.25, y=1.25, z=1.25), |
| | ), |
| | aspectratio=dict(x=1, y=1, z=1), |
| | aspectmode='data' |
| | ), |
| | width=int(chart_width), |
| | height=int(chart_height), |
| | margin=dict(r=20, l=10, b=10, t=50), |
| | paper_bgcolor="rgb(0, 0, 0)", |
| | plot_bgcolor="rgb(0, 0, 0)", |
| | coloraxis_colorbar=dict( |
| | title="Average Distance", |
| | thicknessmode="pixels", thickness=20, |
| | lenmode="pixels", len=400, |
| | yanchor="top", y=1, |
| | ticks="outside", |
| | dtick=0.1 |
| | ) |
| | ) |
| |
|
| | return fig |
| | return create_3d_embedding_chart, prepare_embedding_data_3d |
| |
|
| |
|
| | @app.cell |
| | def helper_function_text_preparation(): |
| | def convert_table_to_json_docs(df, selected_columns=None): |
| | """ |
| | Convert a pandas DataFrame or dictionary to a list of JSON documents. |
| | Dynamically includes columns based on user selection. |
| | Column names are standardized to lowercase with underscores instead of spaces |
| | and special characters removed. |
| | |
| | Args: |
| | df: The DataFrame or dictionary to process |
| | selected_columns: List of column names to include in the output documents |
| | |
| | Returns: |
| | list: A list of dictionaries, each representing a row as a JSON document |
| | """ |
| | import pandas as pd |
| | import re |
| |
|
| | def standardize_key(key): |
| | """Convert a column name to lowercase with underscores instead of spaces and no special characters""" |
| | if not isinstance(key, str): |
| | return str(key).lower() |
| | |
| | key = key.lower().replace(' ', '_') |
| | |
| | return re.sub(r'[^\w]', '', key) |
| |
|
| | |
| | if isinstance(df, dict): |
| | |
| | if selected_columns: |
| | return [{standardize_key(k): df.get(k, None) for k in selected_columns}] |
| | else: |
| | |
| | return [{standardize_key(k): v for k, v in df.items()}] |
| |
|
| | |
| | if df is None: |
| | return [] |
| |
|
| | |
| | if not isinstance(df, pd.DataFrame): |
| | try: |
| | df = pd.DataFrame(df) |
| | except: |
| | return [] |
| |
|
| | |
| | if df.empty: |
| | return [] |
| |
|
| | |
| | if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0: |
| | selected_columns = list(df.columns) |
| |
|
| | |
| | available_columns = [] |
| | columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)} |
| |
|
| | for col in selected_columns: |
| | if col in df.columns: |
| | available_columns.append(col) |
| | elif isinstance(col, str) and col.lower() in columns_lower: |
| | available_columns.append(columns_lower[col.lower()]) |
| |
|
| | |
| | if not available_columns: |
| | return [] |
| |
|
| | |
| | json_docs = [] |
| | for _, row in df.iterrows(): |
| | doc = {} |
| | for col in available_columns: |
| | value = row[col] |
| | |
| | std_col = standardize_key(col) |
| | doc[std_col] = None if pd.isna(value) else value |
| | json_docs.append(doc) |
| |
|
| | return json_docs |
| |
|
| | def get_column_values(df, columns_to_include): |
| | """ |
| | Extract values from specified columns of a dataframe as lists. |
| | |
| | Args: |
| | df: A pandas DataFrame |
| | columns_to_include: A list of column names to extract |
| | |
| | Returns: |
| | Dictionary with column names as keys and their values as lists |
| | """ |
| | result = {} |
| |
|
| | |
| | valid_columns = [col for col in columns_to_include if col in df.columns] |
| | invalid_columns = set(columns_to_include) - set(valid_columns) |
| |
|
| | if invalid_columns: |
| | print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}") |
| |
|
| | |
| | for col in valid_columns: |
| | result[col] = df[col].tolist() |
| |
|
| | return result |
| |
|
| | def get_data_in_range(doc_dict_df, index_range, columns_to_include): |
| | """ |
| | Extract values from specified columns of a dataframe within a given index range. |
| | |
| | Args: |
| | doc_dict_df: The pandas DataFrame to extract data from |
| | index_range: An integer specifying the number of rows to include (from 0 to index_range-1) |
| | columns_to_include: A list of column names to extract |
| | |
| | Returns: |
| | Dictionary with column names as keys and their values (within the index range) as lists |
| | """ |
| | |
| | max_index = len(doc_dict_df) |
| | if index_range <= 0: |
| | print(f"Warning: Invalid index range {index_range}. Must be positive.") |
| | return {} |
| |
|
| | |
| | if index_range > max_index: |
| | print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.") |
| | index_range = max_index |
| |
|
| | |
| | df_subset = doc_dict_df.iloc[:index_range] |
| |
|
| | |
| | return get_column_values(df_subset, columns_to_include) |
| |
|
| | def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include): |
| | """ |
| | Extract values from specified columns of a dataframe within a given index range. |
| | Wraps string values with triple quotes and escapes URLs. |
| | |
| | Args: |
| | doc_dict_df: The pandas DataFrame to extract data from |
| | index_range: A list of two integers specifying the start and end indices of rows to include |
| | (e.g., [0, 10] includes rows from index 0 to 9 inclusive) |
| | columns_to_include: A list of column names to extract |
| | """ |
| | |
| | start_idx, end_idx = index_range |
| | max_index = len(doc_dict_df) |
| |
|
| | |
| | if start_idx < 0: |
| | print(f"Warning: Invalid start index {start_idx}. Using 0 instead.") |
| | start_idx = 0 |
| |
|
| | |
| | if end_idx <= start_idx: |
| | print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.") |
| | end_idx = start_idx + 1 |
| |
|
| | |
| | if end_idx > max_index: |
| | print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.") |
| | end_idx = max_index |
| |
|
| | |
| | |
| | df_subset = doc_dict_df.iloc[start_idx:end_idx] |
| |
|
| | |
| | result = get_column_values(df_subset, columns_to_include) |
| |
|
| | |
| | for col in result: |
| | if isinstance(result[col], list): |
| | |
| | processed_items = [] |
| | for item in result[col]: |
| | if isinstance(item, str): |
| | |
| | item = item.replace("http://", "http\\://").replace("https://", "https\\://") |
| | |
| | processed_items.append(item) |
| | else: |
| | processed_items.append(item) |
| | result[col] = processed_items |
| | return result |
| | return (get_data_in_range_triplequote,) |
| |
|
| |
|
| | @app.cell |
| | def prepare_doc_select(sentence_splitter_config): |
| | def prepare_document_selection(node_dict): |
| | """ |
| | Creates document selection UI component. |
| | Args: |
| | node_dict: Dictionary mapping filenames to lists of documents |
| | Returns: |
| | mo.ui component for document selection |
| | """ |
| | |
| | total_docs = sum(len(docs) for docs in node_dict.values()) |
| |
|
| | |
| | all_docs_records = [] |
| | doc_index_global = 0 |
| | for filename, docs in node_dict.items(): |
| | for i, doc in enumerate(docs): |
| | |
| | if hasattr(doc, 'to_dict'): |
| | doc_data = doc.to_dict() |
| | elif isinstance(doc, dict): |
| | doc_data = doc |
| | else: |
| | doc_data = {'content': str(doc)} |
| |
|
| | |
| | doc_data['filename'] = filename |
| | doc_data['doc_index'] = i |
| | doc_data['global_index'] = doc_index_global |
| | all_docs_records.append(doc_data) |
| | doc_index_global += 1 |
| |
|
| | |
| | stop_value = max(total_docs, 1) |
| | llama_docs = mo.ui.range_slider( |
| | start=1, |
| | stop=stop_value, |
| | step=1, |
| | full_width=True, |
| | show_value=True, |
| | label="**Select a Range of Chunks to Visualize:**" |
| | ).form(submit_button_disabled=check_state(sentence_splitter_config.value), submit_button_label="Change Document View Range") |
| |
|
| | return llama_docs |
| | return (prepare_document_selection,) |
| |
|
| |
|
| | @app.cell |
| | def document_range_selection( |
| | dict_from_nodes, |
| | prepare_document_selection, |
| | set_range_slider_state, |
| | ): |
| | if dict_from_nodes is not None: |
| | llama_docs = prepare_document_selection(dict_from_nodes) |
| | set_range_slider_state(llama_docs) |
| | else: |
| | bare_dict = {} |
| | llama_docs = prepare_document_selection(bare_dict) |
| | return |
| |
|
| |
|
| | @app.function |
| | def create_cumulative_dataframe(dict_from_docs): |
| | """ |
| | Creates a cumulative DataFrame from a nested dictionary of documents. |
| | |
| | Args: |
| | dict_from_docs: Dictionary mapping filenames to lists of documents |
| | |
| | Returns: |
| | DataFrame with all documents flattened with global indices |
| | """ |
| | |
| | all_records = [] |
| | global_idx = 1 |
| |
|
| | for filename, docs in dict_from_docs.items(): |
| | for i, doc in enumerate(docs): |
| | |
| | if hasattr(doc, 'to_dict'): |
| | doc_data = doc.to_dict() |
| | elif isinstance(doc, dict): |
| | doc_data = doc.copy() |
| | else: |
| | doc_data = {'content': str(doc)} |
| |
|
| | |
| | doc_data['filename'] = filename |
| | doc_data['doc_index'] = i |
| | doc_data['global_index'] = global_idx |
| |
|
| | |
| | if 'content' in doc_data and 'text' not in doc_data: |
| | doc_data['text'] = doc_data['content'] |
| |
|
| | all_records.append(doc_data) |
| | global_idx += 1 |
| |
|
| | |
| | return pd.DataFrame(all_records) |
| |
|
| |
|
| | @app.function |
| | def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"): |
| | """ |
| | Create a list of stat objects for each item in the specified dictionary. |
| | |
| | Parameters: |
| | - texts_dict (dict): Dictionary containing the text data |
| | - bordered (bool): Whether the stats should be bordered |
| | - object_names (list or tuple): Two object names to use for label and value |
| | [label_object, value_object] |
| | - group_by_row (bool): Whether to group stats in rows (horizontal stacks) |
| | - items_per_row (int): Number of stat objects per row when group_by_row is True |
| | |
| | Returns: |
| | - object: A vertical stack of stat objects or rows of stat objects |
| | """ |
| | if not object_names or len(object_names) < 2: |
| | raise ValueError("You must provide two object names as a list or tuple") |
| |
|
| | label_object = object_names[0] |
| | value_object = object_names[1] |
| |
|
| | |
| | if label_object not in texts_dict: |
| | raise ValueError(f"Label object '{label_object}' not found in texts_dict") |
| | if value_object not in texts_dict: |
| | raise ValueError(f"Value object '{value_object}' not found in texts_dict") |
| |
|
| | |
| | num_items = len(texts_dict[label_object]) |
| |
|
| | |
| | individual_stats = [] |
| | for i in range(num_items): |
| | stat = mo.stat( |
| | label=texts_dict[label_object][i], |
| | value=f"{label} Number: {len(texts_dict[value_object][i])}", |
| | bordered=bordered |
| | ) |
| | individual_stats.append(stat) |
| |
|
| | |
| | if not group_by_row: |
| | return mo.vstack(individual_stats, wrap=False) |
| |
|
| | |
| | rows = [] |
| | for i in range(0, num_items, items_per_row): |
| | |
| | row_stats = individual_stats[i:i+items_per_row] |
| | |
| | widths = [0.35] * len(row_stats) |
| | row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths) |
| | rows.append(row) |
| |
|
| | |
| | return mo.vstack(rows) |
| |
|
| |
|
| | @app.cell |
| | def prepare_chart_embeddings( |
| | chunks_to_process, |
| | emb_model, |
| | emb_model_emb_dim, |
| | get_embedding_state, |
| | prepare_embedding_data_3d, |
| | ): |
| | |
| | if chunks_to_process is not None and get_embedding_state() is not None: |
| | chart_dataframe, chart_metadata = prepare_embedding_data_3d( |
| | get_embedding_state(), |
| | chunks_to_process, |
| | model_id=emb_model, |
| | embedding_dimensions=emb_model_emb_dim |
| | ) |
| | else: |
| | chart_dataframe, chart_metadata = None, None |
| | return chart_dataframe, chart_metadata |
| |
|
| |
|
| | @app.cell |
| | def chart_dims(): |
| | chart_dimensions = ( |
| | mo.md(''' |
| | > **Adjust Chart Window** |
| | |
| | {chart_height} |
| | |
| | {chat_width} |
| | |
| | ''').batch( |
| | chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True), |
| | chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True) |
| | ) |
| | ) |
| | return (chart_dimensions,) |
| |
|
| |
|
| | @app.cell |
| | def chart_dim_values(chart_dimensions): |
| | chart_height = chart_dimensions.value['chart_height'] |
| | chart_width = chart_dimensions.value['chat_width'] |
| | return chart_height, chart_width |
| |
|
| |
|
| | @app.cell |
| | def create_baseline_chart( |
| | chart_dataframe, |
| | chart_height, |
| | chart_metadata, |
| | chart_width, |
| | create_3d_embedding_chart, |
| | ): |
| | if chart_dataframe is not None and chart_metadata is not None: |
| | emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9) |
| | chart = mo.ui.plotly(emb_plot) |
| | else: |
| | emb_plot = None |
| | chart = None |
| | return (emb_plot,) |
| |
|
| |
|
| | @app.cell |
| | def test_query(get_chunk_state): |
| | placeholder = """How can i use watsonx.data to perform vector search?""" |
| |
|
| | query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, submit_button_disabled=check_state(get_chunk_state()), submit_button_label="Query and View Visualization") |
| | return (query,) |
| |
|
| |
|
| | @app.cell |
| | def query_stack(chart_dimensions, query): |
| | |
| | query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15) |
| | return (query_stack,) |
| |
|
| |
|
| | @app.function |
| | def check_state(variable): |
| | return variable is None |
| |
|
| |
|
| | @app.cell |
| | def helper_function_add_query_to_chart(): |
| | def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12): |
| | """ |
| | Add a query point to an existing 3D embedding chart as a large red dot. |
| | |
| | Args: |
| | existing_chart: The existing plotly figure or chart data |
| | query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point |
| | query_text: Text of the query to display on hover |
| | marker_size: Size of the query marker (default: 18, typically 2x other markers) |
| | |
| | Returns: |
| | A modified plotly figure with the query point added as a red dot |
| | """ |
| | import plotly.graph_objects as go |
| |
|
| | |
| | import copy |
| | chart_copy = copy.deepcopy(existing_chart) |
| |
|
| | |
| | if isinstance(chart_copy, (dict, list)): |
| | |
| | import plotly.graph_objects as go |
| |
|
| | if isinstance(chart_copy, list): |
| | |
| | fig = go.Figure(data=chart_copy) |
| | else: |
| | |
| | fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {})) |
| |
|
| | chart_copy = fig |
| |
|
| | |
| | query_trace = go.Scatter3d( |
| | x=[query_coords['x']], |
| | y=[query_coords['y']], |
| | z=[query_coords['z']], |
| | mode='markers', |
| | name='Query', |
| | marker=dict( |
| | size=marker_size, |
| | color='red', |
| | symbol='circle', |
| | opacity=0.70, |
| | line=dict( |
| | width=1, |
| | color='white' |
| | ) |
| | ), |
| | |
| | text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], |
| | hoverinfo="text+name" |
| | ) |
| |
|
| | |
| | chart_copy.add_trace(query_trace) |
| |
|
| | return chart_copy |
| |
|
| |
|
| | def get_query_coordinates(reference_embeddings=None, query_embedding=None): |
| | """ |
| | Calculate appropriate coordinates for a query point based on reference embeddings. |
| | |
| | This function handles several scenarios: |
| | 1. If both reference embeddings and query embedding are provided, it places the |
| | query near similar documents. |
| | 2. If only reference embeddings are provided, it places the query at a visible |
| | location near the center of the chart. |
| | 3. If neither are provided, it returns default origin coordinates. |
| | |
| | Args: |
| | reference_embeddings: DataFrame with x, y, z coordinates from the main chart |
| | query_embedding: The embedding vector of the query |
| | |
| | Returns: |
| | Dictionary with x, y, z coordinates for the query point |
| | """ |
| | import numpy as np |
| |
|
| | |
| | default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0} |
| |
|
| | |
| | if reference_embeddings is None or len(reference_embeddings) == 0: |
| | return default_coords |
| |
|
| | |
| | |
| | if query_embedding is None: |
| | center_coords = { |
| | 'x': reference_embeddings['x'].mean(), |
| | 'y': reference_embeddings['y'].mean(), |
| | 'z': reference_embeddings['z'].mean() |
| | } |
| | return center_coords |
| |
|
| | |
| | |
| | try: |
| | from sklearn.metrics.pairwise import cosine_similarity |
| |
|
| | |
| | if 'embedding' in reference_embeddings.columns: |
| | |
| | if isinstance(reference_embeddings['embedding'].iloc[0], list): |
| | doc_embeddings = np.array(reference_embeddings['embedding'].tolist()) |
| | else: |
| | doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values]) |
| |
|
| | |
| | query_emb_array = np.array(query_embedding) |
| | if query_emb_array.ndim == 1: |
| | query_emb_array = query_emb_array.reshape(1, -1) |
| |
|
| | |
| | similarities = cosine_similarity(query_emb_array, doc_embeddings)[0] |
| |
|
| | |
| | closest_idx = np.argmax(similarities) |
| |
|
| | |
| | query_coords = { |
| | 'x': reference_embeddings['x'].iloc[closest_idx] + 0.2, |
| | 'y': reference_embeddings['y'].iloc[closest_idx] + 0.2, |
| | 'z': reference_embeddings['z'].iloc[closest_idx] + 0.2 |
| | } |
| | return query_coords |
| | except Exception as e: |
| | print(f"Error positioning query near similar documents: {e}") |
| |
|
| | |
| | center_coords = { |
| | 'x': reference_embeddings['x'].mean(), |
| | 'y': reference_embeddings['y'].mean(), |
| | 'z': reference_embeddings['z'].mean() |
| | } |
| | return center_coords |
| | return add_query_to_embedding_chart, get_query_coordinates |
| |
|
| |
|
| | @app.cell |
| | def combined_chart_visualization( |
| | add_query_to_embedding_chart, |
| | chart_dataframe, |
| | emb_plot, |
| | embedding, |
| | get_query_coordinates, |
| | get_query_state, |
| | query, |
| | set_chart_state, |
| | set_query_state, |
| | ): |
| | |
| | if chart_dataframe is not None and query.value: |
| | with mo.status.spinner(title="Embedding Query...", remove_on_exit=True) as _spinner: |
| | query_emb = embedding.embed_documents([query.value]) |
| | set_query_state(query_emb) |
| | |
| | _spinner.update("Preparing Query Coordinates") |
| | time.sleep(1.0) |
| | |
| | |
| | query_coords = get_query_coordinates( |
| | reference_embeddings=chart_dataframe, |
| | query_embedding=get_query_state() |
| | ) |
| |
|
| | _spinner.update("Adding Query to Chart") |
| | time.sleep(1.0) |
| | |
| | |
| | result = add_query_to_embedding_chart( |
| | existing_chart=emb_plot, |
| | query_coords=query_coords, |
| | query_text=query.value, |
| | ) |
| | |
| | chart_with_query = result |
| | |
| | _spinner.update("Preparing Visualization") |
| | time.sleep(1.0) |
| | |
| | |
| | combined_viz = mo.ui.plotly(chart_with_query) |
| | set_chart_state(combined_viz) |
| | |
| | _spinner.update("Done") |
| | else: |
| | combined_viz = None |
| | return |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | get_range_slider_state, set_range_slider_state = mo.state(None) |
| | return get_range_slider_state, set_range_slider_state |
| |
|
| |
|
| | @app.cell |
| | def _(get_range_slider_state): |
| | if get_range_slider_state() is not None: |
| | document_range_stack = get_range_slider_state() |
| | else: |
| | document_range_stack = None |
| | return (document_range_stack,) |
| |
|
| |
|
| | @app.cell |
| | def _(): |
| | get_chart_state, set_chart_state = mo.state(None) |
| | return get_chart_state, set_chart_state |
| |
|
| |
|
| | @app.cell |
| | def _(get_chart_state, query): |
| | if query.value is not None: |
| | chart_visualization = get_chart_state() |
| | else: |
| | chart_visualization = None |
| | return (chart_visualization,) |
| |
|
| |
|
| | @app.cell |
| | def c(document_range_stack): |
| | chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65]) |
| | return (chart_range_selection,) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run() |
| |
|