import os import gradio as gr from dotenv import load_dotenv import openai from mistralai.client import MistralClient import google.generativeai as genai from anthropic import Anthropic import numpy as np import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots import pandas as pd import base64 from io import BytesIO import json import re import traceback import uuid import onnx import onnxruntime as ort from onnx import helper, numpy_helper import networkx as nx from collections import defaultdict, Counter # --- 1. INITIALIZATION & API KEY SETUP --- load_dotenv() # Securely get API keys from environment variables ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") # Initialize the LLM clients anthropic_client = None openai_client = None mistral_client = None try: if ANTHROPIC_API_KEY: anthropic_client = Anthropic(api_key=ANTHROPIC_API_KEY) if OPENAI_API_KEY: openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if GEMINI_API_KEY: genai.configure(api_key=GEMINI_API_KEY) if MISTRAL_API_KEY: mistral_client = MistralClient(api_key=MISTRAL_API_KEY) except Exception as e: print(f"Error initializing clients: {e}. Please check your API keys.") # Create a directory to store the plot images os.makedirs("temp_plots", exist_ok=True) # --- 2. ONNX MODEL ANALYSIS FUNCTIONS --- def analyze_onnx_model(file_path: str) -> dict: """ Analyzes an ONNX model file and extracts comprehensive information. Returns a dictionary with model structure, operators, and metadata. """ try: # Load the ONNX model model = onnx.load(file_path) # Basic model info model_info = { 'ir_version': model.ir_version, 'producer_name': model.producer_name, 'producer_version': model.producer_version, 'domain': model.domain, 'model_version': model.model_version, 'doc_string': model.doc_string } # Graph analysis graph = model.graph # Node analysis nodes = [] op_types = Counter() for i, node in enumerate(graph.node): node_info = { 'index': i, 'op_type': node.op_type, 'name': node.name or f"{node.op_type}_{i}", 'inputs': list(node.input), 'outputs': list(node.output), 'attributes': {} } # Parse attributes for attr in node.attribute: if attr.type == onnx.AttributeProto.INT: node_info['attributes'][attr.name] = attr.i elif attr.type == onnx.AttributeProto.FLOAT: node_info['attributes'][attr.name] = attr.f elif attr.type == onnx.AttributeProto.STRING: node_info['attributes'][attr.name] = attr.s.decode('utf-8') elif attr.type == onnx.AttributeProto.INTS: node_info['attributes'][attr.name] = list(attr.ints) elif attr.type == onnx.AttributeProto.FLOATS: node_info['attributes'][attr.name] = list(attr.floats) nodes.append(node_info) op_types[node.op_type] += 1 # Input/Output analysis inputs = [] for inp in graph.input: input_info = { 'name': inp.name, 'type': inp.type.tensor_type.elem_type, 'shape': [dim.dim_value if dim.dim_value > 0 else dim.dim_param for dim in inp.type.tensor_type.shape.dim] } inputs.append(input_info) outputs = [] for out in graph.output: output_info = { 'name': out.name, 'type': out.type.tensor_type.elem_type, 'shape': [dim.dim_value if dim.dim_value > 0 else dim.dim_param for dim in out.type.tensor_type.shape.dim] } outputs.append(output_info) # Value info (intermediate tensors) value_info = [] for val in graph.value_info: val_info = { 'name': val.name, 'type': val.type.tensor_type.elem_type, 'shape': [dim.dim_value if dim.dim_value > 0 else dim.dim_param for dim in val.type.tensor_type.shape.dim] } value_info.append(val_info) # Initializers (weights/constants) initializers = [] for init in graph.initializer: init_info = { 'name': init.name, 'data_type': init.data_type, 'dims': list(init.dims), 'size': np.prod(init.dims) if init.dims else 0 } initializers.append(init_info) return { 'model_info': model_info, 'nodes': nodes, 'op_types': dict(op_types), 'inputs': inputs, 'outputs': outputs, 'value_info': value_info, 'initializers': initializers, 'total_nodes': len(nodes), 'total_parameters': sum(init['size'] for init in initializers) } except Exception as e: return {'error': f"Failed to analyze ONNX model: {str(e)}"} def create_onnx_description(onnx_analysis: dict) -> str: """ Creates a comprehensive description from ONNX analysis for LLM processing. """ if 'error' in onnx_analysis: return f"ONNX Analysis Error: {onnx_analysis['error']}" description = f""" # ONNX Model Analysis Report ## Model Information - Producer: {onnx_analysis['model_info']['producer_name']} v{onnx_analysis['model_info']['producer_version']} - IR Version: {onnx_analysis['model_info']['ir_version']} - Domain: {onnx_analysis['model_info']['domain']} - Total Nodes: {onnx_analysis['total_nodes']} - Total Parameters: {onnx_analysis['total_parameters']:,} ## Architecture Overview The model contains {len(onnx_analysis['op_types'])} different operation types: {chr(10).join([f"- {op}: {count} nodes" for op, count in onnx_analysis['op_types'].items()])} ## Input/Output Specification ### Inputs: {chr(10).join([f"- {inp['name']}: shape {inp['shape']}, type {inp['type']}" for inp in onnx_analysis['inputs']])} ### Outputs: {chr(10).join([f"- {out['name']}: shape {out['shape']}, type {out['type']}" for out in onnx_analysis['outputs']])} ## Detailed Node Structure {chr(10).join([f"Node {i}: {node['op_type']} ({node['name']})" for i, node in enumerate(onnx_analysis['nodes'][:10])])} {'...' if len(onnx_analysis['nodes']) > 10 else ''} ## Key Architectural Patterns Based on the operation types and structure, this appears to be a {_infer_model_type(onnx_analysis['op_types'])} model. """ return description def _infer_model_type(op_types: dict) -> str: """Infers the model type based on operation types.""" if any(op in op_types for op in ['Conv', 'ConvTranspose', 'MaxPool', 'AveragePool']): if any(op in op_types for op in ['LSTM', 'GRU', 'RNN']): return "Hybrid CNN-RNN" elif 'Attention' in op_types or 'MatMul' in op_types: return "CNN with Attention/Transformer components" else: return "Convolutional Neural Network (CNN)" elif any(op in op_types for op in ['LSTM', 'GRU', 'RNN']): return "Recurrent Neural Network (RNN/LSTM/GRU)" elif any(op in op_types for op in ['Attention', 'MatMul']) and 'Reshape' in op_types: return "Transformer/Attention-based model" elif 'MatMul' in op_types and 'Add' in op_types: return "Feed-forward Neural Network" else: return "Custom/Mixed architecture" # --- 3. ENHANCED VISUALIZATION PROMPTS --- def get_visualization_prompt(analysis_type: str) -> str: """Returns the appropriate system prompt for generating visualization code.""" prompts = { "shap": """You are an expert data scientist. Based on the provided model description, generate a Python function `generate_shap_plot()` that creates a SHAP-style feature importance bar chart using only the Plotly library. **Constraints:** - The function must take no arguments. - It must return a Plotly Figure object (`go.Figure`). - **Crucially, do NOT import any external libraries like `shap`.** Use `plotly.graph_objects` (`go`) and `numpy` (`np`). - Create realistic placeholder data (e.g., feature names and importance values) within the function. Do not assume access to a live model object. - Make the chart visually appealing with proper titles and labels. - For ONNX models, base feature names on the actual input/output tensor names if provided. The output should ONLY be the complete Python code for the function. """, "lime": """You are an expert data scientist. Based on the model description, generate a Python function `generate_lime_plot()` that creates a LIME-style local interpretation horizontal bar chart using only Plotly. **Constraints:** - The function must take no arguments and return a `go.Figure` object. - **Do not import external libraries like `lime`.** - Generate realistic placeholder data for component contributions (both positive and negative) inside the function. - For ONNX models, consider the actual model operations and create relevant feature contributions. - Create visually appealing charts with proper titles and colors. The output should ONLY be the Python code for the function. """, "attention": """You are an expert in transformer architectures. Based on the model description, generate a Python function `generate_attention_plot()` that creates an attention heatmap using Plotly Express. **Constraints:** - The function must take no arguments and return a `px.imshow` Figure object. - Generate a realistic placeholder 2D numpy array for the attention weights inside the function. - For ONNX models with attention mechanisms, create appropriate attention patterns. - For non-attention models, generate a plausible feature correlation matrix. - Use proper labels and titles for the heatmap. The output should ONLY be the complete Python code for the function. """, "architecture": """You are a neural network architect. Based on the model description, generate a Python function `generate_architecture_plot()` that creates a complex, detailed architecture flow diagram using Plotly. **Enhanced Requirements for Complex Architecture:** - The function must take no arguments and return a `go.Figure` object. - Create a multi-layered, hierarchical visualization showing different layer types - Use different colors and shapes for different layer types - Show connections between layers with arrows - Include layer dimensions/parameters as annotations - For ONNX models, use the actual operation types and create a detailed flow - Make it visually impressive with proper spacing and professional styling The output should ONLY be the complete Python code for the function. """, "parameter": """You are a deep learning engineer. Based on the model description, generate a Python function `generate_parameter_plot()` that creates a comprehensive parameter distribution visualization using Plotly. **Enhanced Requirements:** - The function must take no arguments and return a `go.Figure` object. - Create a subplot with multiple visualizations: * Donut chart for parameter distribution across layers * Bar chart for layer-wise parameter counts * Histogram of parameter magnitudes (simulated) - For ONNX models, use actual initializer information if available - Make it visually impressive with proper styling The output should ONLY be the complete Python code for the function. """ } return prompts.get(analysis_type, "") # --- 4. LLM COMMUNICATION FUNCTIONS --- def get_llm_response(prompt: str, model_description: str, client_name: str) -> str: """Generic function to get a response from the selected LLM.""" try: if client_name == "Claude (Anthropic)" and anthropic_client: message = anthropic_client.messages.create( model="claude-3-5-sonnet-20240620", max_tokens=3000, temperature=0.1, system=prompt + "\n\nIMPORTANT: Respond with ONLY the Python code, no explanations or markdown formatting.", messages=[{"role": "user", "content": model_description}] ) return message.content[0].text elif client_name == "GPT (OpenAI)" and openai_client: response = openai_client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": prompt + "\n\nIMPORTANT: Respond with ONLY the Python code, no explanations or markdown formatting."}, {"role": "user", "content": model_description} ], max_tokens=3000, temperature=0.1 ) return response.choices[0].message.content elif client_name == "Gemini (Google)" and GEMINI_API_KEY: model = genai.GenerativeModel('gemini-1.5-flash') full_prompt = f"{prompt}\n\nIMPORTANT: Respond with ONLY the Python code, no explanations or markdown formatting.\n\nModel Description: {model_description}" response = model.generate_content(full_prompt) return response.text elif client_name == "Mistral (Mistral)" and mistral_client: messages = [ {"role": "system", "content": prompt + "\n\nIMPORTANT: Respond with ONLY the Python code, no explanations or markdown formatting."}, {"role": "user", "content": model_description} ] response = mistral_client.chat( model="mistral-large-latest", messages=messages, temperature=0.1 ) return response.choices[0].message.content else: return f"Error: {client_name} API key not configured or client unavailable." except Exception as e: return f"Error communicating with {client_name}: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" def extract_python_code(response: str) -> str: """Extracts Python code from a markdown-formatted string.""" # Remove markdown code blocks pattern = r'```python\s*\n(.*?)\n```' match = re.search(pattern, response, re.DOTALL) if match: return match.group(1).strip() pattern = r'```python(.*?)```' match = re.search(pattern, response, re.DOTALL) if match: return match.group(1).strip() pattern = r'```\s*(.*?)\s*```' match = re.search(pattern, response, re.DOTALL) if match: code = match.group(1).strip() if 'def ' in code or 'import' in code: return code # If no code blocks found, try to extract function definition lines = response.split('\n') code_lines = [] in_code = False for line in lines: if (line.strip().startswith('import ') or line.strip().startswith('from ') or line.strip().startswith('def ')): in_code = True if in_code: code_lines.append(line) if code_lines: extracted_code = '\n'.join(code_lines).strip() if 'def ' in extracted_code: return extracted_code # Last resort: return the whole response if it contains function definition if 'def ' in response and ('import' in response or 'plotly' in response): return response.strip() return "" def safely_execute_visualization_code(code: str, plot_type: str, save_dir: str) -> str or None: """ Executes LLM-generated code to create a Plotly viz and saves it as a PNG. Returns the file path of the generated image, or None on failure. """ if not code: print(f"Warning: No {plot_type} visualization code was generated.") return None try: # Define the execution environment exec_globals = { "go": go, "px": px, "np": np, "pd": pd, "make_subplots": make_subplots, "__builtins__": __builtins__ } # Execute the function definition exec(code, exec_globals) # Find the function name func_name_match = re.search(r'def (\w+)\(', code) if not func_name_match: raise NameError("Could not find function definition in the generated code.") func_name = func_name_match.group(1) if func_name not in exec_globals: raise NameError(f"Function {func_name} was not properly defined.") # Call the generated function to get the figure object fig = exec_globals[func_name]() if not hasattr(fig, 'write_image'): raise ValueError("Generated function did not return a valid Plotly figure object.") # Define the save path and save the image file_path = os.path.join(save_dir, f"{plot_type}_plot.png") fig.write_image(file_path, width=1200, height=800, scale=2) print(f"Successfully generated and saved image: {file_path}") return file_path except Exception as e: print(f"ERROR executing or saving {plot_type} plot: {str(e)}") print(f"Full traceback:\n{traceback.format_exc()}") return None # --- 5. MAIN PROCESSING FUNCTIONS --- def process_comprehensive_analysis(model_input: str, onnx_file, llms_to_query: list, analysis_types: list) -> tuple: """ Processes both text/code descriptions and ONNX files for comprehensive analysis. """ if not llms_to_query: return "Please select at least one LLM to query.", None, None, None, None, None # Determine input type and create description model_description = "" if onnx_file is not None: # Process ONNX file print(f"Processing ONNX file: {onnx_file.name}") try: onnx_analysis = analyze_onnx_model(onnx_file.name) model_description = create_onnx_description(onnx_analysis) print("ONNX analysis completed successfully") except Exception as e: return f"Error processing ONNX file: {str(e)}", None, None, None, None, None elif model_input.strip(): # Use text input model_description = model_input.strip() else: return "Please provide either a model description/code or upload an ONNX file.", None, None, None, None, None # Create a unique directory for this run's plots run_dir = os.path.join("temp_plots", str(uuid.uuid4())) os.makedirs(run_dir, exist_ok=True) print(f"Created temporary directory for plots: {run_dir}") # Text Analysis Generation text_analysis_prompt = """You are a model analysis expert. Provide a comprehensive technical analysis of the provided model. Cover: 1. **Model Architecture**: Type, key components, and design patterns 2. **Technical Specifications**: Layer types, parameters, complexity 3. **Operational Analysis**: Key operations, computational requirements 4. **Performance Characteristics**: Strengths, limitations, optimization opportunities 5. **Use Case Assessment**: Suitable applications, domain-specific considerations 6. **ONNX-Specific Analysis** (if applicable): Export quality, operator support, optimization potential Provide detailed, technical insights with proper markdown formatting.""" full_response = "" if "comprehensive" in analysis_types: full_response += "# 🧠 Comprehensive AI Model Analysis\n\n" if onnx_file is not None: full_response += "## 📊 ONNX Model Summary\n" full_response += f"**File:** {onnx_file.name}\n" full_response += f"**Analysis Status:** Successfully processed\n\n" for llm in llms_to_query: print(f"Getting text analysis from {llm}...") full_response += f"## 🤖 {llm} - Technical Analysis\n\n" interpretation = get_llm_response(text_analysis_prompt, model_description, llm) full_response += f"{interpretation}\n\n---\n\n" # Visualization Generation viz_llm = llms_to_query[0] print(f"Using {viz_llm} for visualization generation...") viz_outputs = {} viz_types = ["shap", "lime", "attention", "architecture", "parameter"] for viz_type in viz_types: try: prompt = get_visualization_prompt(viz_type) if not prompt: viz_outputs[viz_type] = None continue print(f"Generating {viz_type} visualization...") generated_code_response = get_llm_response(prompt, model_description, viz_llm) if "Error" in generated_code_response: print(f"LLM Error for {viz_type}: {generated_code_response}") viz_outputs[viz_type] = None continue cleaned_code = extract_python_code(generated_code_response) if not cleaned_code: print(f"Could not extract valid code for {viz_type}.") viz_outputs[viz_type] = None continue # Execute code, save plot, and get the file path image_path = safely_execute_visualization_code(cleaned_code, viz_type, run_dir) viz_outputs[viz_type] = image_path except Exception as e: print(f"Unexpected error in main loop for {viz_type}: {e}") viz_outputs[viz_type] = None return ( full_response, viz_outputs.get("shap"), viz_outputs.get("lime"), viz_outputs.get("attention"), viz_outputs.get("architecture"), viz_outputs.get("parameter") ) def process_text_comprehensive_analysis(model_input: str, llms_to_query, analysis_types: list) -> tuple: """ Processes only text/code descriptions for comprehensive analysis (no ONNX support). Accepts llms_to_query as a string or a list for compatibility. """ # Ensure llms_to_query is a list if isinstance(llms_to_query, str): llms_to_query = [llms_to_query] if not llms_to_query: return "Please select at least one LLM to query.", None, None, None, None, None if not model_input.strip(): return "Please provide a model description or code.", None, None, None, None, None model_description = model_input.strip() # Create a unique directory for this run's plots run_dir = os.path.join("temp_plots", str(uuid.uuid4())) os.makedirs(run_dir, exist_ok=True) print(f"Created temporary directory for plots: {run_dir}") # Text Analysis Generation text_analysis_prompt = """You are a model analysis expert. Provide a comprehensive technical analysis of the provided model. Cover: 1. **Model Architecture**: Type, key components, and design patterns 2. **Technical Specifications**: Layer types, parameters, complexity 3. **Operational Analysis**: Key operations, computational requirements 4. **Performance Characteristics**: Strengths, limitations, optimization opportunities 5. **Use Case Assessment**: Suitable applications, domain-specific considerations Provide detailed, technical insights with proper markdown formatting.""" full_response = "" if "comprehensive" in analysis_types: full_response += "# 🧠 Comprehensive AI Model Analysis\n\n" for llm in llms_to_query: print(f"Getting text analysis from {llm}...") full_response += f"## 🤖 {llm} - Technical Analysis\n\n" interpretation = get_llm_response(text_analysis_prompt, model_description, llm) full_response += f"{interpretation}\n\n---\n\n" # Visualization Generation viz_llm = llms_to_query[0] print(f"Using {viz_llm} for visualization generation...") viz_outputs = {} viz_types = ["shap", "lime", "attention", "architecture", "parameter"] for viz_type in viz_types: try: prompt = get_visualization_prompt(viz_type) if not prompt: viz_outputs[viz_type] = None continue print(f"Generating {viz_type} visualization...") generated_code_response = get_llm_response(prompt, model_description, viz_llm) if "Error" in generated_code_response: print(f"LLM Error for {viz_type}: {generated_code_response}") viz_outputs[viz_type] = None continue cleaned_code = extract_python_code(generated_code_response) if not cleaned_code: print(f"Could not extract valid code for {viz_type}.") viz_outputs[viz_type] = None continue # Execute code, save plot, and get the file path image_path = safely_execute_visualization_code(cleaned_code, viz_type, run_dir) viz_outputs[viz_type] = image_path except Exception as e: print(f"Unexpected error in main loop for {viz_type}: {e}") viz_outputs[viz_type] = None return ( full_response, viz_outputs.get("shap"), viz_outputs.get("lime"), viz_outputs.get("attention"), viz_outputs.get("architecture"), viz_outputs.get("parameter") ) # --- 6. ENHANCED GRADIO USER INTERFACE --- with gr.Blocks(theme=gr.themes.Soft(), css=""" .gradio-container {max-width: 1400px !important;} .tab-nav button {font-size: 16px; font-weight: bold;} .output-image {border: 2px solid #e0e0e0; border-radius: 10px; padding: 10px;} .analysis-output {background: #f8f9fa; border-radius: 10px; padding: 20px;} .header-text {text-align: center; color: #2c3e50; margin-bottom: 20px;} """, title="🧠 AI Model Analysis Suite") as app: gr.HTML("""

🧠 AI Model Analysis Suite

Comprehensive AI model analysis using multiple LLMs with advanced visualizations

Upload ONNX models or describe your model architecture for deep technical analysis

""") with gr.Row(): with gr.Column(scale=1): gr.HTML("

📝 Model Input

") with gr.Tabs(): with gr.Tab("Text Description"): model_input = gr.Textbox( label="Model Description or Code", placeholder="Describe your model architecture, paste PyTorch/TensorFlow code, or provide technical specifications...", lines=10, max_lines=20 ) with gr.Tab("ONNX Upload"): onnx_file = gr.File( label="Upload ONNX Model File", file_types=[".onnx"], file_count="single" ) gr.HTML("

🤖 Analysis Configuration

") llm_selection = gr.CheckboxGroup( choices=["Claude (Anthropic)", "GPT (OpenAI)", "Gemini (Google)", "Mistral (Mistral)"], label="Select LLMs for Analysis", value=["Claude (Anthropic)"], info="Choose which LLMs to use for model interpretation" ) analysis_types = gr.CheckboxGroup( choices=["comprehensive", "visualizations"], label="Analysis Types", value=["comprehensive", "visualizations"], info="Select the types of analysis to perform" ) analyze_btn = gr.Button("🚀 Analyze Model", variant="primary", size="lg") with gr.Column(scale=2): gr.HTML("

📊 Analysis Results

") with gr.Tabs(): with gr.Tab("📋 Technical Analysis"): analysis_output = gr.Markdown( label="Comprehensive Analysis", elem_classes=["analysis-output"] ) with gr.Tab("📈 SHAP Analysis"): shap_plot = gr.Image( label="SHAP Feature Importance", elem_classes=["output-image"] ) with gr.Tab("🔍 LIME Interpretation"): lime_plot = gr.Image( label="LIME Local Interpretation", elem_classes=["output-image"] ) with gr.Tab("🎯 Attention Visualization"): attention_plot = gr.Image( label="Attention Heatmap", elem_classes=["output-image"] ) with gr.Tab("🏗️ Architecture Diagram"): architecture_plot = gr.Image( label="Model Architecture", elem_classes=["output-image"] ) with gr.Tab("📊 Parameter Analysis"): parameter_plot = gr.Image( label="Parameter Distribution", elem_classes=["output-image"] ) # Event handlers analyze_btn.click( fn=process_text_comprehensive_analysis, inputs=[model_input, llm_selection, analysis_types], outputs=[analysis_output, shap_plot, lime_plot, attention_plot, architecture_plot, parameter_plot], show_progress=True ) # Add examples gr.HTML("

💡 Example Inputs

") with gr.Row(): with gr.Column(): gr.Examples( examples=[ ["ResNet-50 convolutional neural network with 50 layers, batch normalization, and residual connections for image classification"], ["BERT transformer model with 12 layers, 768 hidden dimensions, and multi-head attention for NLP tasks"], ["LSTM recurrent neural network with 256 hidden units for time series prediction"], ["U-Net architecture with encoder-decoder structure and skip connections for image segmentation"] ], inputs=model_input, label="Model Description Examples" ) if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True, mcp_server=True)