""" Visualization utilities for semantic space exploration. Creates interactive plots showing the semantic relationships between legal passages. """ import numpy as np import plotly.graph_objects as go from sklearn.manifold import TSNE from sklearn.decomposition import PCA import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer import json def load_embeddings_and_metadata(): """Load all embeddings and metadata from ChromaDB.""" persist_directory = "./chromadb" client = chromadb.Client(Settings( anonymized_telemetry=False, persist_directory=persist_directory, is_persistent=True )) collection = client.get_collection("legal_cases") # Get all documents results = collection.get(include=['embeddings', 'metadatas', 'documents']) embeddings = np.array(results['embeddings']) metadatas = results['metadatas'] documents = results['documents'] return embeddings, metadatas, documents def create_2d_projection(embeddings, method='tsne', perplexity=5): """ Project high-dimensional embeddings to 2D space. Args: embeddings: Array of embeddings method: 'tsne' or 'pca' perplexity: t-SNE perplexity parameter (use lower for small datasets) Returns: 2D coordinates """ if method == 'tsne': # Adjust perplexity based on dataset size n_samples = len(embeddings) perplexity = min(perplexity, (n_samples - 1) // 3) tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity) coords_2d = tsne.fit_transform(embeddings) else: # pca pca = PCA(n_components=2, random_state=42) coords_2d = pca.fit_transform(embeddings) return coords_2d def create_semantic_space_plot(method='tsne'): """ Create an interactive plot of the semantic space. Returns: Plotly figure object """ print(f"Loading embeddings...") embeddings, metadatas, documents = load_embeddings_and_metadata() print(f"Projecting to 2D using {method.upper()}...") coords_2d = create_2d_projection(embeddings, method=method) # Extract case names and create colors case_names = [meta['case_name'] for meta in metadatas] unique_cases = list(set(case_names)) color_map = {case: i for i, case in enumerate(unique_cases)} colors = [color_map[name] for name in case_names] # Create hover text hover_texts = [] for i, (meta, doc) in enumerate(zip(metadatas, documents)): text = f"{meta['case_name']}
" text += f"Position: {meta['position_pct']:.1f}% through opinion
" text += f"
Passage preview:
{doc[:200]}..." hover_texts.append(text) # Create the plot fig = go.Figure() fig.add_trace(go.Scatter( x=coords_2d[:, 0], y=coords_2d[:, 1], mode='markers', marker=dict( size=10, color=colors, colorscale='Viridis', showscale=True, colorbar=dict( title="Case", tickvals=list(range(len(unique_cases))), ticktext=unique_cases, len=0.7 ), line=dict(width=0.5, color='white') ), text=hover_texts, hovertemplate='%{text}', name='Legal Passages' )) fig.update_layout( title=f'Semantic Space of Legal Cases ({method.upper()} Projection)', xaxis_title=f'{method.upper()} Dimension 1', yaxis_title=f'{method.upper()} Dimension 2', hovermode='closest', width=900, height=700, plot_bgcolor='rgba(240, 240, 240, 0.9)', showlegend=False ) return fig def create_citation_network_plot(): """ Create an interactive network plot showing citation relationships. Returns: Plotly figure object """ print("Loading citation graph...") with open('citation_graph.json', 'r') as f: citation_graph = json.load(f) # Create nodes case_ids = list(citation_graph.keys()) n_cases = len(case_ids) # Simple circular layout angles = np.linspace(0, 2*np.pi, n_cases, endpoint=False) node_x = np.cos(angles) node_y = np.sin(angles) # Create edges for citations edge_x = [] edge_y = [] for i, case_id in enumerate(case_ids): cited_cases = citation_graph[case_id].get('cites', []) for cited in cited_cases: cited_id = cited['case_id'] if cited_id in case_ids: j = case_ids.index(cited_id) # Add edge edge_x.extend([node_x[i], node_x[j], None]) edge_y.extend([node_y[i], node_y[j], None]) # Create edge trace edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=1, color='rgba(125, 125, 125, 0.3)'), hoverinfo='none', mode='lines', showlegend=False ) # Create node trace node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', marker=dict( size=20, color='lightblue', line=dict(width=2, color='darkblue') ), text=[cid.replace('_', ' ').title() for cid in case_ids], textposition="top center", showlegend=False ) # Add hover text with citation info hover_texts = [] for case_id in case_ids: cites = citation_graph[case_id].get('cites', []) cited_by = citation_graph[case_id].get('cited_by', []) text = f"{case_id.replace('_', ' ').title()}
" text += f"Cites: {len(cites)} cases
" text += f"Cited by: {len(cited_by)} cases" hover_texts.append(text) node_trace.hovertext = hover_texts # Create figure fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( title='Citation Network of Legal Cases', showlegend=False, hovermode='closest', width=900, height=700, xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white' ) return fig if __name__ == "__main__": # Test the visualization fig = create_semantic_space_plot(method='tsne') fig.show() fig2 = create_citation_network_plot() fig2.show()