Spaces:
Runtime error
Runtime error
| """ | |
| Model tracing evaluation for computing p-values from neuron matching statistics. | |
| This module runs the model-tracing comparison using the main.py script from model-tracing | |
| to determine structural similarity via p-value analysis. | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import tempfile | |
| import pickle | |
| import statistics | |
| # Check if model-tracing directory exists | |
| model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing') | |
| MODEL_TRACING_AVAILABLE = os.path.exists(model_tracing_path) and os.path.exists(os.path.join(model_tracing_path, 'main.py')) | |
| sys.stderr.write("π§ CHECKING MODEL TRACING AVAILABILITY...\n") | |
| sys.stderr.write(f" - Model tracing path: {model_tracing_path}\n") | |
| sys.stderr.write(f" - Path exists: {os.path.exists(model_tracing_path)}\n") | |
| sys.stderr.write(f" - main.py exists: {os.path.exists(os.path.join(model_tracing_path, 'main.py'))}\n") | |
| sys.stderr.write(f"π― Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n") | |
| sys.stderr.flush() | |
| def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"): | |
| """ | |
| Run model tracing analysis using the main.py script from model-tracing directory. | |
| Runs the exact command: | |
| python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id <ft_model_name> --stat match --align | |
| Args: | |
| ft_model_name: HuggingFace model identifier for the fine-tuned model | |
| revision: Model revision/commit hash | |
| precision: Model precision (float16, bfloat16) | |
| Returns: | |
| tuple: (success: bool, result: float or error_message) | |
| If success, result is the aggregate p-value from aligned test stat | |
| If failure, result is error message | |
| """ | |
| if not MODEL_TRACING_AVAILABLE: | |
| return False, "Model tracing main.py script not available" | |
| try: | |
| sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS VIA SUBPROCESS ===\n") | |
| sys.stderr.write(f"Base model: meta-llama/Llama-2-7b-hf\n") | |
| sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n") | |
| sys.stderr.write(f"Revision: {revision}\n") | |
| sys.stderr.write(f"Precision: {precision}\n") | |
| sys.stderr.flush() | |
| # Create a temporary file for results | |
| with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: | |
| tmp_results_path = tmp_file.name | |
| sys.stderr.write(f"π Temporary results file: {tmp_results_path}\n") | |
| sys.stderr.flush() | |
| # Build the command exactly as user specified | |
| base_model_id = "meta-llama/Llama-2-7b-hf" | |
| # Build the command | |
| cmd = [ | |
| "python", "main.py", | |
| "--base_model_id", base_model_id, | |
| "--ft_model_id", ft_model_name, | |
| "--stat", "match", | |
| "--save", tmp_results_path | |
| ] | |
| # Add revision if not main/default | |
| if revision and revision != "main": | |
| # Note: main.py doesn't seem to have a revision flag, but we log it for reference | |
| sys.stderr.write(f"β οΈ Note: Revision '{revision}' specified but main.py doesn't support --revision flag\n") | |
| sys.stderr.flush() | |
| sys.stderr.write(f"π Running command: {' '.join(cmd)}\n") | |
| sys.stderr.flush() | |
| # Change to model-tracing directory and run the command | |
| original_cwd = os.getcwd() | |
| try: | |
| os.chdir(model_tracing_path) | |
| sys.stderr.write(f"π Changed to directory: {model_tracing_path}\n") | |
| sys.stderr.flush() | |
| # Run the subprocess | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| timeout=3600 # 1 hour timeout | |
| ) | |
| sys.stderr.write(f"π Subprocess completed with return code: {result.returncode}\n") | |
| # Log stdout and stderr from the subprocess | |
| if result.stdout: | |
| sys.stderr.write(f"π STDOUT from model tracing:\n{result.stdout}\n") | |
| if result.stderr: | |
| sys.stderr.write(f"β οΈ STDERR from model tracing:\n{result.stderr}\n") | |
| sys.stderr.flush() | |
| if result.returncode != 0: | |
| error_msg = f"Model tracing script failed with return code {result.returncode}" | |
| if result.stderr: | |
| error_msg += f"\nSTDERR: {result.stderr}" | |
| return False, error_msg | |
| finally: | |
| os.chdir(original_cwd) | |
| sys.stderr.write(f"π Changed back to directory: {original_cwd}\n") | |
| sys.stderr.flush() | |
| # Load and parse the results | |
| try: | |
| sys.stderr.write(f"π Loading results from: {tmp_results_path}\n") | |
| sys.stderr.flush() | |
| with open(tmp_results_path, 'rb') as f: | |
| results = pickle.load(f) | |
| sys.stderr.write(f"β Results loaded successfully\n") | |
| sys.stderr.write(f"π Available result keys: {list(results.keys())}\n") | |
| sys.stderr.flush() | |
| # Get the aligned test stat (this is what we want with --align flag) | |
| if "aligned test stat" in results: | |
| aligned_stat = results["aligned test stat"] | |
| sys.stderr.write(f"π Aligned test stat: {aligned_stat}\n") | |
| sys.stderr.write(f"π Type: {type(aligned_stat)}\n") | |
| # The match statistic returns a list of p-values per layer | |
| if isinstance(aligned_stat, list): | |
| sys.stderr.write(f"π List of {len(aligned_stat)} p-values: {aligned_stat}\n") | |
| # Filter valid p-values | |
| valid_p_values = [p for p in aligned_stat if p is not None and isinstance(p, (int, float)) and 0 <= p <= 1] | |
| sys.stderr.write(f"π Valid p-values: {len(valid_p_values)}/{len(aligned_stat)}\n") | |
| if valid_p_values: | |
| # Use median as the representative p-value | |
| aggregate_p_value = statistics.median(valid_p_values) | |
| sys.stderr.write(f"π Using median p-value: {aggregate_p_value}\n") | |
| else: | |
| sys.stderr.write("β οΈ No valid p-values found, using default\n") | |
| aggregate_p_value = 1.0 | |
| elif isinstance(aligned_stat, (int, float)): | |
| aggregate_p_value = float(aligned_stat) | |
| sys.stderr.write(f"π Using single p-value: {aggregate_p_value}\n") | |
| else: | |
| sys.stderr.write(f"β οΈ Unexpected aligned_stat type: {type(aligned_stat)}, using default\n") | |
| aggregate_p_value = 1.0 | |
| else: | |
| sys.stderr.write("β οΈ No 'aligned test stat' found in results, checking non-aligned\n") | |
| if "non-aligned test stat" in results: | |
| non_aligned_stat = results["non-aligned test stat"] | |
| sys.stderr.write(f"π Using non-aligned test stat: {non_aligned_stat}\n") | |
| if isinstance(non_aligned_stat, list): | |
| valid_p_values = [p for p in non_aligned_stat if p is not None and isinstance(p, (int, float)) and 0 <= p <= 1] | |
| if valid_p_values: | |
| aggregate_p_value = statistics.median(valid_p_values) | |
| else: | |
| aggregate_p_value = 1.0 | |
| else: | |
| aggregate_p_value = float(non_aligned_stat) if isinstance(non_aligned_stat, (int, float)) else 1.0 | |
| else: | |
| sys.stderr.write("β No test stat found in results\n") | |
| return False, "No test statistic found in results" | |
| sys.stderr.flush() | |
| except Exception as e: | |
| sys.stderr.write(f"β Failed to load results: {e}\n") | |
| sys.stderr.flush() | |
| return False, f"Failed to load results: {e}" | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| os.unlink(tmp_results_path) | |
| sys.stderr.write(f"ποΈ Cleaned up temporary file: {tmp_results_path}\n") | |
| except: | |
| pass | |
| sys.stderr.write(f"β Final aggregate p-value: {aggregate_p_value}\n") | |
| sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n") | |
| sys.stderr.flush() | |
| return True, aggregate_p_value | |
| except subprocess.TimeoutExpired: | |
| sys.stderr.write("β Model tracing analysis timed out after 1 hour\n") | |
| sys.stderr.flush() | |
| return False, "Analysis timed out" | |
| except Exception as e: | |
| error_msg = str(e) | |
| sys.stderr.write(f"π₯ Error in model trace analysis: {error_msg}\n") | |
| import traceback | |
| sys.stderr.write(f"Traceback: {traceback.format_exc()}\n") | |
| sys.stderr.flush() | |
| return False, error_msg | |
| def compute_model_trace_p_value(model_name, revision="main", precision="float16"): | |
| """ | |
| Wrapper function to compute model trace p-value for a single model. | |
| Args: | |
| model_name: HuggingFace model identifier | |
| revision: Model revision | |
| precision: Model precision | |
| Returns: | |
| float or None: P-value if successful, None if failed | |
| """ | |
| sys.stderr.write(f"\n{'='*60}\n") | |
| sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n") | |
| sys.stderr.write(f"Model: {model_name}\n") | |
| sys.stderr.write(f"Revision: {revision}\n") | |
| sys.stderr.write(f"Precision: {precision}\n") | |
| sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n") | |
| sys.stderr.write(f"{'='*60}\n") | |
| sys.stderr.flush() | |
| if not MODEL_TRACING_AVAILABLE: | |
| sys.stderr.write("β MODEL TRACING NOT AVAILABLE - returning None\n") | |
| sys.stderr.flush() | |
| return None | |
| try: | |
| sys.stderr.write("π Starting model trace analysis...\n") | |
| sys.stderr.flush() | |
| success, result = run_model_trace_analysis(model_name, revision, precision) | |
| sys.stderr.write(f"π Analysis completed - Success: {success}, Result: {result}\n") | |
| sys.stderr.flush() | |
| if success: | |
| sys.stderr.write(f"β SUCCESS: Returning p-value {result}\n") | |
| sys.stderr.flush() | |
| return result | |
| else: | |
| sys.stderr.write(f"β FAILED: {result}\n") | |
| sys.stderr.write("π Returning None as fallback\n") | |
| sys.stderr.flush() | |
| return None | |
| except Exception as e: | |
| sys.stderr.write(f"π₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n") | |
| sys.stderr.write(f"Exception: {e}\n") | |
| import traceback | |
| sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n") | |
| sys.stderr.write("π Returning None as fallback\n") | |
| sys.stderr.flush() | |
| return None |