| |
| |
| |
| |
|
|
| import marimo |
|
|
| __generated_with = "0.13.14" |
| app = marimo.App(width="full") |
|
|
| with app.setup: |
| |
| import marimo as mo |
| from typing import Dict, Optional, List, Union, Any |
| from ibm_watsonx_ai import APIClient, Credentials |
| from pathlib import Path |
| import pandas as pd |
| import mimetypes |
| import requests |
| import zipfile |
| import tempfile |
| import certifi |
| import base64 |
| import polars |
| import nltk |
| import time |
| import json |
| import ast |
| import os |
| import io |
| import re |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| def get_iam_token(api_key): |
| return requests.post( |
| 'https://iam.cloud.ibm.com/identity/token', |
| headers={'Content-Type': 'application/x-www-form-urlencoded'}, |
| data={'grant_type': 'urn:ibm:params:oauth:grant-type:apikey', 'apikey': api_key}, |
| verify=certifi.where() |
| ).json()['access_token'] |
|
|
| def setup_task_credentials(client): |
| |
| existing_credentials = client.task_credentials.get_details() |
|
|
| |
| if "resources" in existing_credentials and existing_credentials["resources"]: |
| for cred in existing_credentials["resources"]: |
| cred_id = client.task_credentials.get_id(cred) |
| client.task_credentials.delete(cred_id) |
|
|
| |
| return client.task_credentials.store() |
|
|
|
|
|
|
| @app.cell |
| def client_variables(client_instantiation_form): |
| client_setup = client_instantiation_form.value or None |
|
|
| |
| if client_setup: |
| wx_url = client_setup["wx_region"] if client_setup["wx_region"] else "EU" |
| wx_api_key = client_setup["wx_api_key"].strip() if client_setup["wx_api_key"] else None |
| os.environ["WATSONX_APIKEY"] = wx_api_key or "" |
|
|
| project_id = client_setup["project_id"].strip() if client_setup["project_id"] else None |
| space_id = client_setup["space_id"].strip() if client_setup["space_id"] else None |
| else: |
| os.environ["WATSONX_APIKEY"] = "" |
| project_id = space_id = wx_api_key = wx_url = None |
| return client_setup, project_id, space_id, wx_api_key, wx_url |
|
|
|
|
| @app.cell |
| def _(): |
| from baked_in_credentials.creds import credentials |
| from base_variables import wx_regions, wx_platform_url |
| from helper_functions.helper_functions import wrap_with_spaces, get_key_by_value, markdown_spacing, get_cell_values, create_parameter_table, get_cred_value |
| return ( |
| create_parameter_table, |
| credentials, |
| get_cell_values, |
| get_cred_value, |
| get_key_by_value, |
| wrap_with_spaces, |
| wx_regions, |
| ) |
|
|
|
|
| @app.cell |
| def client_instantiation( |
| client_setup, |
| project_id, |
| space_id, |
| wx_api_key, |
| wx_url, |
| ): |
| |
| if client_setup: |
| try: |
| wx_credentials = Credentials(url=wx_url, api_key=wx_api_key) |
| project_client = ( |
| APIClient(credentials=wx_credentials, project_id=project_id) |
| if project_id |
| else None |
| ) |
| deployment_client = ( |
| APIClient(credentials=wx_credentials, space_id=space_id) |
| if space_id |
| else None |
| ) |
| instantiation_success = True |
| instantiation_error = None |
| except Exception as e: |
| instantiation_success = False |
| instantiation_error = str(e) |
| wx_credentials = project_client = deployment_client = None |
| else: |
| wx_credentials = project_client = deployment_client = None |
| instantiation_success = None |
| instantiation_error = None |
|
|
| return ( |
| deployment_client, |
| instantiation_error, |
| instantiation_success, |
| project_client, |
| ) |
|
|
|
|
| @app.cell |
| def _(): |
| mo.md( |
| r""" |
| #watsonx.ai Embedding Visualizer - Marimo Notebook |
| |
| #### This marimo notebook can be used to develop a more intuitive understanding of how vector embeddings work by creating a 3D visualization of vector embeddings based on chunked PDF document pages. |
| |
| #### It can also serve as a useful tool for identifying gaps in model choice, chunking strategy or contents used in building collections by showing how far you are from what you want. |
| <br> |
| |
| /// admonition |
| Created by ***Milan Mrdenovic*** [milan.mrdenovic@ibm.com] for IBM Ecosystem Client Engineering, NCEE - ***version 5.3** - 20.04.2025* |
| /// |
| |
| |
| >Licensed under apache 2.0, users hold full accountability for any use or modification of the code. |
| ><br>This asset is part of a set meant to support IBMers, IBM Partners, Clients in developing understanding of how to better utilize various watsonx features and generative AI as a subject matter. |
| |
| <br> |
| """ |
| ) |
| return |
|
|
|
|
| @app.cell |
| def _(): |
| mo.md("""###Part 1 - Client Setup, File Preparation and Chunking""") |
| return |
|
|
|
|
| @app.cell |
| def accordion_client_setup( |
| client_selector, |
| client_stack, |
| current_mode, |
| switch_file_loader_type, |
| ): |
| ui_accordion_part_1_1 = mo.accordion( |
| { |
| "Instantiate Client": mo.vstack([ |
| client_stack, |
| mo.hstack([client_selector, switch_file_loader_type], justify="space-around", gap=2), |
| current_mode |
| ], align="center"), |
| } |
| ) |
|
|
| ui_accordion_part_1_1 |
| return |
|
|
|
|
| @app.cell |
| def _(switch_file_loader_type): |
| if switch_file_loader_type.value: |
| current_mode = mo.md("**Current Mode:** Using pre-made embedding/text files.") |
| else: |
| current_mode = mo.md("**Current Mode:** Using loaded pdf files and chunking.") |
| return (current_mode,) |
|
|
|
|
| @app.cell |
| def accordion_file_upload(select_stack): |
| ui_accordion_part_1_2 = mo.accordion( |
| { |
| "Select Model & Upload Files": select_stack |
| } |
| ) |
|
|
| ui_accordion_part_1_2 |
| return |
|
|
|
|
| @app.cell |
| def loaded_texts( |
| create_temp_files_from_uploads, |
| file_loader, |
| pdf_reader, |
| run_upload_button, |
| set_text_state, |
| switch_file_loader_type, |
| ): |
| if file_loader.value is not None and run_upload_button.value: |
| filepaths = create_temp_files_from_uploads(file_loader.value) |
| if switch_file_loader_type.value: |
| loaded_texts = load_json_csv_data_with_progress(filepaths, file_loader.value, show_progress=True) |
| else: |
| loaded_texts = load_pdf_data_with_progress(pdf_reader, filepaths, file_loader.value, show_progress=True) |
|
|
| set_text_state(loaded_texts) |
| else: |
| filepaths = None |
| loaded_texts = None |
| return (loaded_texts,) |
|
|
|
|
| @app.cell |
| def _(chunker_setup, file_column_setup, switch_file_loader_type): |
| if switch_file_loader_type.value: |
| ui_accordion_part_1_3 = mo.accordion( |
| { |
| "Column Selector": file_column_setup |
| } |
| ) |
| else: |
| ui_accordion_part_1_3 = mo.accordion( |
| { |
| "Chunker Setup": chunker_setup |
| } |
| ) |
|
|
| ui_accordion_part_1_3 |
| return |
|
|
|
|
| @app.cell |
| def accordion_chunker_setup(): |
| |
| |
| |
| |
| |
|
|
| |
| return |
|
|
|
|
| @app.cell |
| def chunk_documents_to_nodes( |
| get_text_state, |
| sentence_splitter, |
| sentence_splitter_config, |
| set_chunk_state, |
| ): |
| if sentence_splitter_config.value and sentence_splitter and get_text_state() is not None: |
| chunked_texts = chunk_documents(get_text_state(), sentence_splitter, show_progress=True) |
| set_chunk_state(chunked_texts) |
| else: |
| chunked_texts = None |
| return (chunked_texts,) |
|
|
|
|
| @app.cell |
| def _(): |
| mo.md(r"""###Part 2 - Query Setup and Visualization""") |
| return |
|
|
|
|
| @app.cell |
| def accordion_chunk_range(): |
| |
| |
| |
| |
| |
| |
| return |
|
|
|
|
| @app.cell |
| def _(chart_range_selection, switch_file_loader_type): |
| ui_accordion_part_2_1 = mo.accordion( |
| { |
| "Chunk Range Selection": chart_range_selection |
| } |
| ) |
| ui_accordion_part_2_1 if switch_file_loader_type.value == False else None |
| return |
|
|
|
|
| @app.cell |
| def chunk_embedding( |
| chunks_to_process, |
| embedding, |
| sentence_splitter_config, |
| set_embedding_state, |
| ): |
| if sentence_splitter_config.value is not None and chunks_to_process is not None: |
| with mo.status.spinner(title="Embedding Documents...", remove_on_exit=True) as _spinner: |
| output_embeddings = embedding.embed_documents(chunks_to_process) |
| _spinner.update("Almost Done") |
| time.sleep(1.5) |
| set_embedding_state(output_embeddings) |
| _spinner.update("Documents Embedded") |
| else: |
| output_embeddings = None |
| return |
|
|
|
|
| @app.cell |
| def preview_chunks(chunks_dict): |
| if chunks_dict is not None: |
| stats = create_stats(chunks_dict, |
| bordered=True, |
| object_names=['text','text'], |
| group_by_row=True, |
| items_per_row=5, |
| gap=1, |
| label="Chunk") |
| ui_chunk_viewer = mo.accordion( |
| { |
| "View Chunks": stats, |
| } |
| ) |
| else: |
| ui_chunk_viewer = None |
|
|
| ui_chunk_viewer |
| return |
|
|
|
|
| @app.cell |
| def accordion_query_view(chart_visualization, query_stack): |
| ui_accordion_part_2_2 = mo.accordion( |
| { |
| "Query": mo.vstack([query_stack, mo.hstack([chart_visualization])], align="center", gap=3) |
| } |
| ) |
| ui_accordion_part_2_2 |
| return |
|
|
|
|
| @app.cell |
| def chunker_setup(sentence_splitter_config): |
| chunker_setup = mo.hstack([sentence_splitter_config], justify="space-around", align="center", widths=[0.55]) |
| return (chunker_setup,) |
|
|
|
|
| @app.cell |
| def file_and_model_select( |
| file_loader, |
| get_embedding_model_list, |
| run_upload_button, |
| ): |
| select_stack = mo.hstack([get_embedding_model_list(), mo.vstack([mo.md("Drag & Drop or Double Click to select PDFs, then press **Load Files**"),file_loader, run_upload_button], align="center")], justify="space-around", align="center", widths=[0.3,0.3]) |
| return (select_stack,) |
|
|
|
|
| @app.cell |
| def client_instantiation_form(credentials, get_cred_value, wx_regions): |
| baked_in_creds = credentials |
| |
| client_instantiation_form = ( |
| mo.md(''' |
| ###**watsonx.ai credentials:** |
| |
| {wx_region} |
| |
| {wx_api_key} |
| |
| {project_id} |
| |
| {space_id} |
| |
| > You can add either a project_id, space_id or both, **only one is required**. |
| > If you provide both you can switch the active one in the dropdown. |
| ''') |
| .batch( |
| wx_region = mo.ui.dropdown( |
| wx_regions, |
| label="Select your watsonx.ai region:", |
| value=get_cred_value('region', creds_var_name='baked_in_creds') or "EU", |
| searchable=True |
| ), |
| wx_api_key = mo.ui.text( |
| placeholder="Add your IBM Cloud api-key...", |
| label="IBM Cloud Api-key:", |
| kind="password", |
| value=get_cred_value('api_key', creds_var_name='baked_in_creds') |
| ), |
| project_id = mo.ui.text( |
| placeholder="Add your watsonx.ai project_id...", |
| label="Project_ID:", |
| kind="text", |
| value=get_cred_value('project_id', creds_var_name='baked_in_creds') |
| ), |
| space_id = mo.ui.text( |
| placeholder="Add your watsonx.ai space_id...", |
| label="Space_ID:", |
| kind="text", |
| value=get_cred_value('space_id', creds_var_name='baked_in_creds') |
| ) |
| ,) |
| .form(show_clear_button=True, bordered=False) |
| ) |
| return (client_instantiation_form,) |
|
|
|
|
| @app.cell |
| def instantiation_status( |
| client_callout_kind, |
| client_instantiation_form, |
| client_status, |
| ): |
| client_callout = mo.callout(client_status, kind=client_callout_kind) |
| client_stack = mo.hstack([client_instantiation_form, client_callout], align="center", justify="space-around", gap=10) |
| return (client_stack,) |
|
|
|
|
| @app.cell |
| def _( |
| client_key, |
| client_options, |
| client_selector, |
| client_setup, |
| get_key_by_value, |
| instantiation_error, |
| instantiation_success, |
| wrap_with_spaces, |
| ): |
| active_client_name = get_key_by_value(client_options, client_key) |
|
|
| if client_setup: |
| if instantiation_success: |
| client_status = mo.md( |
| f"### ✅ Client Instantiation Successful ✅\n\n" |
| f"{client_selector}\n\n" |
| f"**Active Client:**{wrap_with_spaces(active_client_name, prefix_spaces=5)}" |
| ) |
| client_callout_kind = "success" |
| else: |
| client_status = mo.md( |
| f"### ❌ Client Instantiation Failed\n**Error:** {instantiation_error}\n\nCheck your region selection and credentials" |
| ) |
| client_callout_kind = "danger" |
| else: |
| client_status = mo.md( |
| f"### Client Instantiation Status will turn Green When Ready\n\n" |
| f"{client_selector}\n\n" |
| f"**Active Client:**{wrap_with_spaces(active_client_name, prefix_spaces=5)}" |
| ) |
| client_callout_kind = "neutral" |
|
|
| return client_callout_kind, client_status |
|
|
|
|
| @app.cell |
| def client_selector(deployment_client, project_client): |
| if project_client is not None and deployment_client is not None: |
| client_options = {"Project Client":project_client, "Deployment Client":deployment_client} |
|
|
| elif project_client is not None: |
| client_options = {"Project Client":project_client} |
|
|
| elif deployment_client is not None: |
| client_options = {"Deployment Client":deployment_client} |
|
|
| else: |
| client_options = {"No Client": "Instantiate a Client"} |
|
|
| default_client = next(iter(client_options)) |
| client_selector = mo.ui.dropdown(client_options, value=default_client, label="**Switch your active client:**") |
| return client_options, client_selector |
|
|
|
|
| @app.cell |
| def active_client(client_selector): |
| client_key = client_selector.value |
| if client_key == "Instantiate a Client": |
| client = None |
| else: |
| client = client_key |
| return client, client_key |
|
|
|
|
| @app.cell |
| def emb_model_selection(client, set_embedding_model_list): |
| if client is not None: |
| model_specs = client.foundation_models.get_embeddings_model_specs() |
| |
| resources = model_specs["resources"] |
| |
| embedding_models = { |
| "ibm/granite-embedding-107m-multilingual": {"max_tokens": 512, "embedding_dimensions": 384}, |
| "ibm/granite-embedding-278m-multilingual": {"max_tokens": 512, "embedding_dimensions": 768}, |
| "ibm/slate-125m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 768}, |
| "ibm/slate-125m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 768}, |
| "ibm/slate-30m-english-rtrvr-v2": {"max_tokens": 512, "embedding_dimensions": 384}, |
| "ibm/slate-30m-english-rtrvr": {"max_tokens": 512, "embedding_dimensions": 384}, |
| "sentence-transformers/all-minilm-l6-v2": {"max_tokens": 128, "embedding_dimensions": 384}, |
| "sentence-transformers/all-minilm-l12-v2": {"max_tokens": 128, "embedding_dimensions": 384}, |
| "intfloat/multilingual-e5-large": {"max_tokens": 512, "embedding_dimensions": 1024} |
| } |
|
|
| |
| model_id_list = [] |
| for resource in resources: |
| model_id_list.append(resource["model_id"]) |
|
|
| |
| embedding_model_data = [] |
| for model_id in model_id_list: |
| model_entry = {"model_id": model_id} |
|
|
| |
| if model_id in embedding_models: |
| model_entry["max_tokens"] = embedding_models[model_id]["max_tokens"] |
| model_entry["embedding_dimensions"] = embedding_models[model_id]["embedding_dimensions"] |
| else: |
| model_entry["max_tokens"] = 0 |
| model_entry["embedding_dimensions"] = 0 |
|
|
| embedding_model_data.append(model_entry) |
|
|
| embedding_model_selection = mo.ui.table( |
| embedding_model_data, |
| selection="single", |
| label="Select an embedding model to use.", |
| page_size=30, |
| initial_selection=[1] |
| ) |
| set_embedding_model_list(embedding_model_selection) |
| else: |
| default_model_data = [{ |
| "model_id": "ibm/granite-embedding-107m-multilingual", |
| "max_tokens": 512, |
| "embedding_dimensions": 384 |
| }] |
|
|
| set_embedding_model_list(create_emb_model_selection_table(default_model_data, initial_selection=0, selection_type="single", label="Select a model to use.")) |
| return |
|
|
|
|
| @app.function |
| def create_emb_model_selection_table(model_data, initial_selection=0, selection_type="single", label="Select a model to use."): |
| embedding_model_selection = mo.ui.table( |
| model_data, |
| selection=selection_type, |
| label=label, |
| page_size=30, |
| initial_selection=[initial_selection] |
| ) |
| return embedding_model_selection |
|
|
|
|
| @app.cell |
| def embedding_model(): |
| get_embedding_model_list, set_embedding_model_list = mo.state(None) |
| return get_embedding_model_list, set_embedding_model_list |
|
|
|
|
| @app.cell |
| def emb_model_parameters(emb_model_max_tk, embedding_model): |
| from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams |
| if embedding_model is not None: |
| embed_params = { |
| EmbedParams.TRUNCATE_INPUT_TOKENS: emb_model_max_tk, |
| EmbedParams.RETURN_OPTIONS: { |
| 'input_text': True |
| } |
| } |
| else: |
| embed_params = { |
| EmbedParams.TRUNCATE_INPUT_TOKENS: 128, |
| EmbedParams.RETURN_OPTIONS: { |
| 'input_text': True |
| } |
| } |
| return (embed_params,) |
|
|
|
|
| @app.cell |
| def emb_model_state(get_embedding_model_list): |
| embedding_model = get_embedding_model_list() |
| return (embedding_model,) |
|
|
|
|
| @app.cell |
| def emb_model_setup(embedding_model): |
| if embedding_model is not None: |
| emb_model = embedding_model.value[0]['model_id'] |
| emb_model_max_tk = embedding_model.value[0]['max_tokens'] |
| emb_model_emb_dim = embedding_model.value[0]['embedding_dimensions'] |
| else: |
| emb_model = None |
| emb_model_max_tk = None |
| emb_model_emb_dim = None |
| return emb_model, emb_model_emb_dim, emb_model_max_tk |
|
|
|
|
| @app.cell |
| def emb_model_instantiation(client, emb_model, embed_params): |
| from ibm_watsonx_ai.foundation_models import Embeddings |
| if client is not None: |
| embedding = Embeddings( |
| model_id=emb_model, |
| api_client=client, |
| params=embed_params, |
| batch_size=1000, |
| concurrency_limit=10 |
| ) |
| else: |
| embedding = None |
| return (embedding,) |
|
|
|
|
| @app.cell |
| def _(): |
| get_embedding_state, set_embedding_state = mo.state(None) |
| return get_embedding_state, set_embedding_state |
|
|
|
|
| @app.cell |
| def _(): |
| get_query_state, set_query_state = mo.state(None) |
| return get_query_state, set_query_state |
|
|
|
|
| @app.cell |
| def file_loader_input(switch_file_loader_type): |
| if switch_file_loader_type.value: |
| file_loader = mo.ui.file( |
| kind="area", |
| filetypes=[".json",".csv"], |
| label=" Load pre-made embedding/text pair files (.json,.csv) ", |
| multiple=False |
| ) |
| else: |
| file_loader = mo.ui.file( |
| kind="area", |
| filetypes=[".pdf"], |
| label=" Load .pdf files ", |
| multiple=True |
| ) |
| return (file_loader,) |
|
|
|
|
| @app.cell |
| def file_loader_run(file_loader): |
| if file_loader.value: |
| run_upload_button = mo.ui.run_button(label="Load Files") |
| else: |
| run_upload_button = mo.ui.run_button(disabled=True, label="Load Files") |
| return (run_upload_button,) |
|
|
|
|
| @app.cell |
| def helper_function_tempfiles(): |
| def create_temp_files_from_uploads(upload_results) -> List[str]: |
| """ |
| Creates temporary files from a tuple of FileUploadResults objects and returns their paths. |
| Args: |
| upload_results: Object containing a value attribute that is a tuple of FileUploadResults |
| Returns: |
| List of temporary file paths |
| """ |
| temp_file_paths = [] |
|
|
| |
| num_items = len(upload_results) |
|
|
| |
| for i in range(num_items): |
| result = upload_results[i] |
|
|
| |
| temp_dir = tempfile.gettempdir() |
| file_name = result.name |
| temp_path = os.path.join(temp_dir, file_name) |
| |
| with open(temp_path, 'wb') as temp_file: |
| temp_file.write(result.contents) |
| |
| temp_file_paths.append(temp_path) |
|
|
| return temp_file_paths |
|
|
| def cleanup_temp_files(temp_file_paths: List[str]) -> None: |
| """Delete temporary files after use.""" |
| for path in temp_file_paths: |
| if os.path.exists(path): |
| os.unlink(path) |
| return (create_temp_files_from_uploads,) |
|
|
|
|
| @app.function |
| def load_pdf_data_with_progress(pdf_reader, filepaths, file_loader_value, show_progress=True): |
| """ |
| Loads PDF data for each file path and organizes results by original filename. |
| Args: |
| pdf_reader: The PyMuPDFReader instance |
| filepaths: List of temporary file paths |
| file_loader_value: The original upload results value containing file information |
| show_progress: Whether to show a progress bar during loading (default: False) |
| Returns: |
| Dictionary mapping original filenames to their loaded text content |
| """ |
| results = {} |
|
|
| |
| if show_progress: |
| import marimo as mo |
| |
| with mo.status.progress_bar( |
| total=len(filepaths), |
| title="Loading PDFs", |
| subtitle="Processing documents...", |
| completion_title="PDF Loading Complete", |
| completion_subtitle=f"{len(filepaths)} documents processed", |
| remove_on_exit=True |
| ) as bar: |
| |
| for i, file_path in enumerate(filepaths): |
|
|
| original_file_name = file_loader_value[i].name |
| bar.update(subtitle=f"Processing {original_file_name}...") |
| loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) |
|
|
| |
| results[original_file_name] = loaded_text |
| |
| bar.update(increment=1) |
| else: |
| |
| for i, file_path in enumerate(filepaths): |
| original_file_name = file_loader_value[i].name |
| loaded_text = pdf_reader.load_data(file_path=file_path, metadata=True) |
| results[original_file_name] = loaded_text |
|
|
| return results |
|
|
|
|
| @app.cell |
| def file_readers(): |
| from llama_index.readers.file import PyMuPDFReader |
| from llama_index.readers.file import FlatReader |
| from llama_index.core.node_parser import SentenceSplitter |
|
|
| |
| pdf_reader = PyMuPDFReader() |
| |
| return SentenceSplitter, pdf_reader |
|
|
|
|
| @app.cell |
| def sentence_splitter_setup(): |
| |
| sentence_splitter_config = ( |
| mo.md(''' |
| ###**Chunking Setup:** |
| |
| > Unless you want to do some advanced sentence splitting, it's best to stick to adjusting only the chunk size and overlap. Changing the other settings might result in unexpected results. |
| |
| Separator value is set to **" "** by default, while the paragraph separator is **"\\n\\n\\n"**. |
| |
| {chunk_size} |
| |
| {chunk_overlap} |
| |
| {separator} {paragraph_separator} |
| |
| {secondary_chunking_regex} {include_metadata} |
| |
| ''') |
| .batch( |
| chunk_size = mo.ui.slider(start=100, stop=5000, step=1, label="**Chunk Size:**", value=275, show_value=True, full_width=True), |
| chunk_overlap = mo.ui.slider(start=0, stop=1000, step=1, label="**Chunk Overlap** *(Must always be smaller than Chunk Size)* **:**", value=0, show_value=True, full_width=True), |
| separator = mo.ui.text(placeholder="Define a separator", label="**Separator:**", kind="text", value=" "), |
| paragraph_separator = mo.ui.text(placeholder="Define a paragraph separator", |
| label="**Paragraph Separator:**", kind="text", |
| value="\n\n\n"), |
| secondary_chunking_regex = mo.ui.text(placeholder="Define a secondary chunking regex", |
| label="**Chunking Regex:**", kind="text", |
| value="[^,.;?!]+[,.;?!]?"), |
| include_metadata= mo.ui.checkbox(value=True, label="**Include Metadata**") |
| ) |
| .form(show_clear_button=True, bordered=False, submit_button_label="Chunk Documents") |
| ) |
| return (sentence_splitter_config,) |
|
|
|
|
| @app.cell |
| def sentence_splitter_instantiation( |
| SentenceSplitter, |
| sentence_splitter_config, |
| ): |
| |
| def simple_whitespace_tokenizer(text): |
| return text.split() |
|
|
| if sentence_splitter_config.value is not None: |
| sentence_splitter_config_values = sentence_splitter_config.value |
| validated_chunk_overlap = min(sentence_splitter_config_values.get("chunk_overlap"), |
| int(sentence_splitter_config_values.get("chunk_size") * 0.3)) |
|
|
| sentence_splitter = SentenceSplitter( |
| chunk_size=sentence_splitter_config_values.get("chunk_size"), |
| chunk_overlap=validated_chunk_overlap, |
| separator=sentence_splitter_config_values.get("separator"), |
| paragraph_separator=sentence_splitter_config_values.get("paragraph_separator"), |
| secondary_chunking_regex=sentence_splitter_config_values.get("secondary_chunking_regex"), |
| include_metadata=sentence_splitter_config_values.get("include_metadata"), |
| tokenizer=simple_whitespace_tokenizer |
| ) |
|
|
| else: |
| sentence_splitter = SentenceSplitter( |
| chunk_size=2048, |
| chunk_overlap=204, |
| separator=" ", |
| paragraph_separator="\n\n\n", |
| secondary_chunking_regex="[^,.;?!]+[,.;?!]?", |
| include_metadata=True, |
| tokenizer=simple_whitespace_tokenizer |
| ) |
| return (sentence_splitter,) |
|
|
|
|
| @app.cell |
| def text_state(): |
| get_text_state, set_text_state = mo.state(None) |
| return get_text_state, set_text_state |
|
|
|
|
| @app.cell |
| def chunk_state(): |
| get_chunk_state, set_chunk_state = mo.state(None) |
| return get_chunk_state, set_chunk_state |
|
|
|
|
| @app.function |
| def chunk_documents(loaded_texts, sentence_splitter, show_progress=True): |
| """ |
| Process each document in the loaded_texts dictionary using the sentence_splitter, |
| with an optional marimo progress bar tracking progress at document level. |
| |
| Args: |
| loaded_texts (dict): Dictionary containing lists of Document objects |
| sentence_splitter: The sentence splitter object with get_nodes_from_documents method |
| show_progress (bool): Whether to show a progress bar during processing |
| |
| Returns: |
| dict: Dictionary with the same structure but containing chunked texts |
| """ |
| chunked_texts_dict = {} |
|
|
| |
| total_docs = sum(len(docs) for docs in loaded_texts.values()) |
| processed_docs = 0 |
|
|
| |
| if show_progress: |
| import marimo as mo |
| |
| with mo.status.progress_bar( |
| total=total_docs, |
| title="Processing Documents", |
| subtitle="Chunking documents...", |
| completion_title="Processing Complete", |
| completion_subtitle=f"{total_docs} documents processed", |
| remove_on_exit=True |
| ) as bar: |
| |
| for key, documents in loaded_texts.items(): |
| |
| doc_count = len(documents) |
| bar.update(subtitle=f"Chunking {key}... ({doc_count} documents)") |
|
|
| |
| chunked_texts = sentence_splitter.get_nodes_from_documents( |
| documents, |
| show_progress=False |
| ) |
|
|
| |
| chunked_texts_dict[key] = chunked_texts |
| time.sleep(0.15) |
|
|
| |
| bar.update(increment=doc_count) |
| processed_docs += doc_count |
| else: |
| |
| for key, documents in loaded_texts.items(): |
| chunked_texts = sentence_splitter.get_nodes_from_documents( |
| documents, |
| show_progress=True |
| ) |
| chunked_texts_dict[key] = chunked_texts |
|
|
| return chunked_texts_dict |
|
|
|
|
| @app.cell |
| def chunked_nodes(chunked_texts, get_chunk_state, sentence_splitter): |
| if chunked_texts is not None and sentence_splitter: |
| chunked_documents = get_chunk_state() |
| else: |
| chunked_documents = None |
| return (chunked_documents,) |
|
|
|
|
| @app.cell |
| def prep_cumulative_df(chunked_documents, llamaindex_convert_docs_multi): |
| if chunked_documents is not None: |
| dict_from_nodes = llamaindex_convert_docs_multi(chunked_documents) |
| nodes_from_dict = llamaindex_convert_docs_multi(dict_from_nodes) |
| else: |
| dict_from_nodes = None |
| nodes_from_dict = None |
| return (dict_from_nodes,) |
|
|
|
|
| @app.cell |
| def chunks_to_process( |
| dict_from_nodes, |
| document_range_stack, |
| get_data_in_range_triplequote, |
| ): |
| if dict_from_nodes is not None and document_range_stack is not None: |
|
|
| chunk_dict_df = create_cumulative_dataframe(dict_from_nodes) |
|
|
| if document_range_stack.value is not None: |
| chunk_start_idx = document_range_stack.value[0] |
| chunk_end_idx = document_range_stack.value[1] |
| else: |
| chunk_start_idx = 0 |
| chunk_end_idx = len(chunk_dict_df) |
|
|
| chunk_range_index = [chunk_start_idx, chunk_end_idx] |
| chunks_dict = get_data_in_range_triplequote(chunk_dict_df, |
| index_range=chunk_range_index, |
| columns_to_include=["text"]) |
|
|
| chunks_to_process = chunks_dict['text'] if 'text' in chunks_dict else [] |
| else: |
| chunk_objects = None |
| chunks_dict = None |
| chunks_to_process = None |
| return chunks_dict, chunks_to_process |
|
|
|
|
| @app.cell |
| def helper_function_doc_formatting(): |
| def llamaindex_convert_docs_multi(items): |
| """ |
| Automatically convert between document objects and dictionaries. |
| |
| This function handles: |
| - Converting dictionaries to document objects |
| - Converting document objects to dictionaries |
| - Processing lists or individual items |
| - Supporting dictionary structures where values are lists of documents |
| |
| Args: |
| items: A document object, dictionary, or list of either. |
| Can also be a dictionary mapping filenames to lists of documents. |
| |
| Returns: |
| Converted item(s) maintaining the original structure |
| """ |
| |
| if not items: |
| return [] |
|
|
| |
| if isinstance(items, dict) and all(isinstance(v, list) for v in items.values()): |
| result = {} |
| for filename, doc_list in items.items(): |
| result[filename] = llamaindex_convert_docs(doc_list) |
| return result |
|
|
| |
| if not isinstance(items, list): |
| |
| if isinstance(items, dict): |
| |
| doc_class = None |
| if 'doc_type' in items: |
| import importlib |
| module_path, class_name = items['doc_type'].rsplit('.', 1) |
| module = importlib.import_module(module_path) |
| doc_class = getattr(module, class_name) |
| if not doc_class: |
| from llama_index.core.schema import Document |
| doc_class = Document |
| return doc_class.from_dict(items) |
| |
| elif hasattr(items, 'to_dict'): |
| return items.to_dict() |
| |
| return items |
|
|
| |
| result = [] |
|
|
| |
| if len(items) == 0: |
| return result |
|
|
| |
| first_item = next((item for item in items if item is not None), None) |
|
|
| |
| if first_item is None: |
| return result |
|
|
| |
| if isinstance(first_item, dict): |
| |
| doc_class = None |
| |
| if 'doc_type' in first_item: |
| import importlib |
| module_path, class_name = first_item['doc_type'].rsplit('.', 1) |
| module = importlib.import_module(module_path) |
| doc_class = getattr(module, class_name) |
| if not doc_class: |
| |
| from llama_index.core.schema import Document |
| doc_class = Document |
|
|
| |
| for item in items: |
| if isinstance(item, dict): |
| result.append(doc_class.from_dict(item)) |
| elif item is None: |
| result.append(None) |
| elif isinstance(item, list): |
| result.append(llamaindex_convert_docs(item)) |
| else: |
| result.append(item) |
|
|
| |
| else: |
| for item in items: |
| if hasattr(item, 'to_dict'): |
| result.append(item.to_dict()) |
| elif item is None: |
| result.append(None) |
| elif isinstance(item, list): |
| result.append(llamaindex_convert_docs(item)) |
| else: |
| result.append(item) |
|
|
| return result |
|
|
| def llamaindex_convert_docs(items): |
| """ |
| Automatically convert between document objects and dictionaries. |
| |
| Args: |
| items: A list of document objects or dictionaries |
| |
| Returns: |
| List of converted items (dictionaries or document objects) |
| """ |
| result = [] |
|
|
| |
| if not items: |
| return result |
|
|
| |
| if isinstance(items[0], dict): |
| |
| |
| doc_class = None |
|
|
| |
| if 'doc_type' in items[0]: |
| import importlib |
| module_path, class_name = items[0]['doc_type'].rsplit('.', 1) |
| module = importlib.import_module(module_path) |
| doc_class = getattr(module, class_name) |
|
|
| if not doc_class: |
| |
| from llama_index.core.schema import Document |
| doc_class = Document |
|
|
| |
| for item in items: |
| if isinstance(item, dict): |
| result.append(doc_class.from_dict(item)) |
| else: |
| |
| for item in items: |
| if hasattr(item, 'to_dict'): |
| result.append(item.to_dict()) |
|
|
| return result |
| return (llamaindex_convert_docs_multi,) |
|
|
|
|
| @app.cell |
| def helper_function_create_df(): |
| def create_document_dataframes(dict_from_docs): |
| """ |
| Creates a pandas DataFrame for each file in the dictionary. |
| |
| Args: |
| dict_from_docs: Dictionary mapping filenames to lists of documents |
| |
| Returns: |
| List of pandas DataFrames, each representing all documents from a single file |
| """ |
| dataframes = [] |
|
|
| for filename, docs in dict_from_docs.items(): |
| |
| file_records = [] |
|
|
| for i, doc in enumerate(docs): |
| |
| if hasattr(doc, 'to_dict'): |
| doc_data = doc.to_dict() |
| elif isinstance(doc, dict): |
| doc_data = doc |
| else: |
| doc_data = {'content': str(doc)} |
|
|
| |
| doc_data['doc_index'] = i |
|
|
| |
| file_records.append(doc_data) |
|
|
| |
| if file_records: |
| df = pd.DataFrame(file_records) |
| df['filename'] = filename |
| dataframes.append(df) |
|
|
| return dataframes |
|
|
| def create_dataframe_previews(dataframe_list, page_size=5): |
| """ |
| Creates a list of mo.ui.dataframe components, one for each DataFrame in the input list. |
| |
| Args: |
| dataframe_list: List of pandas DataFrames (output from create_document_dataframes) |
| page_size: Number of rows to show per page for each component |
| |
| Returns: |
| List of mo.ui.dataframe components |
| """ |
| |
| preview_components = [] |
|
|
| for df in dataframe_list: |
| |
| preview = mo.ui.dataframe(df, page_size=page_size) |
| preview_components.append(preview) |
|
|
| return preview_components |
| return |
|
|
|
|
| @app.cell |
| def _(): |
| switch_file_loader_type = mo.ui.switch(label="**Switch** to pre-made Embedding/Text pairs") |
| return (switch_file_loader_type,) |
|
|
|
|
| @app.cell |
| def _(): |
| import csv |
|
|
| def csv_to_json(csv_file_path, json_file_path): |
| """ |
| Convert CSV file to JSON format. |
| |
| Args: |
| csv_file_path (str): Path to input CSV file |
| json_file_path (str): Path to output JSON file |
| """ |
| with open(csv_file_path, 'r', encoding='utf-8') as csv_file: |
| csv_reader = csv.DictReader(csv_file) |
| data = list(csv_reader) |
| |
| with open(json_file_path, 'w', encoding='utf-8') as json_file: |
| json.dump(data, json_file, indent=2) |
|
|
| return |
|
|
|
|
| @app.function |
| def load_json_csv_data_with_progress(filepaths, file_loader_value, show_progress=True): |
| """ |
| Loads CSV (converted to JSON) or JSON data for a single file. |
| Returns the raw JSON content without filename mapping. |
| """ |
| import csv |
| import json |
| |
| filepath = filepaths[0] |
| original_file_name = file_loader_value[0].name |
| |
| if show_progress: |
| import marimo as mo |
| with mo.status.progress_bar( |
| total=1, |
| title="Loading File", |
| subtitle=f"Processing {original_file_name}...", |
| completion_title="File Loading Complete", |
| completion_subtitle="1 file processed", |
| remove_on_exit=True |
| ) as bar: |
| if filepath.lower().endswith('.csv'): |
| with open(filepath, 'r', encoding='utf-8') as csv_file: |
| csv_reader = csv.DictReader(csv_file) |
| result = list(csv_reader) |
| elif filepath.lower().endswith('.json'): |
| with open(filepath, 'r', encoding='utf-8') as json_file: |
| result = json.load(json_file) |
| bar.update(increment=1) |
| else: |
| if filepath.lower().endswith('.csv'): |
| with open(filepath, 'r', encoding='utf-8') as csv_file: |
| csv_reader = csv.DictReader(csv_file) |
| result = list(csv_reader) |
| elif filepath.lower().endswith('.json'): |
| with open(filepath, 'r', encoding='utf-8') as json_file: |
| result = json.load(json_file) |
| |
| return result |
|
|
|
|
| @app.function |
| def organize_data_by_columns(loaded_texts, columns_to_use): |
| """ |
| Organizes loaded text data into specified column groups based on configuration. |
| |
| Args: |
| loaded_texts: List of dictionaries containing the data |
| columns_to_use: Dictionary mapping column group names to their field configurations |
| |
| Returns: |
| Dictionary with column group names as keys and lists of field values as values |
| """ |
| result = {} |
| |
| for group_name, field_config in columns_to_use.items(): |
| result[group_name] = [] |
| |
| |
| selected_fields = [field for field, include in field_config.items() if include] |
| |
| for record in loaded_texts: |
| |
| for field in selected_fields: |
| if field in record: |
| result[group_name].append(record[field]) |
| |
| return result |
|
|
|
|
| @app.cell |
| def _( |
| create_parameter_table, |
| get_text_state, |
| loaded_texts, |
| switch_file_loader_type, |
| ): |
| if switch_file_loader_type.value and loaded_texts: |
| column_list = list(get_text_state()[0]) |
| text_column = create_parameter_table( |
| label="Select the Embedded Text Column", |
| input_list=column_list, |
| column_name="Text Column", |
| selection_type="single-cell", |
| text_justify="center" |
| ) |
| embedding_column = create_parameter_table( |
| label="Select the Corresponding Embeddings Column", |
| input_list=column_list, |
| column_name="Embedding Column", |
| selection_type="single-cell", |
| text_justify="center" |
| ) |
| column_selection_stack = mo.hstack([text_column, embedding_column], justify="space-around", widths=[0.4,0.4]) |
| file_column_setup = mo.hstack([column_selection_stack], justify="space-around", align="center", widths=[0.75]) |
| else: |
| text_column = embedding_column = column_selection_stack = file_column_setup = None |
| return embedding_column, file_column_setup, text_column |
|
|
|
|
| @app.cell |
| def _(embedding_column, get_cell_values, switch_file_loader_type, text_column): |
| if switch_file_loader_type.value: |
| text_col_value = get_cell_values(text_column) |
| emb_col_value = get_cell_values(embedding_column) |
| columns_to_use = { |
| "texts": text_col_value, |
| "embeddings": emb_col_value |
| } |
| else: |
| text_col_value = emb_col_value = columns_to_use = None |
| return columns_to_use, text_col_value |
|
|
|
|
| @app.cell |
| def _(columns_to_use, loaded_texts, text_col_value): |
| if text_col_value and columns_to_use and loaded_texts: |
| premade_documents = organize_data_by_columns(columns_to_use=columns_to_use, loaded_texts=loaded_texts) |
| text_col_state = validate_value(premade_documents["texts"]) |
| emb_col_state = validate_value(premade_documents["embeddings"]) |
| columns_selected = all_true(text_col_state, emb_col_state) |
| else: |
| premade_documents = text_col_state = emb_col_state = emb_col_state = columns_selected = None |
| return columns_selected, premade_documents |
|
|
|
|
| @app.function |
| def validate_value(value): |
| """ |
| Check if a value is not None or an empty data object. |
| """ |
| return value is not None and bool(value) |
|
|
|
|
| @app.cell |
| def helper_function_chart_preparation(): |
| import altair as alt |
| import numpy as np |
| import plotly.express as px |
| from sklearn.manifold import TSNE |
|
|
| def prepare_embedding_data(embeddings, texts, model_id=None, embedding_dimensions=None): |
| """ |
| Prepare embedding data for visualization |
| |
| Args: |
| embeddings: List of embeddings arrays |
| texts: List of text strings |
| model_id: Embedding model ID (optional) |
| embedding_dimensions: Embedding dimensions (optional) |
| |
| Returns: |
| DataFrame with processed data and metadata |
| """ |
| |
| flattened_embeddings = [] |
| for emb in embeddings: |
| if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): |
| flattened_embeddings.append(emb[0]) |
| else: |
| flattened_embeddings.append(emb) |
|
|
| |
| embedding_array = np.array(flattened_embeddings) |
|
|
| |
| tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embedding_array)-1)) |
| reduced_embeddings = tsne.fit_transform(embedding_array) |
|
|
| |
| truncated_texts = [text[:50] + "..." if len(text) > 50 else text for text in texts] |
|
|
| |
| df = pd.DataFrame({ |
| "x": reduced_embeddings[:, 0], |
| "y": reduced_embeddings[:, 1], |
| "text": truncated_texts, |
| "full_text": texts, |
| "index": range(len(texts)) |
| }) |
|
|
| |
| metadata = { |
| "model_id": model_id, |
| "embedding_dimensions": embedding_dimensions |
| } |
|
|
| return df, metadata |
|
|
| def create_embedding_chart(df, metadata=None): |
| """ |
| Create an Altair chart for embedding visualization |
| |
| Args: |
| df: DataFrame with x, y coordinates and text |
| metadata: Dictionary with model_id and embedding_dimensions |
| |
| Returns: |
| Altair chart |
| """ |
| model_id = metadata.get("model_id") if metadata else None |
| embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None |
|
|
| selection = alt.selection_multi(fields=['index']) |
|
|
| base = alt.Chart(df).encode( |
| x=alt.X("x:Q", title="Dimension 1"), |
| y=alt.Y("y:Q", title="Dimension 2"), |
| tooltip=["text", "index"] |
| ) |
|
|
| points = base.mark_circle(size=100).encode( |
| color=alt.Color("index:N", legend=None), |
| opacity=alt.condition(selection, alt.value(1), alt.value(0.2)) |
| ).add_selection(selection) |
|
|
| text = base.mark_text(align="left", dx=7).encode( |
| text="index:N" |
| ) |
|
|
| return (points + text).properties( |
| width=700, |
| height=500, |
| title=f"Embedding Visualization{f' - Model: {model_id}' if model_id else ''}{f' ({embedding_dimensions} dimensions)' if embedding_dimensions else ''}" |
| ).interactive() |
|
|
| def show_selected_text(indices, texts): |
| """ |
| Create markdown display for selected texts |
| |
| Args: |
| indices: List of selected indices |
| texts: List of all texts |
| |
| Returns: |
| Markdown string |
| """ |
| if not indices: |
| return "No text selected" |
|
|
| selected_texts = [texts[i] for i in indices if i < len(texts)] |
| return "\n\n".join([f"**Document {i}**:\n{text}" for i, text in zip(indices, selected_texts)]) |
|
|
| def prepare_embedding_data_3d(embeddings, texts, model_id=None, embedding_dimensions=None): |
| """ |
| Prepare embedding data for 3D visualization |
| |
| Args: |
| embeddings: List of embeddings arrays |
| texts: List of text strings |
| model_id: Embedding model ID (optional) |
| embedding_dimensions: Embedding dimensions (optional) |
| |
| Returns: |
| DataFrame with processed data and metadata |
| """ |
| |
| flattened_embeddings = [] |
| for emb in embeddings: |
| if isinstance(emb, list) and len(emb) > 0 and isinstance(emb[0], list): |
| flattened_embeddings.append(emb[0]) |
| else: |
| flattened_embeddings.append(emb) |
|
|
| |
| embedding_array = np.array(flattened_embeddings) |
|
|
| |
| if len(embedding_array) == 1: |
| |
| reduced_embeddings = np.array([[0.0, 0.0, 0.0]]) |
| else: |
| |
| |
| perplexity_value = max(1.0, min(30, len(embedding_array)-1)) |
| tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity_value) |
| reduced_embeddings = tsne.fit_transform(embedding_array) |
|
|
| |
| formatted_texts = [] |
| for text in texts: |
| |
| if len(text) > 500: |
| text = text[:500] + "..." |
|
|
| |
| wrapped_text = "" |
| for i in range(0, len(text), 50): |
| wrapped_text += text[i:i+50] + "<br>" |
|
|
| formatted_texts.append("<b>"+wrapped_text+"</b>") |
|
|
| |
| df = pd.DataFrame({ |
| "x": reduced_embeddings[:, 0], |
| "y": reduced_embeddings[:, 1], |
| "z": reduced_embeddings[:, 2], |
| "text": formatted_texts, |
| "full_text": texts, |
| "index": range(len(texts)), |
| "embedding": flattened_embeddings |
| }) |
|
|
| |
| metadata = { |
| "model_id": model_id, |
| "embedding_dimensions": embedding_dimensions |
| } |
|
|
| return df, metadata |
|
|
| def create_3d_embedding_chart(df, metadata=None, chart_width=1200, chart_height=800, marker_size_var: int=3): |
| """ |
| Create a 3D Plotly chart for embedding visualization with proximity-based coloring |
| """ |
| model_id = metadata.get("model_id") if metadata else None |
| embedding_dimensions = metadata.get("embedding_dimensions") if metadata else None |
|
|
| |
| from scipy.spatial.distance import pdist, squareform |
| |
| coords = df[['x', 'y', 'z']].values |
|
|
| |
| dist_matrix = squareform(pdist(coords)) |
|
|
| |
| avg_distances = np.mean(dist_matrix, axis=1) |
|
|
| |
| df['proximity'] = avg_distances |
|
|
| |
| fig = px.scatter_3d( |
| df, |
| x='x', |
| y='y', |
| z='z', |
| |
| |
| |
| color='proximity', |
| color_continuous_scale='Viridis_r', |
| hover_data=['text', 'index', 'proximity'], |
| labels={'x': 'Dimension 1', 'y': 'Dimension 2', 'z': 'Dimension 3', 'proximity': 'Avg Distance'}, |
| |
| title=f"<b>3D Embedding Visualization</b>{f' - Model: <b>{model_id}</b>' if model_id else ''}{f' <i>({embedding_dimensions} dimensions)</i>' if embedding_dimensions else ''}", |
| text='index', |
| |
| ) |
|
|
| |
| |
| fig.update_traces( |
| marker=dict( |
| size=marker_size_var, |
| opacity=0.7, |
| symbol="diamond", |
| line=dict( |
| width=0.5, |
| color="white" |
| ) |
| ), |
| textfont=dict( |
| color="rgba(255, 255, 255, 0.3)", |
| size=8 |
| ), |
| |
| hovertemplate="text:<br><b>%{customdata[0]}</b><br>index: <b>%{text}</b><br><br>Avg Distance: <b>%{customdata[2]:.4f}</b><extra></extra>", |
| hoverinfo="text+name", |
| hoverlabel=dict( |
| bgcolor="white", |
| font_size=12 |
| ), |
| selector=dict(type='scatter3d') |
| ) |
|
|
| |
| fig.update_layout( |
| scene=dict( |
| xaxis=dict( |
| title='Dimension 1', |
| nticks=40, |
| backgroundcolor="rgb(10, 10, 20, 0.1)", |
| gridcolor="white", |
| showbackground=True, |
| gridwidth=0.35, |
| zerolinecolor="white", |
| ), |
| yaxis=dict( |
| title='Dimension 2', |
| nticks=40, |
| backgroundcolor="rgb(10, 10, 20, 0.1)", |
| gridcolor="white", |
| showbackground=True, |
| gridwidth=0.35, |
| zerolinecolor="white", |
| ), |
| zaxis=dict( |
| title='Dimension 3', |
| nticks=40, |
| backgroundcolor="rgb(10, 10, 20, 0.1)", |
| gridcolor="white", |
| showbackground=True, |
| gridwidth=0.35, |
| zerolinecolor="white", |
| ), |
| |
| camera=dict( |
| up=dict(x=0, y=0, z=1), |
| center=dict(x=0, y=0, z=0), |
| eye=dict(x=1.25, y=1.25, z=1.25), |
| ), |
| aspectratio=dict(x=1, y=1, z=1), |
| aspectmode='data' |
| ), |
| width=int(chart_width), |
| height=int(chart_height), |
| margin=dict(r=20, l=10, b=10, t=50), |
| paper_bgcolor="rgb(0, 0, 0)", |
| plot_bgcolor="rgb(0, 0, 0)", |
| coloraxis_colorbar=dict( |
| title="Average Distance", |
| thicknessmode="pixels", thickness=20, |
| lenmode="pixels", len=400, |
| yanchor="top", y=1, |
| ticks="outside", |
| dtick=0.1 |
| ) |
| ) |
|
|
| return fig |
| return create_3d_embedding_chart, prepare_embedding_data_3d |
|
|
|
|
| @app.cell |
| def helper_function_text_preparation(): |
| def convert_table_to_json_docs(df, selected_columns=None): |
| """ |
| Convert a pandas DataFrame or dictionary to a list of JSON documents. |
| Dynamically includes columns based on user selection. |
| Column names are standardized to lowercase with underscores instead of spaces |
| and special characters removed. |
| |
| Args: |
| df: The DataFrame or dictionary to process |
| selected_columns: List of column names to include in the output documents |
| |
| Returns: |
| list: A list of dictionaries, each representing a row as a JSON document |
| """ |
| import pandas as pd |
| import re |
|
|
| def standardize_key(key): |
| """Convert a column name to lowercase with underscores instead of spaces and no special characters""" |
| if not isinstance(key, str): |
| return str(key).lower() |
| |
| key = key.lower().replace(' ', '_') |
| |
| return re.sub(r'[^\w]', '', key) |
|
|
| |
| if isinstance(df, dict): |
| |
| if selected_columns: |
| return [{standardize_key(k): df.get(k, None) for k in selected_columns}] |
| else: |
| |
| return [{standardize_key(k): v for k, v in df.items()}] |
|
|
| |
| if df is None: |
| return [] |
|
|
| |
| if not isinstance(df, pd.DataFrame): |
| try: |
| df = pd.DataFrame(df) |
| except: |
| return [] |
|
|
| |
| if df.empty: |
| return [] |
|
|
| |
| if not selected_columns or not isinstance(selected_columns, list) or len(selected_columns) == 0: |
| selected_columns = list(df.columns) |
|
|
| |
| available_columns = [] |
| columns_lower = {col.lower(): col for col in df.columns if isinstance(col, str)} |
|
|
| for col in selected_columns: |
| if col in df.columns: |
| available_columns.append(col) |
| elif isinstance(col, str) and col.lower() in columns_lower: |
| available_columns.append(columns_lower[col.lower()]) |
|
|
| |
| if not available_columns: |
| return [] |
|
|
| |
| json_docs = [] |
| for _, row in df.iterrows(): |
| doc = {} |
| for col in available_columns: |
| value = row[col] |
| |
| std_col = standardize_key(col) |
| doc[std_col] = None if pd.isna(value) else value |
| json_docs.append(doc) |
|
|
| return json_docs |
|
|
| def get_column_values(df, columns_to_include): |
| """ |
| Extract values from specified columns of a dataframe as lists. |
| |
| Args: |
| df: A pandas DataFrame |
| columns_to_include: A list of column names to extract |
| |
| Returns: |
| Dictionary with column names as keys and their values as lists |
| """ |
| result = {} |
|
|
| |
| valid_columns = [col for col in columns_to_include if col in df.columns] |
| invalid_columns = set(columns_to_include) - set(valid_columns) |
|
|
| if invalid_columns: |
| print(f"Warning: These columns don't exist in the dataframe: {list(invalid_columns)}") |
|
|
| |
| for col in valid_columns: |
| result[col] = df[col].tolist() |
|
|
| return result |
|
|
| def get_data_in_range(doc_dict_df, index_range, columns_to_include): |
| """ |
| Extract values from specified columns of a dataframe within a given index range. |
| |
| Args: |
| doc_dict_df: The pandas DataFrame to extract data from |
| index_range: An integer specifying the number of rows to include (from 0 to index_range-1) |
| columns_to_include: A list of column names to extract |
| |
| Returns: |
| Dictionary with column names as keys and their values (within the index range) as lists |
| """ |
| |
| max_index = len(doc_dict_df) |
| if index_range <= 0: |
| print(f"Warning: Invalid index range {index_range}. Must be positive.") |
| return {} |
|
|
| |
| if index_range > max_index: |
| print(f"Warning: Index range {index_range} exceeds dataframe length {max_index}. Using maximum length.") |
| index_range = max_index |
|
|
| |
| df_subset = doc_dict_df.iloc[:index_range] |
|
|
| |
| return get_column_values(df_subset, columns_to_include) |
|
|
| def get_data_in_range_triplequote(doc_dict_df, index_range, columns_to_include): |
| """ |
| Extract values from specified columns of a dataframe within a given index range. |
| Wraps string values with triple quotes and escapes URLs. |
| |
| Args: |
| doc_dict_df: The pandas DataFrame to extract data from |
| index_range: A list of two integers specifying the start and end indices of rows to include |
| (e.g., [0, 10] includes rows from index 0 to 9 inclusive) |
| columns_to_include: A list of column names to extract |
| """ |
| |
| start_idx, end_idx = index_range |
| max_index = len(doc_dict_df) |
|
|
| |
| if start_idx < 0: |
| print(f"Warning: Invalid start index {start_idx}. Using 0 instead.") |
| start_idx = 0 |
|
|
| |
| if end_idx <= start_idx: |
| print(f"Warning: End index {end_idx} must be greater than start index {start_idx}. Using {start_idx + 1} instead.") |
| end_idx = start_idx + 1 |
|
|
| |
| if end_idx > max_index: |
| print(f"Warning: End index {end_idx} exceeds dataframe length {max_index}. Using maximum length.") |
| end_idx = max_index |
|
|
| |
| |
| df_subset = doc_dict_df.iloc[start_idx:end_idx] |
|
|
| |
| result = get_column_values(df_subset, columns_to_include) |
|
|
| |
| for col in result: |
| if isinstance(result[col], list): |
| |
| processed_items = [] |
| for item in result[col]: |
| if isinstance(item, str): |
| |
| item = item.replace("http://", "http\\://").replace("https://", "https\\://") |
| |
| processed_items.append(item) |
| else: |
| processed_items.append(item) |
| result[col] = processed_items |
| return result |
| return (get_data_in_range_triplequote,) |
|
|
|
|
| @app.cell |
| def prepare_doc_select(sentence_splitter_config): |
| def prepare_document_selection(node_dict): |
| """ |
| Creates document selection UI component. |
| Args: |
| node_dict: Dictionary mapping filenames to lists of documents |
| Returns: |
| mo.ui component for document selection |
| """ |
| |
| total_docs = sum(len(docs) for docs in node_dict.values()) |
|
|
| |
| all_docs_records = [] |
| doc_index_global = 0 |
| for filename, docs in node_dict.items(): |
| for i, doc in enumerate(docs): |
| |
| if hasattr(doc, 'to_dict'): |
| doc_data = doc.to_dict() |
| elif isinstance(doc, dict): |
| doc_data = doc |
| else: |
| doc_data = {'content': str(doc)} |
|
|
| |
| doc_data['filename'] = filename |
| doc_data['doc_index'] = i |
| doc_data['global_index'] = doc_index_global |
| all_docs_records.append(doc_data) |
| doc_index_global += 1 |
|
|
| |
| stop_value = max(total_docs, 1) |
| llama_docs = mo.ui.range_slider( |
| start=1, |
| stop=stop_value, |
| step=1, |
| full_width=True, |
| show_value=True, |
| label="**Select a Range of Chunks to Visualize:**" |
| ).form(submit_button_disabled=check_state(sentence_splitter_config.value), submit_button_label="Change Document View Range") |
|
|
| return llama_docs |
| return (prepare_document_selection,) |
|
|
|
|
| @app.cell |
| def document_range_selection( |
| dict_from_nodes, |
| prepare_document_selection, |
| set_range_slider_state, |
| ): |
| if dict_from_nodes is not None: |
| llama_docs = prepare_document_selection(dict_from_nodes) |
| set_range_slider_state(llama_docs) |
| else: |
| bare_dict = {} |
| llama_docs = prepare_document_selection(bare_dict) |
| return |
|
|
|
|
| @app.function |
| def create_cumulative_dataframe(dict_from_docs): |
| """ |
| Creates a cumulative DataFrame from a nested dictionary of documents. |
| |
| Args: |
| dict_from_docs: Dictionary mapping filenames to lists of documents |
| |
| Returns: |
| DataFrame with all documents flattened with global indices |
| """ |
| |
| all_records = [] |
| global_idx = 1 |
|
|
| for filename, docs in dict_from_docs.items(): |
| for i, doc in enumerate(docs): |
| |
| if hasattr(doc, 'to_dict'): |
| doc_data = doc.to_dict() |
| elif isinstance(doc, dict): |
| doc_data = doc.copy() |
| else: |
| doc_data = {'content': str(doc)} |
|
|
| |
| doc_data['filename'] = filename |
| doc_data['doc_index'] = i |
| doc_data['global_index'] = global_idx |
|
|
| |
| if 'content' in doc_data and 'text' not in doc_data: |
| doc_data['text'] = doc_data['content'] |
|
|
| all_records.append(doc_data) |
| global_idx += 1 |
|
|
| |
| return pd.DataFrame(all_records) |
|
|
|
|
| @app.function |
| def create_stats(texts_dict, bordered=False, object_names=None, group_by_row=False, items_per_row=6, gap=2, label="Chunk"): |
| """ |
| Create a list of stat objects for each item in the specified dictionary. |
| |
| Parameters: |
| - texts_dict (dict): Dictionary containing the text data |
| - bordered (bool): Whether the stats should be bordered |
| - object_names (list or tuple): Two object names to use for label and value |
| [label_object, value_object] |
| - group_by_row (bool): Whether to group stats in rows (horizontal stacks) |
| - items_per_row (int): Number of stat objects per row when group_by_row is True |
| |
| Returns: |
| - object: A vertical stack of stat objects or rows of stat objects |
| """ |
| if not object_names or len(object_names) < 2: |
| raise ValueError("You must provide two object names as a list or tuple") |
|
|
| label_object = object_names[0] |
| value_object = object_names[1] |
|
|
| |
| if label_object not in texts_dict: |
| raise ValueError(f"Label object '{label_object}' not found in texts_dict") |
| if value_object not in texts_dict: |
| raise ValueError(f"Value object '{value_object}' not found in texts_dict") |
|
|
| |
| num_items = len(texts_dict[label_object]) |
|
|
| |
| individual_stats = [] |
| for i in range(num_items): |
| stat = mo.stat( |
| label=texts_dict[label_object][i], |
| value=f"{label} Number: {len(texts_dict[value_object][i])}", |
| bordered=bordered |
| ) |
| individual_stats.append(stat) |
|
|
| |
| if not group_by_row: |
| return mo.vstack(individual_stats, wrap=False) |
|
|
| |
| rows = [] |
| for i in range(0, num_items, items_per_row): |
| |
| row_stats = individual_stats[i:i+items_per_row] |
| |
| widths = [0.35] * len(row_stats) |
| row = mo.hstack(row_stats, gap=gap, align="start", justify="center", widths=widths) |
| rows.append(row) |
|
|
| |
| return mo.vstack(rows) |
|
|
|
|
| @app.cell |
| def prepare_chart_embeddings( |
| chunks_to_process, |
| emb_model, |
| emb_model_emb_dim, |
| get_embedding_state, |
| prepare_embedding_data_3d, |
| ): |
| |
| if chunks_to_process is not None and get_embedding_state() is not None: |
| chart_dataframe, chart_metadata = prepare_embedding_data_3d( |
| get_embedding_state(), |
| chunks_to_process, |
| model_id=emb_model, |
| embedding_dimensions=emb_model_emb_dim |
| ) |
| else: |
| chart_dataframe, chart_metadata = None, None |
| return chart_dataframe, chart_metadata |
|
|
|
|
| @app.function |
| def all_true(*args): |
| """ |
| Check if all provided boolean arguments are True. |
| """ |
| return all(args) |
|
|
|
|
| @app.cell |
| def _(chart_dataframe_prem, columns_selected): |
| print(columns_selected,chart_dataframe_prem) |
| return |
|
|
|
|
| @app.cell |
| def _( |
| columns_selected, |
| emb_model, |
| emb_model_emb_dim, |
| premade_documents, |
| prepare_embedding_data_3d, |
| ): |
| if premade_documents and columns_selected: |
| chart_dataframe_prem, chart_metadata_prem = prepare_embedding_data_3d( |
| premade_documents["embeddings"], |
| premade_documents["texts"], |
| model_id=emb_model, |
| embedding_dimensions=emb_model_emb_dim |
| ) |
| else: |
| chart_dataframe_prem = chart_metadata_prem = None |
| return chart_dataframe_prem, chart_metadata_prem |
|
|
|
|
| @app.cell |
| def chart_dims(): |
| chart_dimensions = ( |
| mo.md(''' |
| > **Adjust Chart Window** |
| |
| {chart_height} |
| |
| {chat_width} |
| |
| ''').batch( |
| chart_height = mo.ui.slider(start=500, step=30, stop=1000, label="**Height:**", value=800, show_value=True), |
| chat_width = mo.ui.slider(start=900, step=50, stop=1400, label="**Width:**", value=1200, show_value=True) |
| ) |
| ) |
| return (chart_dimensions,) |
|
|
|
|
| @app.cell |
| def chart_dim_values(chart_dimensions): |
| chart_height = chart_dimensions.value['chart_height'] |
| chart_width = chart_dimensions.value['chat_width'] |
| return chart_height, chart_width |
|
|
|
|
| @app.cell |
| def create_baseline_chart( |
| chart_dataframe, |
| chart_dataframe_prem, |
| chart_height, |
| chart_metadata, |
| chart_metadata_prem, |
| chart_width, |
| create_3d_embedding_chart, |
| ): |
| if chart_dataframe is not None and chart_metadata is not None: |
| emb_plot = create_3d_embedding_chart(chart_dataframe, chart_metadata, chart_width, chart_height, marker_size_var=9) |
| chart = mo.ui.plotly(emb_plot) |
| chart_reference = chart_dataframe |
| |
| elif chart_dataframe_prem is not None and chart_metadata_prem is not None: |
| emb_plot = create_3d_embedding_chart(chart_dataframe_prem, chart_metadata_prem, chart_width, chart_height, marker_size_var=9) |
| chart = mo.ui.plotly(emb_plot) |
| chart_reference = chart_dataframe_prem |
| |
| else: |
| emb_plot = chart = chart_reference = None |
| return chart, chart_reference, emb_plot |
|
|
|
|
| @app.cell |
| def test_query(get_chunk_state, premade_documents, switch_file_loader_type): |
| placeholder = """How can i use watsonx.data to perform vector search?""" |
| if switch_file_loader_type.value: |
| query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, |
| submit_button_disabled=check_state(premade_documents), |
| submit_button_label="Query and View Visualization") |
| else: |
| query = mo.ui.text_area(label="**Write text to check:**", full_width=True, rows=8, value=placeholder).form(show_clear_button=True, |
| submit_button_disabled=check_state(get_chunk_state()), |
| submit_button_label="Query and View Visualization") |
| return (query,) |
|
|
|
|
| @app.cell |
| def query_stack(chart_dimensions, query): |
| |
| query_stack = mo.hstack([query, chart_dimensions], justify="space-around", align="center", gap=15) |
| return (query_stack,) |
|
|
|
|
| @app.function |
| def check_state(variable): |
| return variable is None |
|
|
|
|
| @app.cell |
| def helper_function_add_query_to_chart(): |
| def add_query_to_embedding_chart(existing_chart, query_coords, query_text, marker_size=12): |
| """ |
| Add a query point to an existing 3D embedding chart as a large red dot. |
| |
| Args: |
| existing_chart: The existing plotly figure or chart data |
| query_coords: Dictionary with 'x', 'y', 'z' coordinates for the query point |
| query_text: Text of the query to display on hover |
| marker_size: Size of the query marker (default: 18, typically 2x other markers) |
| |
| Returns: |
| A modified plotly figure with the query point added as a red dot |
| """ |
| import plotly.graph_objects as go |
|
|
| |
| import copy |
| chart_copy = copy.deepcopy(existing_chart) |
|
|
| |
| if isinstance(chart_copy, (dict, list)): |
| |
| import plotly.graph_objects as go |
|
|
| if isinstance(chart_copy, list): |
| |
| fig = go.Figure(data=chart_copy) |
| else: |
| |
| fig = go.Figure(data=chart_copy.get('data', []), layout=chart_copy.get('layout', {})) |
|
|
| chart_copy = fig |
|
|
| |
| query_trace = go.Scatter3d( |
| x=[query_coords['x']], |
| y=[query_coords['y']], |
| z=[query_coords['z']], |
| mode='markers', |
| name='Query', |
| marker=dict( |
| size=marker_size, |
| color='red', |
| symbol='circle', |
| opacity=0.70, |
| line=dict( |
| width=1, |
| color='white' |
| ) |
| ), |
| |
| text=['<b>Query:</b><br>' + '<br>'.join([query_text[i:i+50] for i in range(0, len(query_text), 50)])], |
| hoverinfo="text+name" |
| ) |
|
|
| |
| chart_copy.add_trace(query_trace) |
|
|
| return chart_copy |
|
|
|
|
| def get_query_coordinates(reference_embeddings=None, query_embedding=None): |
| """ |
| Calculate appropriate coordinates for a query point based on reference embeddings. |
| |
| This function handles several scenarios: |
| 1. If both reference embeddings and query embedding are provided, it places the |
| query near similar documents. |
| 2. If only reference embeddings are provided, it places the query at a visible |
| location near the center of the chart. |
| 3. If neither are provided, it returns default origin coordinates. |
| |
| Args: |
| reference_embeddings: DataFrame with x, y, z coordinates from the main chart |
| query_embedding: The embedding vector of the query |
| |
| Returns: |
| Dictionary with x, y, z coordinates for the query point |
| """ |
| import numpy as np |
|
|
| |
| default_coords = {'x': 0.0, 'y': 0.0, 'z': 0.0} |
|
|
| |
| if reference_embeddings is None or len(reference_embeddings) == 0: |
| return default_coords |
|
|
| |
| |
| if query_embedding is None: |
| center_coords = { |
| 'x': reference_embeddings['x'].mean(), |
| 'y': reference_embeddings['y'].mean(), |
| 'z': reference_embeddings['z'].mean() |
| } |
| return center_coords |
|
|
| |
| |
| try: |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| |
| if 'embedding' in reference_embeddings.columns: |
| |
| if isinstance(reference_embeddings['embedding'].iloc[0], list): |
| doc_embeddings = np.array(reference_embeddings['embedding'].tolist()) |
| else: |
| doc_embeddings = np.array([emb for emb in reference_embeddings['embedding'].values]) |
|
|
| |
| query_emb_array = np.array(query_embedding) |
| if query_emb_array.ndim == 1: |
| query_emb_array = query_emb_array.reshape(1, -1) |
|
|
| |
| similarities = cosine_similarity(query_emb_array, doc_embeddings)[0] |
|
|
| |
| closest_idx = np.argmax(similarities) |
|
|
| |
| query_coords = { |
| 'x': reference_embeddings['x'].iloc[closest_idx] + 0.2, |
| 'y': reference_embeddings['y'].iloc[closest_idx] + 0.2, |
| 'z': reference_embeddings['z'].iloc[closest_idx] + 0.2 |
| } |
| return query_coords |
| except Exception as e: |
| print(f"Error positioning query near similar documents: {e}") |
|
|
| |
| center_coords = { |
| 'x': reference_embeddings['x'].mean(), |
| 'y': reference_embeddings['y'].mean(), |
| 'z': reference_embeddings['z'].mean() |
| } |
| return center_coords |
| return add_query_to_embedding_chart, get_query_coordinates |
|
|
|
|
| @app.cell |
| def combined_chart_visualization( |
| add_query_to_embedding_chart, |
| chart, |
| chart_reference, |
| emb_plot, |
| embedding, |
| get_query_coordinates, |
| get_query_state, |
| query, |
| set_chart_state, |
| set_query_state, |
| ): |
| |
| if chart is not None and query.value: |
| with mo.status.spinner(title="Embedding Query...", remove_on_exit=True) as _spinner: |
| query_emb = embedding.embed_documents([query.value]) |
| set_query_state(query_emb) |
|
|
| _spinner.update("Preparing Query Coordinates") |
| time.sleep(1.0) |
|
|
| |
| query_coords = get_query_coordinates( |
| reference_embeddings=chart_reference, |
| query_embedding=get_query_state() |
| ) |
|
|
| _spinner.update("Adding Query to Chart") |
| time.sleep(1.0) |
|
|
| |
| result = add_query_to_embedding_chart( |
| existing_chart=emb_plot, |
| query_coords=query_coords, |
| query_text=query.value, |
| ) |
|
|
| chart_with_query = result |
|
|
| _spinner.update("Preparing Visualization") |
| time.sleep(1.0) |
|
|
| |
| combined_viz = mo.ui.plotly(chart_with_query) |
| set_chart_state(combined_viz) |
|
|
| _spinner.update("Done") |
| else: |
| combined_viz = None |
| return |
|
|
|
|
| @app.cell |
| def _(): |
| get_range_slider_state, set_range_slider_state = mo.state(None) |
| return get_range_slider_state, set_range_slider_state |
|
|
|
|
| @app.cell |
| def _(get_range_slider_state): |
| if get_range_slider_state() is not None: |
| document_range_stack = get_range_slider_state() |
| else: |
| document_range_stack = None |
| return (document_range_stack,) |
|
|
|
|
| @app.cell |
| def _(): |
| get_chart_state, set_chart_state = mo.state(None) |
| return get_chart_state, set_chart_state |
|
|
|
|
| @app.cell |
| def _(get_chart_state, query): |
| if query.value is not None: |
| chart_visualization = get_chart_state() |
| else: |
| chart_visualization = None |
| return (chart_visualization,) |
|
|
|
|
| @app.cell |
| def c(document_range_stack): |
| chart_range_selection = mo.hstack([document_range_stack], justify="space-around", align="center", widths=[0.65]) |
| return (chart_range_selection,) |
|
|
|
|
| if __name__ == "__main__": |
| app.run() |
|
|