Sarthak
commited on
Commit
Β·
ea0b2a0
1
Parent(s):
0b74f1f
feat: created a cli to manage the complete generation process
Browse files- patches/model2vec.patch +39 -0
- src/distiller/__init__.py +7 -0
- src/distiller/__main__.py +183 -0
- src/distiller/analyze.py +1495 -0
- src/distiller/beam_utils.py +753 -0
- src/distiller/benchmark.py +1181 -0
- src/distiller/distill.py +1306 -0
- src/distiller/distill_simplified.py +413 -0
- src/distiller/evaluate.py +839 -0
- src/distiller/patch_utils.py +145 -0
- src/distiller/sync.py +262 -0
patches/model2vec.patch
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--- a/model2vec/train/base.py
|
| 2 |
+
+++ b/model2vec/train/base.py
|
| 3 |
+
@@ -35,7 +35,7 @@ class FinetunableStaticModel(nn.Module):
|
| 4 |
+
)
|
| 5 |
+
self.vectors = vectors.float()
|
| 6 |
+
|
| 7 |
+
- self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
|
| 8 |
+
+ self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
|
| 9 |
+
self.head = self.construct_head()
|
| 10 |
+
self.w = self.construct_weights()
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
--- a/model2vec/distill/distillation.py
|
| 13 |
+
+++ b/model2vec/distill/distillation.py
|
| 14 |
+
@@ -137,7 +137,10 @@ def distill_from_model(
|
| 15 |
+
# Get the language from the model card.
|
| 16 |
+
try:
|
| 17 |
+
info = model_info(model_name)
|
| 18 |
+
- language = info.cardData.get("language", None)
|
| 19 |
+
+ if info is not None and hasattr(info, 'cardData') and info.cardData is not None:
|
| 20 |
+
+ language = info.cardData.get("language", None)
|
| 21 |
+
+ else:
|
| 22 |
+
+ language = None
|
| 23 |
+
except RepositoryNotFoundError:
|
| 24 |
+
logger.info("No model info found for the model. Setting language to None.")
|
| 25 |
+
language = None
|
| 26 |
+
--- a/model2vec/distill/inference.py
|
| 27 |
+
+++ b/model2vec/distill/inference.py
|
| 28 |
+
@@ -109,5 +109,12 @@ def create_embeddings(
|
| 29 |
+
out_tokens.extend([Token(x, False) for x in tokens])
|
| 30 |
+
out_weights = np.stack(intermediate_weights)
|
| 31 |
+
|
| 32 |
+
+ # Validate token-vector consistency to prevent failures
|
| 33 |
+
+ if len(out_tokens) != out_weights.shape[0]:
|
| 34 |
+
+ logger.warning(f"Token-vector mismatch: {len(out_tokens)} tokens vs {out_weights.shape[0]} vectors. Truncating to prevent failure.")
|
| 35 |
+
+ min_count = min(len(out_tokens), out_weights.shape[0])
|
| 36 |
+
+ out_tokens = out_tokens[:min_count]
|
| 37 |
+
+ out_weights = out_weights[:min_count]
|
| 38 |
+
+
|
| 39 |
+
return out_tokens, out_weights
|
src/distiller/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model2Vec Distillation Pipeline for gte-Qwen2-7B-instruct."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
| 4 |
+
|
| 5 |
+
from .distill import beam_code_distillation, code_specialized_distillation
|
| 6 |
+
|
| 7 |
+
__all__ = ["beam_code_distillation", "code_specialized_distillation"]
|
src/distiller/__main__.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main entry point for the distiller package."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main() -> None:
|
| 8 |
+
"""Main entry point for the distiller package."""
|
| 9 |
+
parser = argparse.ArgumentParser(description="Model2Vec Code-Specialized Distillation Pipeline")
|
| 10 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 11 |
+
|
| 12 |
+
# Distillation command
|
| 13 |
+
distill_parser = subparsers.add_parser("distill", help="Run code-specialized model distillation")
|
| 14 |
+
distill_parser.add_argument("--model", default="Alibaba-NLP/gte-Qwen2-7B-instruct", help="Model to distill")
|
| 15 |
+
distill_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code", help="Output directory")
|
| 16 |
+
distill_parser.add_argument("--pca-dims", type=int, default=512, help="PCA dimensions")
|
| 17 |
+
distill_parser.add_argument("--max-samples", type=int, default=50000, help="Max CodeSearchNet samples")
|
| 18 |
+
distill_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud GPU distillation")
|
| 19 |
+
|
| 20 |
+
# Simplified distillation command
|
| 21 |
+
simple_parser = subparsers.add_parser("distill-simple", help="Run simplified Model2Vec distillation (local)")
|
| 22 |
+
simple_parser.add_argument(
|
| 23 |
+
"--teacher", default="sentence-transformers/all-MiniLM-L6-v2", help="Teacher model to distill from"
|
| 24 |
+
)
|
| 25 |
+
simple_parser.add_argument("--output-dir", default="gte_qwen2_m2v_code_simplified", help="Output directory")
|
| 26 |
+
simple_parser.add_argument("--pca-dims", type=int, default=256, help="PCA dimensions")
|
| 27 |
+
|
| 28 |
+
# CodeSearchNet evaluation command
|
| 29 |
+
evaluate_parser = subparsers.add_parser("evaluate", help="Run CodeSearchNet evaluation on all default models")
|
| 30 |
+
evaluate_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
|
| 31 |
+
|
| 32 |
+
# CodeSearchNet evaluation command (simplified models only)
|
| 33 |
+
evaluate_simple_parser = subparsers.add_parser(
|
| 34 |
+
"evaluate-simple", help="Run CodeSearchNet evaluation on simplified models only"
|
| 35 |
+
)
|
| 36 |
+
evaluate_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud evaluation")
|
| 37 |
+
|
| 38 |
+
# Analysis command
|
| 39 |
+
analysis_parser = subparsers.add_parser("analyze", help="Generate CodeSearchNet analysis report")
|
| 40 |
+
analysis_parser.add_argument("--results-dir", default="code_evaluation_results", help="Results directory")
|
| 41 |
+
analysis_parser.add_argument("--results-file", help="Single results file to analyze")
|
| 42 |
+
analysis_parser.add_argument("--model-name", default="gte_qwen2_m2v_code", help="Model name for report")
|
| 43 |
+
analysis_parser.add_argument("--output", default="README.md", help="Output report file")
|
| 44 |
+
analysis_parser.add_argument("--export-csv", help="Export comparison results to CSV")
|
| 45 |
+
analysis_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud analysis")
|
| 46 |
+
|
| 47 |
+
# Sync command
|
| 48 |
+
sync_parser = subparsers.add_parser("sync", help="Download files from Beam volume to local directory")
|
| 49 |
+
sync_parser.add_argument("--model-files", action="store_true", help="Download final model files")
|
| 50 |
+
sync_parser.add_argument(
|
| 51 |
+
"--analysis-files",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Download analysis reports and charts",
|
| 54 |
+
)
|
| 55 |
+
sync_parser.add_argument("--all", action="store_true", help="Download all generated files")
|
| 56 |
+
sync_parser.add_argument("--output-dir", default=".", help="Local output directory")
|
| 57 |
+
|
| 58 |
+
# Benchmark command
|
| 59 |
+
benchmark_parser = subparsers.add_parser("benchmark", help="Run performance benchmarking on all default models")
|
| 60 |
+
benchmark_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
|
| 61 |
+
|
| 62 |
+
# Benchmark command (simplified models only)
|
| 63 |
+
benchmark_simple_parser = subparsers.add_parser(
|
| 64 |
+
"benchmark-simple", help="Run performance benchmarking on simplified models only"
|
| 65 |
+
)
|
| 66 |
+
benchmark_simple_parser.add_argument("--use-beam", action="store_true", help="Use Beam for cloud benchmarking")
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
if args.command == "distill":
|
| 71 |
+
from .distill_simplified import run_local_distillation, beam_distill_all_teachers
|
| 72 |
+
|
| 73 |
+
if args.use_beam:
|
| 74 |
+
# Run on Beam
|
| 75 |
+
print("Running comprehensive teacher model distillation on Beam...")
|
| 76 |
+
results = beam_distill_all_teachers()
|
| 77 |
+
else:
|
| 78 |
+
# Run locally
|
| 79 |
+
print("Running comprehensive teacher model distillation locally...")
|
| 80 |
+
results = run_local_distillation()
|
| 81 |
+
|
| 82 |
+
print(f"β
Distillation complete! Created {results['total_successful']} models")
|
| 83 |
+
print("π Models location: ./code_model2vec/final/")
|
| 84 |
+
print("\nβ
Created models:")
|
| 85 |
+
for model_name in results["successful_models"]:
|
| 86 |
+
model_info = results["all_results"][model_name]
|
| 87 |
+
print(f" β’ {model_name} (from {model_info['teacher_model']})")
|
| 88 |
+
|
| 89 |
+
elif args.command == "distill-simple":
|
| 90 |
+
from .distill_simplified import run_local_distillation
|
| 91 |
+
|
| 92 |
+
# Run simplified distillation for all teacher models locally
|
| 93 |
+
print("Running comprehensive teacher model distillation locally...")
|
| 94 |
+
results = run_local_distillation()
|
| 95 |
+
print(f"β
Distillation complete! Created {results['total_successful']} models")
|
| 96 |
+
print("π Models location: ./code_model2vec/final/")
|
| 97 |
+
print("\nβ
Created models:")
|
| 98 |
+
for model_name in results["successful_models"]:
|
| 99 |
+
model_info = results["all_results"][model_name]
|
| 100 |
+
print(f" β’ {model_name} (from {model_info['teacher_model']})")
|
| 101 |
+
|
| 102 |
+
elif args.command == "evaluate":
|
| 103 |
+
from .evaluate import main as evaluate_main, run_local_evaluation
|
| 104 |
+
|
| 105 |
+
if args.use_beam:
|
| 106 |
+
# Run on Beam with all default models
|
| 107 |
+
print("Running comprehensive evaluation on Beam...")
|
| 108 |
+
evaluate_main()
|
| 109 |
+
else:
|
| 110 |
+
# Run locally with all default models
|
| 111 |
+
print("Running comprehensive evaluation locally...")
|
| 112 |
+
run_local_evaluation()
|
| 113 |
+
|
| 114 |
+
elif args.command == "evaluate-simple":
|
| 115 |
+
from .evaluate import evaluate_simplified_only, run_local_evaluation_simplified
|
| 116 |
+
|
| 117 |
+
if args.use_beam:
|
| 118 |
+
# Run on Beam with simplified models only
|
| 119 |
+
print("Running simplified model evaluation on Beam...")
|
| 120 |
+
evaluate_simplified_only()
|
| 121 |
+
else:
|
| 122 |
+
# Run locally with simplified models only
|
| 123 |
+
print("Running simplified model evaluation locally...")
|
| 124 |
+
run_local_evaluation_simplified()
|
| 125 |
+
|
| 126 |
+
elif args.command == "analyze":
|
| 127 |
+
from .analyze import main as analyze_main
|
| 128 |
+
|
| 129 |
+
# Run locally - Override sys.argv to pass arguments to the analyze script
|
| 130 |
+
sys.argv = ["analyze.py"]
|
| 131 |
+
if args.results_dir != "code_evaluation_results":
|
| 132 |
+
sys.argv.extend(["--results-dir", args.results_dir])
|
| 133 |
+
if args.results_file:
|
| 134 |
+
sys.argv.extend(["--results-file", args.results_file])
|
| 135 |
+
if args.model_name != "gte_qwen2_m2v_code":
|
| 136 |
+
sys.argv.extend(["--model-name", args.model_name])
|
| 137 |
+
if args.output != "README.md":
|
| 138 |
+
sys.argv.extend(["--output", args.output])
|
| 139 |
+
if args.export_csv:
|
| 140 |
+
sys.argv.extend(["--export-csv", args.export_csv])
|
| 141 |
+
analyze_main()
|
| 142 |
+
|
| 143 |
+
elif args.command == "sync":
|
| 144 |
+
from .sync import sync_files
|
| 145 |
+
|
| 146 |
+
# Run locally
|
| 147 |
+
sync_files(
|
| 148 |
+
model_files=args.model_files,
|
| 149 |
+
analysis_files=args.analysis_files,
|
| 150 |
+
all_files=args.all,
|
| 151 |
+
output_dir=args.output_dir,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
elif args.command == "benchmark":
|
| 155 |
+
from .benchmark import main as benchmark_main, run_local_benchmark
|
| 156 |
+
|
| 157 |
+
if args.use_beam:
|
| 158 |
+
# Run on Beam with all default models
|
| 159 |
+
print("Running comprehensive benchmarking on Beam...")
|
| 160 |
+
benchmark_main()
|
| 161 |
+
else:
|
| 162 |
+
# Run locally with all default models
|
| 163 |
+
print("Running comprehensive benchmarking locally...")
|
| 164 |
+
run_local_benchmark()
|
| 165 |
+
|
| 166 |
+
elif args.command == "benchmark-simple":
|
| 167 |
+
from .benchmark import benchmark_simplified_only, run_local_benchmark_simplified
|
| 168 |
+
|
| 169 |
+
if args.use_beam:
|
| 170 |
+
# Run on Beam with simplified models only
|
| 171 |
+
print("Running simplified model benchmarking on Beam...")
|
| 172 |
+
benchmark_simplified_only()
|
| 173 |
+
else:
|
| 174 |
+
# Run locally with simplified models only
|
| 175 |
+
print("Running simplified model benchmarking locally...")
|
| 176 |
+
run_local_benchmark_simplified()
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
parser.print_help()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
main()
|
src/distiller/analyze.py
ADDED
|
@@ -0,0 +1,1495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive CodeSearchNet Analysis and Reporting Script.
|
| 3 |
+
|
| 4 |
+
This script provides a complete CodeSearchNet evaluation pipeline that includes:
|
| 5 |
+
1. Model evaluation results analysis
|
| 6 |
+
2. Peer model comparison analysis
|
| 7 |
+
3. Advanced visualizations and charts
|
| 8 |
+
4. Leaderboard comparison and ranking analysis
|
| 9 |
+
5. Comprehensive README report generation
|
| 10 |
+
6. Performance efficiency analysis
|
| 11 |
+
7. Language-specific performance analysis
|
| 12 |
+
|
| 13 |
+
Features:
|
| 14 |
+
- CodeSearchNet-style scoring (NDCG@10, MRR, Recall metrics)
|
| 15 |
+
- Comparison with peer code-specialized models
|
| 16 |
+
- Model efficiency metrics (performance per parameter)
|
| 17 |
+
- Interactive visualizations with Plotly and Matplotlib
|
| 18 |
+
- Professional charts for README integration
|
| 19 |
+
- Statistical analysis of results across programming languages
|
| 20 |
+
|
| 21 |
+
Usage:
|
| 22 |
+
python analyze.py --results-dir results/ --model-name my_model
|
| 23 |
+
distiller analyze --results-dir evaluation_results
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import time
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Any
|
| 32 |
+
|
| 33 |
+
import matplotlib.pyplot as plt
|
| 34 |
+
import numpy as np
|
| 35 |
+
import pandas as pd
|
| 36 |
+
import seaborn as sns
|
| 37 |
+
|
| 38 |
+
# Optional Plotly import with fallback
|
| 39 |
+
PLOTLY_AVAILABLE = True
|
| 40 |
+
try:
|
| 41 |
+
import plotly.graph_objects as go
|
| 42 |
+
except ImportError:
|
| 43 |
+
PLOTLY_AVAILABLE = False
|
| 44 |
+
|
| 45 |
+
# Set plotting style
|
| 46 |
+
try:
|
| 47 |
+
plt.style.use("seaborn-v0_8")
|
| 48 |
+
except OSError:
|
| 49 |
+
plt.style.use("seaborn") # Fallback for older matplotlib versions
|
| 50 |
+
sns.set_palette("husl")
|
| 51 |
+
|
| 52 |
+
# =============================================================================
|
| 53 |
+
# CONFIGURATION
|
| 54 |
+
# =============================================================================
|
| 55 |
+
|
| 56 |
+
# Constants
|
| 57 |
+
MIN_SCORES_FOR_STATS = 2
|
| 58 |
+
HIGH_PERFORMANCE_THRESHOLD = 0.3
|
| 59 |
+
MEDIUM_PERFORMANCE_THRESHOLD = 0.2
|
| 60 |
+
|
| 61 |
+
# Model Configuration
|
| 62 |
+
MODEL_NAME = "code_model2vec_analysis" # Generic name for multi-model analysis
|
| 63 |
+
ORIGINAL_MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct"
|
| 64 |
+
OUTPUT_DIR = Path("analysis_results")
|
| 65 |
+
IMAGES_DIR = Path("analysis_charts")
|
| 66 |
+
REPORT_FILE = Path("REPORT.md") # Changed from README.md
|
| 67 |
+
|
| 68 |
+
# Local directories for results - updated for new structure
|
| 69 |
+
DEFAULT_EVALUATION_DIR = "code_model2vec/evaluation_results"
|
| 70 |
+
DEFAULT_BENCHMARK_DIR = "code_model2vec/benchmark_results"
|
| 71 |
+
|
| 72 |
+
# CodeSearchNet Languages
|
| 73 |
+
CODE_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
|
| 74 |
+
|
| 75 |
+
# Model name mapping from the default models in evaluate.py and benchmark.py
|
| 76 |
+
MODEL_NAME_MAPPING = {
|
| 77 |
+
# File names to display names
|
| 78 |
+
"gte_qwen2_m2v_code": "gte_qwen2_m2v_code (Ours)",
|
| 79 |
+
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
| 80 |
+
"codebert-base": "microsoft/codebert-base",
|
| 81 |
+
"graphcodebert-base": "microsoft/graphcodebert-base",
|
| 82 |
+
"CodeBERTa-small-v1": "huggingface/CodeBERTa-small-v1",
|
| 83 |
+
"all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
|
| 84 |
+
"all-MiniLM-L12-v2": "sentence-transformers/all-MiniLM-L12-v2",
|
| 85 |
+
"potion-base-8M": "minishlab/potion-base-8M",
|
| 86 |
+
"potion-retrieval-32M": "minishlab/potion-retrieval-32M",
|
| 87 |
+
"codet5-base": "Salesforce/codet5-base",
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Reverse mapping for lookups
|
| 91 |
+
DISPLAY_NAME_TO_FILE = {v: k for k, v in MODEL_NAME_MAPPING.items()}
|
| 92 |
+
|
| 93 |
+
# Peer models for comparison (code-specialized models)
|
| 94 |
+
PEER_MODELS = {
|
| 95 |
+
"sentence-transformers/all-MiniLM-L6-v2": {"overall_ndcg": 0.25, "type": "General"},
|
| 96 |
+
"microsoft/codebert-base": {"overall_ndcg": 0.32, "type": "Code-Specific"},
|
| 97 |
+
"microsoft/graphcodebert-base": {"overall_ndcg": 0.35, "type": "Code-Specific"},
|
| 98 |
+
"huggingface/CodeBERTa-small-v1": {"overall_ndcg": 0.28, "type": "Code-Specific"},
|
| 99 |
+
"sentence-transformers/all-mpnet-base-v2": {"overall_ndcg": 0.27, "type": "General"},
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Model specifications for efficiency analysis
|
| 103 |
+
MODEL_SPECS = {
|
| 104 |
+
"sentence-transformers/all-MiniLM-L6-v2": {"parameters": 22.7, "size_mb": 90},
|
| 105 |
+
"microsoft/codebert-base": {"parameters": 125.0, "size_mb": 500},
|
| 106 |
+
"microsoft/graphcodebert-base": {"parameters": 125.0, "size_mb": 500},
|
| 107 |
+
"huggingface/CodeBERTa-small-v1": {"parameters": 84.0, "size_mb": 340},
|
| 108 |
+
"sentence-transformers/all-mpnet-base-v2": {"parameters": 109.0, "size_mb": 440},
|
| 109 |
+
"Alibaba-NLP/gte-Qwen2-7B-instruct": {"parameters": 7000.0, "size_mb": 13000},
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Distilled model specifications
|
| 113 |
+
DISTILLED_MODEL_SPECS = {
|
| 114 |
+
"parameters": 39.0, # Model2Vec parameters
|
| 115 |
+
"size_mb": 149.0, # Actual model size
|
| 116 |
+
"dimensions": 256, # Model2Vec dimensions
|
| 117 |
+
"original_dimensions": 3584,
|
| 118 |
+
"distillation_method": "Model2Vec",
|
| 119 |
+
"training_dataset": "CodeSearchNet",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# =============================================================================
|
| 123 |
+
# UTILITY FUNCTIONS
|
| 124 |
+
# =============================================================================
|
| 125 |
+
|
| 126 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 127 |
+
logger = logging.getLogger(__name__)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def setup_directories(base_path: Path | None = None) -> tuple[Path, Path, Path]:
|
| 131 |
+
"""Create necessary directories and return their paths."""
|
| 132 |
+
if base_path:
|
| 133 |
+
output_dir = base_path / "analysis_results"
|
| 134 |
+
images_dir = base_path / "analysis_results" / "charts"
|
| 135 |
+
reports_dir = base_path / "analysis_results" / "reports"
|
| 136 |
+
else:
|
| 137 |
+
output_dir = OUTPUT_DIR
|
| 138 |
+
images_dir = IMAGES_DIR
|
| 139 |
+
reports_dir = OUTPUT_DIR / "reports"
|
| 140 |
+
|
| 141 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 142 |
+
images_dir.mkdir(parents=True, exist_ok=True)
|
| 143 |
+
reports_dir.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
|
| 145 |
+
return output_dir, images_dir, reports_dir
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def extract_model_name_from_filename(filename: str) -> str:
|
| 149 |
+
"""Extract and map model name from filename."""
|
| 150 |
+
# Remove prefixes and extensions
|
| 151 |
+
name = filename.replace("codesearchnet_eval_", "").replace("benchmark_", "").replace(".json", "")
|
| 152 |
+
|
| 153 |
+
# Check if it's in our mapping
|
| 154 |
+
if name in MODEL_NAME_MAPPING:
|
| 155 |
+
return MODEL_NAME_MAPPING[name]
|
| 156 |
+
|
| 157 |
+
# Try to find partial matches
|
| 158 |
+
for file_key, display_name in MODEL_NAME_MAPPING.items():
|
| 159 |
+
if file_key in name or name in file_key:
|
| 160 |
+
return display_name
|
| 161 |
+
|
| 162 |
+
# If no mapping found, return the cleaned name
|
| 163 |
+
return name
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class CodeSearchNetAnalyzer:
|
| 167 |
+
"""Analyzer for CodeSearchNet evaluation results and performance benchmarks."""
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
results_dir: str | None = None,
|
| 172 |
+
benchmark_dir: str | None = None,
|
| 173 |
+
images_dir: Path | None = None,
|
| 174 |
+
) -> None:
|
| 175 |
+
"""Initialize analyzer with results directories."""
|
| 176 |
+
self.results_dir = Path(results_dir) if results_dir else Path(DEFAULT_EVALUATION_DIR)
|
| 177 |
+
self.benchmark_dir = Path(benchmark_dir) if benchmark_dir else Path(DEFAULT_BENCHMARK_DIR)
|
| 178 |
+
self.images_dir = images_dir or IMAGES_DIR
|
| 179 |
+
self.results: list[dict[str, Any]] = []
|
| 180 |
+
self.benchmark_results: list[dict[str, Any]] = []
|
| 181 |
+
self.comparison_df: pd.DataFrame | None = None
|
| 182 |
+
self.benchmark_df: pd.DataFrame | None = None
|
| 183 |
+
|
| 184 |
+
def load_benchmark_results(self) -> None:
|
| 185 |
+
"""Load benchmark results from local directory."""
|
| 186 |
+
logger.info("π Loading benchmark results...")
|
| 187 |
+
|
| 188 |
+
if not self.benchmark_dir.exists():
|
| 189 |
+
logger.warning(f"Benchmark directory not found: {self.benchmark_dir}")
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
logger.info(f"π Searching for benchmark files in: {self.benchmark_dir}")
|
| 193 |
+
benchmark_files = list(self.benchmark_dir.glob("benchmark_*.json"))
|
| 194 |
+
logger.info(f"π Found {len(benchmark_files)} benchmark files")
|
| 195 |
+
|
| 196 |
+
for benchmark_file_path in benchmark_files:
|
| 197 |
+
try:
|
| 198 |
+
logger.info(f"π Loading: {benchmark_file_path.name}")
|
| 199 |
+
with benchmark_file_path.open() as f:
|
| 200 |
+
data = json.load(f)
|
| 201 |
+
if data is not None:
|
| 202 |
+
# Update model name with proper mapping
|
| 203 |
+
original_name = data.get("model_name", "Unknown")
|
| 204 |
+
mapped_name = extract_model_name_from_filename(benchmark_file_path.stem)
|
| 205 |
+
data["model_name"] = mapped_name
|
| 206 |
+
data["original_model_name"] = original_name
|
| 207 |
+
|
| 208 |
+
self.benchmark_results.append(data)
|
| 209 |
+
logger.info(f"β
Successfully loaded: {mapped_name}")
|
| 210 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 211 |
+
logger.warning(f"β Failed to load {benchmark_file_path}: {e}")
|
| 212 |
+
|
| 213 |
+
logger.info(f"π Total benchmark results loaded: {len(self.benchmark_results)}")
|
| 214 |
+
if self.benchmark_results:
|
| 215 |
+
model_names = [r.get("model_name", "Unknown") for r in self.benchmark_results]
|
| 216 |
+
logger.info(f"π― Benchmark models found: {', '.join(model_names)}")
|
| 217 |
+
|
| 218 |
+
self._create_benchmark_dataframe()
|
| 219 |
+
|
| 220 |
+
def _create_benchmark_dataframe(self) -> None:
|
| 221 |
+
"""Create benchmark comparison DataFrame from results."""
|
| 222 |
+
if not self.benchmark_results:
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
benchmark_data = []
|
| 226 |
+
for result in self.benchmark_results:
|
| 227 |
+
model_name = result.get("model_name", "Unknown")
|
| 228 |
+
size_metrics = result.get("size_metrics", {})
|
| 229 |
+
speed_benchmarks = result.get("speed_benchmarks", {})
|
| 230 |
+
memory_benchmarks = result.get("memory_benchmarks", {})
|
| 231 |
+
cpu_vs_gpu = result.get("cpu_vs_gpu", {})
|
| 232 |
+
|
| 233 |
+
# Extract key metrics
|
| 234 |
+
row = {
|
| 235 |
+
"Model": model_name,
|
| 236 |
+
"Disk_Size_MB": size_metrics.get("disk_size_mb", 0),
|
| 237 |
+
"Parameters_M": size_metrics.get("parameters_millions", 0),
|
| 238 |
+
"Embedding_Dim": size_metrics.get("embedding_dim", 0),
|
| 239 |
+
"RAM_Usage_MB": size_metrics.get("ram_usage_mb", 0),
|
| 240 |
+
"GPU_Memory_MB": size_metrics.get("gpu_memory_mb", 0),
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# Speed metrics (medium texts, batch 32)
|
| 244 |
+
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
| 245 |
+
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
| 246 |
+
row.update(
|
| 247 |
+
{
|
| 248 |
+
"Throughput_TextsPerSec": batch_32.get("texts_per_second", 0),
|
| 249 |
+
"Latency_MsPerText": batch_32.get("time_per_text_ms", 0),
|
| 250 |
+
"TokenSpeed_TokensPerSec": batch_32.get("tokens_per_second", 0),
|
| 251 |
+
}
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Memory scaling (batch 32)
|
| 255 |
+
if "batch_32" in memory_benchmarks:
|
| 256 |
+
batch_32_mem = memory_benchmarks["batch_32"]
|
| 257 |
+
if not batch_32_mem.get("oom", False) and "error" not in batch_32_mem:
|
| 258 |
+
row.update(
|
| 259 |
+
{
|
| 260 |
+
"Memory_Used_MB": batch_32_mem.get("memory_used_mb", 0),
|
| 261 |
+
"Memory_Per_Text_MB": batch_32_mem.get("memory_per_text_mb", 0),
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# CPU vs GPU comparison
|
| 266 |
+
for device in ["cpu", "cuda"]:
|
| 267 |
+
if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
|
| 268 |
+
device_key = f"{device.upper()}_TextsPerSec"
|
| 269 |
+
row[device_key] = cpu_vs_gpu[device].get("texts_per_second", 0)
|
| 270 |
+
|
| 271 |
+
benchmark_data.append(row)
|
| 272 |
+
|
| 273 |
+
self.benchmark_df = pd.DataFrame(benchmark_data)
|
| 274 |
+
|
| 275 |
+
def load_results(self) -> None:
|
| 276 |
+
"""Load evaluation results from local directory."""
|
| 277 |
+
logger.info("π Loading evaluation results...")
|
| 278 |
+
|
| 279 |
+
if not self.results_dir.exists():
|
| 280 |
+
logger.warning(f"Evaluation directory not found: {self.results_dir}")
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
logger.info(f"π Searching for evaluation files in: {self.results_dir}")
|
| 284 |
+
json_files = list(self.results_dir.glob("codesearchnet_eval_*.json"))
|
| 285 |
+
logger.info(f"π Found {len(json_files)} evaluation files")
|
| 286 |
+
|
| 287 |
+
for json_file in json_files:
|
| 288 |
+
try:
|
| 289 |
+
logger.info(f"π Loading: {json_file.name}")
|
| 290 |
+
with json_file.open() as f:
|
| 291 |
+
data = json.load(f)
|
| 292 |
+
if data is not None:
|
| 293 |
+
# Update model name with proper mapping
|
| 294 |
+
original_name = data.get("model_name", "Unknown")
|
| 295 |
+
mapped_name = extract_model_name_from_filename(json_file.stem)
|
| 296 |
+
data["model_name"] = mapped_name
|
| 297 |
+
data["original_model_name"] = original_name
|
| 298 |
+
|
| 299 |
+
self.results.append(data)
|
| 300 |
+
logger.info(f"β
Successfully loaded: {mapped_name}")
|
| 301 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 302 |
+
logger.warning(f"β Failed to load {json_file}: {e}")
|
| 303 |
+
|
| 304 |
+
logger.info(f"π Total loaded: {len(self.results)} model results")
|
| 305 |
+
if self.results:
|
| 306 |
+
model_names = [r.get("model_name", "Unknown") for r in self.results]
|
| 307 |
+
logger.info(f"π― Models found: {', '.join(model_names)}")
|
| 308 |
+
|
| 309 |
+
self._create_comparison_dataframe()
|
| 310 |
+
|
| 311 |
+
# Also load benchmark results
|
| 312 |
+
self.load_benchmark_results()
|
| 313 |
+
|
| 314 |
+
def _create_comparison_dataframe(self) -> None:
|
| 315 |
+
"""Create comparison DataFrame from results."""
|
| 316 |
+
if not self.results:
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
comparison_data = []
|
| 320 |
+
for result in self.results:
|
| 321 |
+
overall = result.get("overall", {})
|
| 322 |
+
row = {
|
| 323 |
+
"Model": result["model_name"],
|
| 324 |
+
"MRR": overall.get("mrr", 0),
|
| 325 |
+
"NDCG@1": overall.get("ndcg@1", 0),
|
| 326 |
+
"NDCG@5": overall.get("ndcg@5", 0),
|
| 327 |
+
"NDCG@10": overall.get("ndcg@10", 0),
|
| 328 |
+
"Recall@1": overall.get("recall@1", 0),
|
| 329 |
+
"Recall@5": overall.get("recall@5", 0),
|
| 330 |
+
"Recall@10": overall.get("recall@10", 0),
|
| 331 |
+
"Mean_Rank": overall.get("mean_rank", 0),
|
| 332 |
+
"Median_Rank": overall.get("median_rank", 0),
|
| 333 |
+
}
|
| 334 |
+
comparison_data.append(row)
|
| 335 |
+
|
| 336 |
+
self.comparison_df = pd.DataFrame(comparison_data)
|
| 337 |
+
if not self.comparison_df.empty:
|
| 338 |
+
self.comparison_df = self.comparison_df.sort_values("NDCG@10", ascending=False)
|
| 339 |
+
|
| 340 |
+
def print_summary(self) -> None:
|
| 341 |
+
"""Print summary of results."""
|
| 342 |
+
if not self.results:
|
| 343 |
+
logger.warning("No results to summarize")
|
| 344 |
+
return
|
| 345 |
+
|
| 346 |
+
print(f"\n{'=' * 60}")
|
| 347 |
+
print("CodeSearchNet Evaluation Summary")
|
| 348 |
+
print(f"{'=' * 60}")
|
| 349 |
+
print(f"Total models evaluated: {len(self.results)}")
|
| 350 |
+
|
| 351 |
+
if self.comparison_df is not None and not self.comparison_df.empty:
|
| 352 |
+
print(f"\nTop performing model: {self.comparison_df.iloc[0]['Model']}")
|
| 353 |
+
print(f"Best NDCG@10: {self.comparison_df.iloc[0]['NDCG@10']:.4f}")
|
| 354 |
+
print(f"Best MRR: {self.comparison_df['MRR'].max():.4f}")
|
| 355 |
+
|
| 356 |
+
print(f"\nEvaluated languages: {', '.join(CODE_LANGUAGES)}")
|
| 357 |
+
|
| 358 |
+
# Also print benchmark summary if available
|
| 359 |
+
if self.benchmark_results:
|
| 360 |
+
print(f"\n{'=' * 60}")
|
| 361 |
+
print("Performance Benchmark Summary")
|
| 362 |
+
print(f"{'=' * 60}")
|
| 363 |
+
print(f"Total models benchmarked: {len(self.benchmark_results)}")
|
| 364 |
+
|
| 365 |
+
if self.benchmark_df is not None and not self.benchmark_df.empty:
|
| 366 |
+
# Safely get fastest and smallest models
|
| 367 |
+
fastest_model = "N/A"
|
| 368 |
+
smallest_model = "N/A"
|
| 369 |
+
|
| 370 |
+
if "Throughput_TextsPerSec" in self.benchmark_df.columns:
|
| 371 |
+
fastest_idx = self.benchmark_df["Throughput_TextsPerSec"].idxmax()
|
| 372 |
+
fastest_model = str(self.benchmark_df.loc[fastest_idx, "Model"])
|
| 373 |
+
|
| 374 |
+
if "Disk_Size_MB" in self.benchmark_df.columns:
|
| 375 |
+
smallest_idx = self.benchmark_df["Disk_Size_MB"].idxmin()
|
| 376 |
+
smallest_model = str(self.benchmark_df.loc[smallest_idx, "Model"])
|
| 377 |
+
|
| 378 |
+
print(f"\nFastest model: {fastest_model}")
|
| 379 |
+
print(f"Smallest model: {smallest_model}")
|
| 380 |
+
|
| 381 |
+
def analyze_language_performance(self) -> None:
|
| 382 |
+
"""Analyze performance across programming languages."""
|
| 383 |
+
if not self.results:
|
| 384 |
+
return
|
| 385 |
+
|
| 386 |
+
print(f"\n{'=' * 60}")
|
| 387 |
+
print("Language-Specific Performance Analysis")
|
| 388 |
+
print(f"{'=' * 60}")
|
| 389 |
+
|
| 390 |
+
for result in self.results:
|
| 391 |
+
model_name = result["model_name"]
|
| 392 |
+
print(f"\nModel: {model_name}")
|
| 393 |
+
print("-" * 40)
|
| 394 |
+
|
| 395 |
+
languages = result.get("languages", {})
|
| 396 |
+
lang_data = []
|
| 397 |
+
|
| 398 |
+
for lang, lang_results in languages.items():
|
| 399 |
+
metrics = lang_results.get("metrics", {})
|
| 400 |
+
lang_data.append(
|
| 401 |
+
{
|
| 402 |
+
"Language": lang,
|
| 403 |
+
"NDCG@10": metrics.get("ndcg@10", 0),
|
| 404 |
+
"MRR": metrics.get("mrr", 0),
|
| 405 |
+
"Recall@5": metrics.get("recall@5", 0),
|
| 406 |
+
"Queries": lang_results.get("num_queries", 0),
|
| 407 |
+
}
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if lang_data:
|
| 411 |
+
lang_df = pd.DataFrame(lang_data)
|
| 412 |
+
print(lang_df.to_string(index=False, float_format="%.4f"))
|
| 413 |
+
print(f"\nBest language: {lang_df.loc[lang_df['NDCG@10'].idxmax(), 'Language']}")
|
| 414 |
+
print(f"Average NDCG@10: {lang_df['NDCG@10'].mean():.4f}")
|
| 415 |
+
print(f"Average queries per language: {lang_df['Queries'].mean():.0f}")
|
| 416 |
+
|
| 417 |
+
def analyze_benchmark_performance(self) -> None:
|
| 418 |
+
"""Analyze and print benchmark performance summary."""
|
| 419 |
+
if not self.benchmark_results:
|
| 420 |
+
logger.warning("No benchmark results to analyze")
|
| 421 |
+
return
|
| 422 |
+
|
| 423 |
+
print(f"\n{'=' * 60}")
|
| 424 |
+
print("Performance Benchmark Analysis")
|
| 425 |
+
print(f"{'=' * 60}")
|
| 426 |
+
|
| 427 |
+
for result in self.benchmark_results:
|
| 428 |
+
model_name = result.get("model_name", "Unknown")
|
| 429 |
+
print(f"\nModel: {model_name}")
|
| 430 |
+
print("-" * 40)
|
| 431 |
+
|
| 432 |
+
# Size metrics
|
| 433 |
+
size_metrics = result.get("size_metrics", {})
|
| 434 |
+
if size_metrics:
|
| 435 |
+
print("π Model Size:")
|
| 436 |
+
print(f" Disk Size: {size_metrics.get('disk_size_mb', 0):.1f} MB")
|
| 437 |
+
if "parameters_millions" in size_metrics:
|
| 438 |
+
print(f" Parameters: {size_metrics['parameters_millions']:.1f}M")
|
| 439 |
+
if "embedding_dim" in size_metrics:
|
| 440 |
+
print(f" Embedding Dimension: {size_metrics['embedding_dim']}")
|
| 441 |
+
|
| 442 |
+
# Speed metrics
|
| 443 |
+
speed_benchmarks = result.get("speed_benchmarks", {})
|
| 444 |
+
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
| 445 |
+
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
| 446 |
+
print("β‘ Performance (Batch 32, Medium Texts):")
|
| 447 |
+
print(f" Throughput: {batch_32.get('texts_per_second', 0):.1f} texts/sec")
|
| 448 |
+
print(f" Latency: {batch_32.get('time_per_text_ms', 0):.1f} ms/text")
|
| 449 |
+
print(f" Token Speed: {batch_32.get('tokens_per_second', 0):.0f} tokens/sec")
|
| 450 |
+
|
| 451 |
+
# CPU vs GPU
|
| 452 |
+
cpu_vs_gpu = result.get("cpu_vs_gpu", {})
|
| 453 |
+
if cpu_vs_gpu:
|
| 454 |
+
print("π₯οΈ CPU vs GPU:")
|
| 455 |
+
for device, metrics in cpu_vs_gpu.items():
|
| 456 |
+
if "error" not in metrics:
|
| 457 |
+
print(f" {device.upper()}: {metrics.get('texts_per_second', 0):.1f} texts/sec")
|
| 458 |
+
|
| 459 |
+
# Memory efficiency
|
| 460 |
+
memory_benchmarks = result.get("memory_benchmarks", {})
|
| 461 |
+
if "batch_32" in memory_benchmarks:
|
| 462 |
+
batch_32_mem = memory_benchmarks["batch_32"]
|
| 463 |
+
if not batch_32_mem.get("oom", False) and "error" not in batch_32_mem:
|
| 464 |
+
print("πΎ Memory Usage (Batch 32):")
|
| 465 |
+
print(f" Total: {batch_32_mem.get('memory_used_mb', 0):.1f} MB")
|
| 466 |
+
print(f" Per Text: {batch_32_mem.get('memory_per_text_mb', 0):.2f} MB")
|
| 467 |
+
|
| 468 |
+
def create_performance_radar_chart(self, model_name: str, language_scores: dict[str, float]) -> str:
|
| 469 |
+
"""Create radar chart showing performance across languages."""
|
| 470 |
+
if not PLOTLY_AVAILABLE:
|
| 471 |
+
logger.warning("Plotly not available, skipping radar chart")
|
| 472 |
+
return ""
|
| 473 |
+
|
| 474 |
+
languages = list(language_scores.keys())
|
| 475 |
+
scores = list(language_scores.values())
|
| 476 |
+
|
| 477 |
+
if not languages:
|
| 478 |
+
return ""
|
| 479 |
+
|
| 480 |
+
# Close the radar chart
|
| 481 |
+
languages_closed = [*languages, languages[0]]
|
| 482 |
+
scores_closed = [*scores, scores[0]]
|
| 483 |
+
|
| 484 |
+
fig = go.Figure()
|
| 485 |
+
|
| 486 |
+
fig.add_trace(
|
| 487 |
+
go.Scatterpolar(
|
| 488 |
+
r=scores_closed,
|
| 489 |
+
theta=languages_closed,
|
| 490 |
+
fill="toself",
|
| 491 |
+
name=model_name,
|
| 492 |
+
line_color="rgb(67, 147, 195)",
|
| 493 |
+
fillcolor="rgba(67, 147, 195, 0.3)",
|
| 494 |
+
)
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
fig.update_layout(
|
| 498 |
+
polar={"radialaxis": {"visible": True, "range": [0, max(scores) * 1.1]}},
|
| 499 |
+
showlegend=True,
|
| 500 |
+
title=f"CodeSearchNet Performance by Language: {model_name}",
|
| 501 |
+
width=800,
|
| 502 |
+
height=600,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
static_path = self.images_dir / "code_performance_radar.png"
|
| 506 |
+
try:
|
| 507 |
+
fig.write_image(str(static_path), width=800, height=600, scale=2)
|
| 508 |
+
return str(static_path)
|
| 509 |
+
except Exception as e:
|
| 510 |
+
logger.warning(f"Could not create static image: {e}")
|
| 511 |
+
return ""
|
| 512 |
+
|
| 513 |
+
def create_comparative_radar_chart(self, simplified_models: list, peer_models: list) -> str:
|
| 514 |
+
"""Create comparative radar chart between best distilled model and top peer models."""
|
| 515 |
+
if not PLOTLY_AVAILABLE:
|
| 516 |
+
logger.warning("Plotly not available, skipping comparative radar chart")
|
| 517 |
+
return ""
|
| 518 |
+
|
| 519 |
+
if not simplified_models:
|
| 520 |
+
return ""
|
| 521 |
+
|
| 522 |
+
# Get the best simplified model
|
| 523 |
+
best_simplified = max(simplified_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0))
|
| 524 |
+
|
| 525 |
+
# Get top 3 peer models by performance
|
| 526 |
+
peer_models_sorted = sorted(peer_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0), reverse=True)
|
| 527 |
+
top_peers = peer_models_sorted[:3]
|
| 528 |
+
|
| 529 |
+
models_to_compare = [best_simplified, *top_peers]
|
| 530 |
+
|
| 531 |
+
fig = go.Figure()
|
| 532 |
+
|
| 533 |
+
# Define colors for each model
|
| 534 |
+
colors = ["rgb(255, 99, 132)", "rgb(54, 162, 235)", "rgb(255, 205, 86)", "rgb(75, 192, 192)"]
|
| 535 |
+
|
| 536 |
+
for i, model_result in enumerate(models_to_compare):
|
| 537 |
+
model_name = model_result["model_name"]
|
| 538 |
+
languages = model_result.get("languages", {})
|
| 539 |
+
|
| 540 |
+
# Calculate language scores
|
| 541 |
+
language_scores = {}
|
| 542 |
+
for lang, lang_data in languages.items():
|
| 543 |
+
metrics = lang_data.get("metrics", {})
|
| 544 |
+
language_scores[lang.title()] = metrics.get("ndcg@10", 0)
|
| 545 |
+
|
| 546 |
+
if language_scores:
|
| 547 |
+
languages_list = list(language_scores.keys())
|
| 548 |
+
scores_list = list(language_scores.values())
|
| 549 |
+
|
| 550 |
+
# Close the radar chart
|
| 551 |
+
languages_closed = [*languages_list, languages_list[0]]
|
| 552 |
+
scores_closed = [*scores_list, scores_list[0]]
|
| 553 |
+
|
| 554 |
+
# Determine line style - solid for best distilled, dash for peers
|
| 555 |
+
line_dash = "solid" if i == 0 else "dash"
|
| 556 |
+
line_width = 3 if i == 0 else 2
|
| 557 |
+
|
| 558 |
+
fig.add_trace(
|
| 559 |
+
go.Scatterpolar(
|
| 560 |
+
r=scores_closed,
|
| 561 |
+
theta=languages_closed,
|
| 562 |
+
fill="toself" if i == 0 else "none",
|
| 563 |
+
name=model_name,
|
| 564 |
+
line={"color": colors[i % len(colors)], "dash": line_dash, "width": line_width},
|
| 565 |
+
fillcolor=f"rgba{colors[i % len(colors)][3:-1]}, 0.2)" if i == 0 else None,
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
fig.update_layout(
|
| 570 |
+
polar={"radialaxis": {"visible": True, "range": [0, 0.5]}}, # Adjust max range as needed
|
| 571 |
+
showlegend=True,
|
| 572 |
+
title="Model Comparison: Best Distilled vs Top Peer Models",
|
| 573 |
+
width=900,
|
| 574 |
+
height=700,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
static_path = self.images_dir / "comparative_radar.png"
|
| 578 |
+
try:
|
| 579 |
+
fig.write_image(str(static_path), width=900, height=700, scale=2)
|
| 580 |
+
return str(static_path)
|
| 581 |
+
except Exception as e:
|
| 582 |
+
logger.warning(f"Could not create comparative radar chart: {e}")
|
| 583 |
+
return ""
|
| 584 |
+
|
| 585 |
+
def create_individual_radar_charts(self, simplified_models: list) -> dict[str, str]:
|
| 586 |
+
"""Create individual radar charts for all simplified models."""
|
| 587 |
+
radar_charts = {}
|
| 588 |
+
|
| 589 |
+
for result in simplified_models:
|
| 590 |
+
model_name = result["model_name"]
|
| 591 |
+
model_languages = result.get("languages", {})
|
| 592 |
+
model_language_scores = {}
|
| 593 |
+
for lang, lang_data in model_languages.items():
|
| 594 |
+
metrics = lang_data.get("metrics", {})
|
| 595 |
+
model_language_scores[lang.title()] = metrics.get("ndcg@10", 0)
|
| 596 |
+
|
| 597 |
+
if model_language_scores:
|
| 598 |
+
# Create unique filename for each model
|
| 599 |
+
safe_model_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_")).rstrip()
|
| 600 |
+
radar_chart_path = self.create_performance_radar_chart_individual(
|
| 601 |
+
model_name, model_language_scores, safe_model_name
|
| 602 |
+
)
|
| 603 |
+
if radar_chart_path:
|
| 604 |
+
radar_charts[model_name] = radar_chart_path
|
| 605 |
+
|
| 606 |
+
return radar_charts
|
| 607 |
+
|
| 608 |
+
def create_performance_radar_chart_individual(
|
| 609 |
+
self, model_name: str, language_scores: dict[str, float], filename_suffix: str
|
| 610 |
+
) -> str:
|
| 611 |
+
"""Create radar chart for individual model with unique filename."""
|
| 612 |
+
if not PLOTLY_AVAILABLE:
|
| 613 |
+
logger.warning("Plotly not available, skipping radar chart")
|
| 614 |
+
return ""
|
| 615 |
+
|
| 616 |
+
languages = list(language_scores.keys())
|
| 617 |
+
scores = list(language_scores.values())
|
| 618 |
+
|
| 619 |
+
if not languages:
|
| 620 |
+
return ""
|
| 621 |
+
|
| 622 |
+
# Close the radar chart
|
| 623 |
+
languages_closed = [*languages, languages[0]]
|
| 624 |
+
scores_closed = [*scores, scores[0]]
|
| 625 |
+
|
| 626 |
+
fig = go.Figure()
|
| 627 |
+
|
| 628 |
+
fig.add_trace(
|
| 629 |
+
go.Scatterpolar(
|
| 630 |
+
r=scores_closed,
|
| 631 |
+
theta=languages_closed,
|
| 632 |
+
fill="toself",
|
| 633 |
+
name=model_name,
|
| 634 |
+
line_color="rgb(67, 147, 195)",
|
| 635 |
+
fillcolor="rgba(67, 147, 195, 0.3)",
|
| 636 |
+
)
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
fig.update_layout(
|
| 640 |
+
polar={"radialaxis": {"visible": True, "range": [0, max(scores) * 1.1]}},
|
| 641 |
+
showlegend=True,
|
| 642 |
+
title=f"CodeSearchNet Performance by Language: {model_name}",
|
| 643 |
+
width=800,
|
| 644 |
+
height=600,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
static_path = self.images_dir / f"radar_{filename_suffix}.png"
|
| 648 |
+
try:
|
| 649 |
+
fig.write_image(str(static_path), width=800, height=600, scale=2)
|
| 650 |
+
return str(static_path)
|
| 651 |
+
except Exception as e:
|
| 652 |
+
logger.warning(f"Could not create static image for {model_name}: {e}")
|
| 653 |
+
return ""
|
| 654 |
+
|
| 655 |
+
def plot_model_comparison(self, save_path: str | None = None) -> str:
|
| 656 |
+
"""Create comparison plots for models."""
|
| 657 |
+
if self.comparison_df is None or self.comparison_df.empty:
|
| 658 |
+
logger.warning("No comparison data available for plotting")
|
| 659 |
+
return ""
|
| 660 |
+
|
| 661 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
| 662 |
+
fig.suptitle("CodeSearchNet Model Comparison", fontsize=16, fontweight="bold")
|
| 663 |
+
|
| 664 |
+
# NDCG@10 comparison
|
| 665 |
+
axes[0, 0].barh(self.comparison_df["Model"], self.comparison_df["NDCG@10"])
|
| 666 |
+
axes[0, 0].set_title("NDCG@10 Comparison")
|
| 667 |
+
axes[0, 0].set_xlabel("NDCG@10")
|
| 668 |
+
|
| 669 |
+
# MRR comparison
|
| 670 |
+
axes[0, 1].barh(self.comparison_df["Model"], self.comparison_df["MRR"])
|
| 671 |
+
axes[0, 1].set_title("Mean Reciprocal Rank (MRR)")
|
| 672 |
+
axes[0, 1].set_xlabel("MRR")
|
| 673 |
+
|
| 674 |
+
# Recall@5 comparison
|
| 675 |
+
axes[1, 0].barh(self.comparison_df["Model"], self.comparison_df["Recall@5"])
|
| 676 |
+
axes[1, 0].set_title("Recall@5")
|
| 677 |
+
axes[1, 0].set_xlabel("Recall@5")
|
| 678 |
+
|
| 679 |
+
# Mean Rank comparison (lower is better)
|
| 680 |
+
axes[1, 1].barh(self.comparison_df["Model"], self.comparison_df["Mean_Rank"])
|
| 681 |
+
axes[1, 1].set_title("Mean Rank (lower is better)")
|
| 682 |
+
axes[1, 1].set_xlabel("Mean Rank")
|
| 683 |
+
|
| 684 |
+
plt.tight_layout()
|
| 685 |
+
|
| 686 |
+
output_path = save_path or str(self.images_dir / "model_comparison.png")
|
| 687 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 688 |
+
plt.close()
|
| 689 |
+
|
| 690 |
+
return output_path
|
| 691 |
+
|
| 692 |
+
def plot_language_heatmap(self, save_path: str | None = None) -> str:
|
| 693 |
+
"""Create a heatmap of performance across languages."""
|
| 694 |
+
if not self.results:
|
| 695 |
+
return ""
|
| 696 |
+
|
| 697 |
+
# Prepare data for heatmap
|
| 698 |
+
heatmap_data = []
|
| 699 |
+
for result in self.results:
|
| 700 |
+
model_name = result["model_name"]
|
| 701 |
+
languages = result.get("languages", {})
|
| 702 |
+
|
| 703 |
+
row = {"Model": model_name}
|
| 704 |
+
for lang in CODE_LANGUAGES:
|
| 705 |
+
if lang in languages:
|
| 706 |
+
metrics = languages[lang].get("metrics", {})
|
| 707 |
+
row[lang.title()] = metrics.get("ndcg@10", 0)
|
| 708 |
+
else:
|
| 709 |
+
row[lang.title()] = 0
|
| 710 |
+
heatmap_data.append(row)
|
| 711 |
+
|
| 712 |
+
if not heatmap_data:
|
| 713 |
+
return ""
|
| 714 |
+
|
| 715 |
+
df = pd.DataFrame(heatmap_data).set_index("Model")
|
| 716 |
+
|
| 717 |
+
plt.figure(figsize=(12, 8))
|
| 718 |
+
sns.heatmap(
|
| 719 |
+
df,
|
| 720 |
+
annot=True,
|
| 721 |
+
fmt=".3f",
|
| 722 |
+
cmap="RdYlBu_r",
|
| 723 |
+
center=0.2,
|
| 724 |
+
vmin=0,
|
| 725 |
+
vmax=df.to_numpy().max(),
|
| 726 |
+
cbar_kws={"label": "NDCG@10 Score"},
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
plt.title(
|
| 730 |
+
"CodeSearchNet Performance Heatmap by Language",
|
| 731 |
+
fontsize=16,
|
| 732 |
+
fontweight="bold",
|
| 733 |
+
)
|
| 734 |
+
plt.xlabel("Programming Language", fontsize=12)
|
| 735 |
+
plt.ylabel("Model", fontsize=12)
|
| 736 |
+
plt.tight_layout()
|
| 737 |
+
|
| 738 |
+
output_path = save_path or str(self.images_dir / "language_heatmap.png")
|
| 739 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 740 |
+
plt.close()
|
| 741 |
+
|
| 742 |
+
return output_path
|
| 743 |
+
|
| 744 |
+
def plot_benchmark_performance(self, save_path: str | None = None) -> str:
|
| 745 |
+
"""Create comprehensive benchmark performance plots."""
|
| 746 |
+
if not self.benchmark_results:
|
| 747 |
+
logger.warning("No benchmark data available for plotting")
|
| 748 |
+
return ""
|
| 749 |
+
|
| 750 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
| 751 |
+
fig.suptitle("Performance Benchmark Analysis", fontsize=16, fontweight="bold")
|
| 752 |
+
|
| 753 |
+
# 1. Model Size Comparison
|
| 754 |
+
if self.benchmark_df is not None and "Disk_Size_MB" in self.benchmark_df.columns:
|
| 755 |
+
axes[0, 0].barh(self.benchmark_df["Model"], self.benchmark_df["Disk_Size_MB"])
|
| 756 |
+
axes[0, 0].set_title("Model Size (MB)")
|
| 757 |
+
axes[0, 0].set_xlabel("Size (MB)")
|
| 758 |
+
|
| 759 |
+
# 2. Inference Throughput
|
| 760 |
+
if self.benchmark_df is not None and "Throughput_TextsPerSec" in self.benchmark_df.columns:
|
| 761 |
+
axes[0, 1].barh(self.benchmark_df["Model"], self.benchmark_df["Throughput_TextsPerSec"])
|
| 762 |
+
axes[0, 1].set_title("Inference Throughput")
|
| 763 |
+
axes[0, 1].set_xlabel("Texts/Second")
|
| 764 |
+
|
| 765 |
+
# 3. Memory Usage
|
| 766 |
+
if self.benchmark_df is not None and "Memory_Used_MB" in self.benchmark_df.columns:
|
| 767 |
+
axes[0, 2].barh(self.benchmark_df["Model"], self.benchmark_df["Memory_Used_MB"])
|
| 768 |
+
axes[0, 2].set_title("Memory Usage (Batch 32)")
|
| 769 |
+
axes[0, 2].set_xlabel("Memory (MB)")
|
| 770 |
+
|
| 771 |
+
# 4. Latency Comparison
|
| 772 |
+
if self.benchmark_df is not None and "Latency_MsPerText" in self.benchmark_df.columns:
|
| 773 |
+
axes[1, 0].barh(self.benchmark_df["Model"], self.benchmark_df["Latency_MsPerText"])
|
| 774 |
+
axes[1, 0].set_title("Inference Latency")
|
| 775 |
+
axes[1, 0].set_xlabel("Milliseconds/Text")
|
| 776 |
+
|
| 777 |
+
# 5. CPU vs GPU Performance
|
| 778 |
+
if self.benchmark_df is not None:
|
| 779 |
+
cpu_col = "CPU_TextsPerSec"
|
| 780 |
+
gpu_col = "CUDA_TextsPerSec"
|
| 781 |
+
if cpu_col in self.benchmark_df.columns and gpu_col in self.benchmark_df.columns:
|
| 782 |
+
x = np.arange(len(self.benchmark_df))
|
| 783 |
+
width = 0.35
|
| 784 |
+
axes[1, 1].bar(x - width / 2, self.benchmark_df[cpu_col], width, label="CPU", alpha=0.7)
|
| 785 |
+
axes[1, 1].bar(x + width / 2, self.benchmark_df[gpu_col], width, label="GPU", alpha=0.7)
|
| 786 |
+
axes[1, 1].set_title("CPU vs GPU Performance")
|
| 787 |
+
axes[1, 1].set_ylabel("Texts/Second")
|
| 788 |
+
axes[1, 1].set_xticks(x)
|
| 789 |
+
axes[1, 1].set_xticklabels(self.benchmark_df["Model"], rotation=45, ha="right")
|
| 790 |
+
axes[1, 1].legend()
|
| 791 |
+
|
| 792 |
+
# 6. Parameter Efficiency
|
| 793 |
+
if (
|
| 794 |
+
self.benchmark_df is not None
|
| 795 |
+
and "Parameters_M" in self.benchmark_df.columns
|
| 796 |
+
and "Throughput_TextsPerSec" in self.benchmark_df.columns
|
| 797 |
+
):
|
| 798 |
+
# Efficiency = Throughput / Parameters (higher is better)
|
| 799 |
+
efficiency = self.benchmark_df["Throughput_TextsPerSec"] / (self.benchmark_df["Parameters_M"] + 1e-6)
|
| 800 |
+
axes[1, 2].barh(self.benchmark_df["Model"], efficiency)
|
| 801 |
+
axes[1, 2].set_title("Parameter Efficiency")
|
| 802 |
+
axes[1, 2].set_xlabel("Texts/Sec per Million Parameters")
|
| 803 |
+
|
| 804 |
+
plt.tight_layout()
|
| 805 |
+
|
| 806 |
+
output_path = save_path or str(self.images_dir / "benchmark_performance.png")
|
| 807 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 808 |
+
plt.close()
|
| 809 |
+
|
| 810 |
+
return output_path
|
| 811 |
+
|
| 812 |
+
def plot_batch_size_scaling(self, save_path: str | None = None) -> str:
|
| 813 |
+
"""Create batch size scaling analysis plot."""
|
| 814 |
+
if not self.benchmark_results:
|
| 815 |
+
return ""
|
| 816 |
+
|
| 817 |
+
plt.figure(figsize=(12, 8))
|
| 818 |
+
|
| 819 |
+
for result in self.benchmark_results:
|
| 820 |
+
model_name = result.get("model_name", "Unknown")
|
| 821 |
+
speed_benchmarks = result.get("speed_benchmarks", {})
|
| 822 |
+
|
| 823 |
+
# Extract batch size performance for medium texts
|
| 824 |
+
if "medium" in speed_benchmarks:
|
| 825 |
+
batch_sizes = []
|
| 826 |
+
throughputs = []
|
| 827 |
+
|
| 828 |
+
for batch_key, metrics in speed_benchmarks["medium"].items():
|
| 829 |
+
if batch_key.startswith("batch_"):
|
| 830 |
+
batch_size = int(batch_key.split("_")[1])
|
| 831 |
+
throughput = metrics.get("texts_per_second", 0)
|
| 832 |
+
batch_sizes.append(batch_size)
|
| 833 |
+
throughputs.append(throughput)
|
| 834 |
+
|
| 835 |
+
if batch_sizes:
|
| 836 |
+
plt.plot(batch_sizes, throughputs, marker="o", label=model_name, linewidth=2)
|
| 837 |
+
|
| 838 |
+
plt.xlabel("Batch Size", fontsize=12)
|
| 839 |
+
plt.ylabel("Throughput (Texts/Second)", fontsize=12)
|
| 840 |
+
plt.title("Batch Size Scaling Performance", fontsize=16, fontweight="bold")
|
| 841 |
+
plt.legend()
|
| 842 |
+
plt.grid(visible=True, alpha=0.3)
|
| 843 |
+
plt.xscale("log", base=2)
|
| 844 |
+
|
| 845 |
+
output_path = save_path or str(self.images_dir / "batch_size_scaling.png")
|
| 846 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 847 |
+
plt.close()
|
| 848 |
+
|
| 849 |
+
return output_path
|
| 850 |
+
|
| 851 |
+
def plot_memory_scaling(self, save_path: str | None = None) -> str:
|
| 852 |
+
"""Create memory scaling analysis plot."""
|
| 853 |
+
if not self.benchmark_results:
|
| 854 |
+
return ""
|
| 855 |
+
|
| 856 |
+
plt.figure(figsize=(12, 8))
|
| 857 |
+
|
| 858 |
+
for result in self.benchmark_results:
|
| 859 |
+
model_name = result.get("model_name", "Unknown")
|
| 860 |
+
memory_benchmarks = result.get("memory_benchmarks", {})
|
| 861 |
+
|
| 862 |
+
batch_sizes = []
|
| 863 |
+
memory_usage = []
|
| 864 |
+
|
| 865 |
+
for batch_key, metrics in memory_benchmarks.items():
|
| 866 |
+
if batch_key.startswith("batch_") and not metrics.get("oom", False) and "error" not in metrics:
|
| 867 |
+
batch_size = int(batch_key.split("_")[1])
|
| 868 |
+
memory_mb = metrics.get("memory_used_mb", 0)
|
| 869 |
+
batch_sizes.append(batch_size)
|
| 870 |
+
memory_usage.append(memory_mb)
|
| 871 |
+
|
| 872 |
+
if batch_sizes:
|
| 873 |
+
plt.plot(batch_sizes, memory_usage, marker="s", label=model_name, linewidth=2)
|
| 874 |
+
|
| 875 |
+
plt.xlabel("Batch Size", fontsize=12)
|
| 876 |
+
plt.ylabel("Memory Usage (MB)", fontsize=12)
|
| 877 |
+
plt.title("Memory Scaling by Batch Size", fontsize=16, fontweight="bold")
|
| 878 |
+
plt.legend()
|
| 879 |
+
plt.grid(visible=True, alpha=0.3)
|
| 880 |
+
plt.xscale("log", base=2)
|
| 881 |
+
|
| 882 |
+
output_path = save_path or str(self.images_dir / "memory_scaling.png")
|
| 883 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 884 |
+
plt.close()
|
| 885 |
+
|
| 886 |
+
return output_path
|
| 887 |
+
|
| 888 |
+
def create_peer_comparison_chart(self, model_name: str) -> str:
|
| 889 |
+
"""Create comparison chart using actual evaluation results."""
|
| 890 |
+
if self.comparison_df is None or self.comparison_df.empty:
|
| 891 |
+
logger.warning("No comparison data available for peer comparison chart")
|
| 892 |
+
return ""
|
| 893 |
+
|
| 894 |
+
# Use actual evaluation results instead of hardcoded scores
|
| 895 |
+
df_sorted = self.comparison_df.sort_values("NDCG@10", ascending=True)
|
| 896 |
+
|
| 897 |
+
plt.figure(figsize=(12, 8))
|
| 898 |
+
|
| 899 |
+
# Color models differently - highlight the user's model
|
| 900 |
+
colors = []
|
| 901 |
+
for model in df_sorted["Model"]:
|
| 902 |
+
if model_name.lower() in model.lower() or "gte_qwen2_m2v_code" in model.lower():
|
| 903 |
+
colors.append("red") # User's model
|
| 904 |
+
else:
|
| 905 |
+
colors.append("skyblue") # Peer models
|
| 906 |
+
|
| 907 |
+
bars = plt.barh(df_sorted["Model"], df_sorted["NDCG@10"], color=colors)
|
| 908 |
+
|
| 909 |
+
# Highlight current model with special formatting
|
| 910 |
+
for i, model in enumerate(df_sorted["Model"]):
|
| 911 |
+
if model_name.lower() in model.lower() or "gte_qwen2_m2v_code" in model.lower():
|
| 912 |
+
bars[i].set_alpha(0.8)
|
| 913 |
+
bars[i].set_edgecolor("black")
|
| 914 |
+
bars[i].set_linewidth(2)
|
| 915 |
+
|
| 916 |
+
plt.xlabel("NDCG@10 Score", fontsize=12)
|
| 917 |
+
plt.title(
|
| 918 |
+
"CodeSearchNet Model Comparison (Actual Results)",
|
| 919 |
+
fontsize=16,
|
| 920 |
+
fontweight="bold",
|
| 921 |
+
)
|
| 922 |
+
plt.grid(axis="x", alpha=0.3)
|
| 923 |
+
|
| 924 |
+
# Add score labels
|
| 925 |
+
for i, score in enumerate(df_sorted["NDCG@10"]):
|
| 926 |
+
plt.text(score + 0.005, i, f"{score:.3f}", va="center")
|
| 927 |
+
|
| 928 |
+
plt.tight_layout()
|
| 929 |
+
|
| 930 |
+
output_path = self.images_dir / "peer_comparison.png"
|
| 931 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 932 |
+
plt.close()
|
| 933 |
+
|
| 934 |
+
return str(output_path)
|
| 935 |
+
|
| 936 |
+
def create_efficiency_analysis(self, model_name: str) -> str:
|
| 937 |
+
"""Create efficiency analysis chart using actual evaluation results."""
|
| 938 |
+
if self.comparison_df is None or self.comparison_df.empty:
|
| 939 |
+
logger.warning("No comparison data available for efficiency analysis")
|
| 940 |
+
return ""
|
| 941 |
+
|
| 942 |
+
models = []
|
| 943 |
+
scores = []
|
| 944 |
+
params = []
|
| 945 |
+
is_user_model = []
|
| 946 |
+
|
| 947 |
+
# Process all evaluated models
|
| 948 |
+
for _, row in self.comparison_df.iterrows():
|
| 949 |
+
model_display_name = row["Model"]
|
| 950 |
+
current_model_score = row["NDCG@10"]
|
| 951 |
+
|
| 952 |
+
# Determine if this is the user's model
|
| 953 |
+
is_users = (
|
| 954 |
+
model_name.lower() in model_display_name.lower() or "gte_qwen2_m2v_code" in model_display_name.lower()
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
if is_users:
|
| 958 |
+
# User's distilled model
|
| 959 |
+
models.append(model_display_name)
|
| 960 |
+
# Safe conversion to float for pandas values
|
| 961 |
+
score_value = pd.to_numeric(current_model_score, errors="coerce")
|
| 962 |
+
scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
|
| 963 |
+
# Safe conversion for DISTILLED_MODEL_SPECS parameters
|
| 964 |
+
param_value = DISTILLED_MODEL_SPECS.get("parameters", 39)
|
| 965 |
+
params.append(float(param_value) if isinstance(param_value, (int, float)) else 39.0)
|
| 966 |
+
is_user_model.append(True)
|
| 967 |
+
else:
|
| 968 |
+
# Find corresponding peer model specs
|
| 969 |
+
model_key = None
|
| 970 |
+
for peer_key in MODEL_SPECS:
|
| 971 |
+
peer_short_name = peer_key.split("/")[-1].lower()
|
| 972 |
+
if peer_short_name in model_display_name.lower():
|
| 973 |
+
model_key = peer_key
|
| 974 |
+
break
|
| 975 |
+
|
| 976 |
+
if model_key and model_key in MODEL_SPECS:
|
| 977 |
+
models.append(model_display_name.split("/")[-1]) # Short name
|
| 978 |
+
# Safe conversion to float for pandas values
|
| 979 |
+
score_value = pd.to_numeric(current_model_score, errors="coerce")
|
| 980 |
+
scores.append(float(score_value) if not pd.isna(score_value) else 0.0)
|
| 981 |
+
params.append(float(MODEL_SPECS[model_key].get("parameters", 100)))
|
| 982 |
+
is_user_model.append(False)
|
| 983 |
+
|
| 984 |
+
if not models:
|
| 985 |
+
logger.warning("No models with parameter specifications found")
|
| 986 |
+
return ""
|
| 987 |
+
|
| 988 |
+
plt.figure(figsize=(12, 8))
|
| 989 |
+
|
| 990 |
+
# Plot peer models
|
| 991 |
+
peer_models = [m for i, m in enumerate(models) if not is_user_model[i]]
|
| 992 |
+
peer_params = [p for i, p in enumerate(params) if not is_user_model[i]]
|
| 993 |
+
peer_scores = [s for i, s in enumerate(scores) if not is_user_model[i]]
|
| 994 |
+
|
| 995 |
+
if peer_models:
|
| 996 |
+
plt.scatter(
|
| 997 |
+
peer_params,
|
| 998 |
+
peer_scores,
|
| 999 |
+
s=100,
|
| 1000 |
+
alpha=0.6,
|
| 1001 |
+
label="Peer Models",
|
| 1002 |
+
color="skyblue",
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
# Plot user's model
|
| 1006 |
+
user_models = [m for i, m in enumerate(models) if is_user_model[i]]
|
| 1007 |
+
user_params = [p for i, p in enumerate(params) if is_user_model[i]]
|
| 1008 |
+
user_scores = [s for i, s in enumerate(scores) if is_user_model[i]]
|
| 1009 |
+
|
| 1010 |
+
if user_models:
|
| 1011 |
+
plt.scatter(
|
| 1012 |
+
user_params,
|
| 1013 |
+
user_scores,
|
| 1014 |
+
s=200,
|
| 1015 |
+
color="red",
|
| 1016 |
+
alpha=0.8,
|
| 1017 |
+
label=f"{user_models[0]} (Distilled)",
|
| 1018 |
+
marker="*",
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
# Add model labels
|
| 1022 |
+
for i, (model, param, score) in enumerate(zip(models, params, scores, strict=False)):
|
| 1023 |
+
if is_user_model[i]:
|
| 1024 |
+
plt.annotate(
|
| 1025 |
+
model,
|
| 1026 |
+
(param, score),
|
| 1027 |
+
xytext=(10, 10),
|
| 1028 |
+
textcoords="offset points",
|
| 1029 |
+
fontweight="bold",
|
| 1030 |
+
color="red",
|
| 1031 |
+
)
|
| 1032 |
+
else:
|
| 1033 |
+
plt.annotate(
|
| 1034 |
+
model,
|
| 1035 |
+
(param, score),
|
| 1036 |
+
xytext=(5, 5),
|
| 1037 |
+
textcoords="offset points",
|
| 1038 |
+
fontsize=9,
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
plt.xlabel("Model Size (Million Parameters)", fontsize=12)
|
| 1042 |
+
plt.ylabel("NDCG@10 Score", fontsize=12)
|
| 1043 |
+
plt.title(
|
| 1044 |
+
"Model Efficiency: Performance vs Size (Actual Results)",
|
| 1045 |
+
fontsize=16,
|
| 1046 |
+
fontweight="bold",
|
| 1047 |
+
)
|
| 1048 |
+
plt.legend()
|
| 1049 |
+
plt.grid(visible=True, alpha=0.3)
|
| 1050 |
+
plt.xscale("log")
|
| 1051 |
+
|
| 1052 |
+
plt.tight_layout()
|
| 1053 |
+
|
| 1054 |
+
output_path = self.images_dir / "efficiency_analysis.png"
|
| 1055 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
| 1056 |
+
plt.close()
|
| 1057 |
+
|
| 1058 |
+
return str(output_path)
|
| 1059 |
+
|
| 1060 |
+
def generate_comprehensive_report(self, model_name: str = "Simplified Distillation Models") -> str:
|
| 1061 |
+
"""Generate comprehensive markdown report for all evaluated models."""
|
| 1062 |
+
if not self.results:
|
| 1063 |
+
logger.error("No results to analyze")
|
| 1064 |
+
return ""
|
| 1065 |
+
|
| 1066 |
+
# Find all simplified distillation models
|
| 1067 |
+
simplified_models = []
|
| 1068 |
+
peer_models = []
|
| 1069 |
+
|
| 1070 |
+
for result in self.results:
|
| 1071 |
+
result_model_name = result["model_name"]
|
| 1072 |
+
if (
|
| 1073 |
+
"code_model2vec" in result_model_name.lower()
|
| 1074 |
+
or "distilled" in result_model_name.lower()
|
| 1075 |
+
or "(ours)" in result_model_name.lower()
|
| 1076 |
+
):
|
| 1077 |
+
simplified_models.append(result)
|
| 1078 |
+
else:
|
| 1079 |
+
peer_models.append(result)
|
| 1080 |
+
|
| 1081 |
+
# Get the best performing simplified model for main analysis
|
| 1082 |
+
if simplified_models:
|
| 1083 |
+
main_result = max(simplified_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0))
|
| 1084 |
+
main_model_name = main_result["model_name"]
|
| 1085 |
+
else:
|
| 1086 |
+
# Fallback to first result if no simplified models found
|
| 1087 |
+
main_result = self.results[0]
|
| 1088 |
+
main_model_name = main_result["model_name"]
|
| 1089 |
+
|
| 1090 |
+
overall = main_result.get("overall", {})
|
| 1091 |
+
languages = main_result.get("languages", {})
|
| 1092 |
+
|
| 1093 |
+
# Calculate language scores for radar chart
|
| 1094 |
+
language_scores = {}
|
| 1095 |
+
for lang, lang_data in languages.items():
|
| 1096 |
+
metrics = lang_data.get("metrics", {})
|
| 1097 |
+
language_scores[lang.title()] = metrics.get("ndcg@10", 0)
|
| 1098 |
+
|
| 1099 |
+
# Create visualizations
|
| 1100 |
+
logger.info("Generating visualizations...")
|
| 1101 |
+
setup_directories()
|
| 1102 |
+
|
| 1103 |
+
self.create_performance_radar_chart(main_model_name, language_scores)
|
| 1104 |
+
comparison_chart = self.plot_model_comparison()
|
| 1105 |
+
heatmap_chart = self.plot_language_heatmap()
|
| 1106 |
+
peer_chart = self.create_peer_comparison_chart(main_model_name)
|
| 1107 |
+
efficiency_chart = self.create_efficiency_analysis(main_model_name)
|
| 1108 |
+
|
| 1109 |
+
# Generate individual radar charts for all simplified models
|
| 1110 |
+
individual_radar_charts = self.create_individual_radar_charts(simplified_models)
|
| 1111 |
+
|
| 1112 |
+
# Create comparative radar chart (best distilled vs top peer models)
|
| 1113 |
+
comparative_radar_chart = self.create_comparative_radar_chart(simplified_models, peer_models)
|
| 1114 |
+
|
| 1115 |
+
# Create benchmark visualizations
|
| 1116 |
+
benchmark_chart = ""
|
| 1117 |
+
batch_scaling_chart = ""
|
| 1118 |
+
memory_scaling_chart = ""
|
| 1119 |
+
if self.benchmark_results:
|
| 1120 |
+
benchmark_chart = self.plot_benchmark_performance()
|
| 1121 |
+
batch_scaling_chart = self.plot_batch_size_scaling()
|
| 1122 |
+
memory_scaling_chart = self.plot_memory_scaling()
|
| 1123 |
+
|
| 1124 |
+
# Generate report
|
| 1125 |
+
report = f"""# Code-Specialized Model2Vec Distillation Analysis
|
| 1126 |
+
|
| 1127 |
+
## π― Executive Summary
|
| 1128 |
+
|
| 1129 |
+
This report presents a comprehensive analysis of Model2Vec distillation experiments using different teacher models for code-specialized embedding generation.
|
| 1130 |
+
|
| 1131 |
+
### Evaluated Models Overview
|
| 1132 |
+
|
| 1133 |
+
**Simplified Distillation Models:** {len(simplified_models)}
|
| 1134 |
+
**Peer Comparison Models:** {len(peer_models)}
|
| 1135 |
+
**Total Models Analyzed:** {len(self.results)}
|
| 1136 |
+
|
| 1137 |
+
### Best Performing Simplified Model: {main_model_name}
|
| 1138 |
+
|
| 1139 |
+
**Overall CodeSearchNet Performance:**
|
| 1140 |
+
- **NDCG@10**: {overall.get("ndcg@10", 0):.4f}
|
| 1141 |
+
- **Mean Reciprocal Rank (MRR)**: {overall.get("mrr", 0):.4f}
|
| 1142 |
+
- **Recall@5**: {overall.get("recall@5", 0):.4f}
|
| 1143 |
+
- **Mean Rank**: {overall.get("mean_rank", 0):.1f}
|
| 1144 |
+
|
| 1145 |
+
## π Comprehensive Model Comparison
|
| 1146 |
+
|
| 1147 |
+
### All Simplified Distillation Models Performance
|
| 1148 |
+
|
| 1149 |
+
"""
|
| 1150 |
+
|
| 1151 |
+
# Add table of all simplified models
|
| 1152 |
+
if simplified_models:
|
| 1153 |
+
report += "| Model | Teacher | NDCG@10 | MRR | Recall@5 | Status |\n"
|
| 1154 |
+
report += "|-------|---------|---------|-----|----------|--------|\n"
|
| 1155 |
+
|
| 1156 |
+
# Sort by performance
|
| 1157 |
+
simplified_models_sorted = sorted(
|
| 1158 |
+
simplified_models, key=lambda x: x.get("overall", {}).get("ndcg@10", 0), reverse=True
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
for rank, result in enumerate(simplified_models_sorted, 1):
|
| 1162 |
+
model_display = result["model_name"]
|
| 1163 |
+
overall_metrics = result.get("overall", {})
|
| 1164 |
+
|
| 1165 |
+
# Extract teacher model name from model name
|
| 1166 |
+
teacher = "Unknown"
|
| 1167 |
+
if "all_MiniLM_L6_v2" in model_display:
|
| 1168 |
+
teacher = "all-MiniLM-L6-v2"
|
| 1169 |
+
elif "codebert_base" in model_display:
|
| 1170 |
+
teacher = "codebert-base"
|
| 1171 |
+
elif "graphcodebert_base" in model_display:
|
| 1172 |
+
teacher = "graphcodebert-base"
|
| 1173 |
+
elif "gte_Qwen2_7B_instruct" in model_display:
|
| 1174 |
+
teacher = "gte-Qwen2-7B-instruct"
|
| 1175 |
+
elif "all_mpnet_base_v2" in model_display:
|
| 1176 |
+
teacher = "all-mpnet-base-v2"
|
| 1177 |
+
|
| 1178 |
+
status = "π₯ Best" if rank == 1 else "π₯ 2nd" if rank == 2 else "π₯ 3rd" if rank == 3 else f"#{rank}"
|
| 1179 |
+
|
| 1180 |
+
report += f"| {model_display} | {teacher} | {overall_metrics.get('ndcg@10', 0):.4f} | {overall_metrics.get('mrr', 0):.4f} | {overall_metrics.get('recall@5', 0):.4f} | {status} |\n"
|
| 1181 |
+
|
| 1182 |
+
report += """
|
| 1183 |
+
|
| 1184 |
+
### Key Findings
|
| 1185 |
+
|
| 1186 |
+
"""
|
| 1187 |
+
|
| 1188 |
+
if simplified_models and len(simplified_models) > 1:
|
| 1189 |
+
best_model = simplified_models_sorted[0]
|
| 1190 |
+
worst_model = simplified_models_sorted[-1]
|
| 1191 |
+
best_score = best_model.get("overall", {}).get("ndcg@10", 0)
|
| 1192 |
+
worst_score = worst_model.get("overall", {}).get("ndcg@10", 0)
|
| 1193 |
+
|
| 1194 |
+
report += f"""
|
| 1195 |
+
- **Best Teacher Model**: {best_model["model_name"]} (NDCG@10: {best_score:.4f})
|
| 1196 |
+
- **Least Effective Teacher**: {worst_model["model_name"]} (NDCG@10: {worst_score:.4f})
|
| 1197 |
+
- **Performance Range**: {((best_score - worst_score) / best_score * 100):.1f}% difference between best and worst
|
| 1198 |
+
- **Average Performance**: {sum(r.get("overall", {}).get("ndcg@10", 0) for r in simplified_models) / len(simplified_models):.4f} NDCG@10
|
| 1199 |
+
"""
|
| 1200 |
+
|
| 1201 |
+
# Add radar charts section
|
| 1202 |
+
report += """
|
| 1203 |
+
|
| 1204 |
+
## π― Language Performance Radar Charts
|
| 1205 |
+
|
| 1206 |
+
### Best Model vs Peer Models Comparison
|
| 1207 |
+
|
| 1208 |
+
"""
|
| 1209 |
+
if comparative_radar_chart:
|
| 1210 |
+
report += f"\n\n"
|
| 1211 |
+
report += "*Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*\n\n"
|
| 1212 |
+
|
| 1213 |
+
# Add individual radar charts for all simplified models
|
| 1214 |
+
if individual_radar_charts:
|
| 1215 |
+
report += "### Individual Model Performance by Language\n\n"
|
| 1216 |
+
for chart_model_name, chart_path in individual_radar_charts.items():
|
| 1217 |
+
# Extract teacher name for cleaner display
|
| 1218 |
+
teacher = "Unknown"
|
| 1219 |
+
if "all_MiniLM_L6_v2" in chart_model_name:
|
| 1220 |
+
teacher = "all-MiniLM-L6-v2"
|
| 1221 |
+
elif "codebert_base" in chart_model_name:
|
| 1222 |
+
teacher = "codebert-base"
|
| 1223 |
+
elif "graphcodebert_base" in chart_model_name:
|
| 1224 |
+
teacher = "graphcodebert-base"
|
| 1225 |
+
elif "gte_Qwen2_7B_instruct" in chart_model_name:
|
| 1226 |
+
teacher = "gte-Qwen2-7B-instruct"
|
| 1227 |
+
elif "all_mpnet_base_v2" in chart_model_name:
|
| 1228 |
+
teacher = "all-mpnet-base-v2"
|
| 1229 |
+
|
| 1230 |
+
report += f"#### {chart_model_name} (Teacher: {teacher})\n\n"
|
| 1231 |
+
report += f"\n\n"
|
| 1232 |
+
|
| 1233 |
+
report += f"""
|
| 1234 |
+
|
| 1235 |
+
## π Peer Model Comparison
|
| 1236 |
+
|
| 1237 |
+

|
| 1238 |
+
|
| 1239 |
+
*Comparison with established code-specialized embedding models using actual evaluation results.*
|
| 1240 |
+
|
| 1241 |
+
### Complete Model Ranking
|
| 1242 |
+
|
| 1243 |
+
"""
|
| 1244 |
+
|
| 1245 |
+
# Add comprehensive ranking table
|
| 1246 |
+
if self.comparison_df is not None and len(self.comparison_df) > 0:
|
| 1247 |
+
report += "| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |\n"
|
| 1248 |
+
report += "|------|-------|------|---------|-----|----------|\n"
|
| 1249 |
+
|
| 1250 |
+
for rank in range(len(self.comparison_df)):
|
| 1251 |
+
row_data = self.comparison_df.iloc[rank]
|
| 1252 |
+
model_name_display = str(row_data["Model"])
|
| 1253 |
+
|
| 1254 |
+
# Determine model type
|
| 1255 |
+
if (
|
| 1256 |
+
"code_model2vec" in model_name_display.lower()
|
| 1257 |
+
or "distilled" in model_name_display.lower()
|
| 1258 |
+
or "(ours)" in model_name_display.lower()
|
| 1259 |
+
):
|
| 1260 |
+
model_type = "**π₯ Simplified Distillation**"
|
| 1261 |
+
elif any(code_term in model_name_display.lower() for code_term in ["codebert", "graphcode", "codet5"]):
|
| 1262 |
+
model_type = "Code-Specific"
|
| 1263 |
+
elif "potion" in model_name_display.lower():
|
| 1264 |
+
model_type = "Model2Vec"
|
| 1265 |
+
else:
|
| 1266 |
+
model_type = "General"
|
| 1267 |
+
|
| 1268 |
+
report += f"| {rank + 1} | {model_name_display} | {model_type} | {row_data['NDCG@10']:.4f} | {row_data['MRR']:.4f} | {row_data['Recall@5']:.4f} |\n"
|
| 1269 |
+
|
| 1270 |
+
report += f"""
|
| 1271 |
+
|
| 1272 |
+
## π Performance Analysis
|
| 1273 |
+
|
| 1274 |
+
### Multi-Model Comparison Charts
|
| 1275 |
+
|
| 1276 |
+

|
| 1277 |
+
|
| 1278 |
+
*Comprehensive comparison across all evaluation metrics.*
|
| 1279 |
+
|
| 1280 |
+
### Language Performance Analysis
|
| 1281 |
+
|
| 1282 |
+

|
| 1283 |
+
|
| 1284 |
+
*Performance heatmap showing how different models perform across programming languages.*
|
| 1285 |
+
|
| 1286 |
+
### Efficiency Analysis
|
| 1287 |
+
|
| 1288 |
+

|
| 1289 |
+
|
| 1290 |
+
*Performance vs model size analysis showing the efficiency benefits of distillation.*
|
| 1291 |
+
|
| 1292 |
+
"""
|
| 1293 |
+
|
| 1294 |
+
# Add benchmark analysis if available
|
| 1295 |
+
if self.benchmark_results:
|
| 1296 |
+
report += f"""
|
| 1297 |
+
|
| 1298 |
+
## β‘ Operational Performance Analysis
|
| 1299 |
+
|
| 1300 |
+

|
| 1301 |
+
|
| 1302 |
+
*Comprehensive performance benchmarking across multiple operational metrics.*
|
| 1303 |
+
|
| 1304 |
+
### Performance Scaling Analysis
|
| 1305 |
+
|
| 1306 |
+

|
| 1307 |
+
|
| 1308 |
+
*How performance scales with different batch sizes for optimal throughput.*
|
| 1309 |
+
|
| 1310 |
+

|
| 1311 |
+
|
| 1312 |
+
*Memory usage patterns across different batch sizes.*
|
| 1313 |
+
|
| 1314 |
+
"""
|
| 1315 |
+
|
| 1316 |
+
# Add detailed language analysis
|
| 1317 |
+
report += """
|
| 1318 |
+
|
| 1319 |
+
## π Language-Specific Analysis
|
| 1320 |
+
|
| 1321 |
+
### Performance by Programming Language
|
| 1322 |
+
|
| 1323 |
+
"""
|
| 1324 |
+
|
| 1325 |
+
if language_scores:
|
| 1326 |
+
report += "| Language | Best Model Performance | Average Performance | Language Difficulty |\n"
|
| 1327 |
+
report += "|----------|------------------------|--------------------|--------------------||\n"
|
| 1328 |
+
|
| 1329 |
+
for lang in sorted(language_scores.keys()):
|
| 1330 |
+
# Find best performance for this language across all models
|
| 1331 |
+
lang_performances = []
|
| 1332 |
+
for result in self.results:
|
| 1333 |
+
lang_data = result.get("languages", {}).get(lang.lower(), {})
|
| 1334 |
+
if lang_data:
|
| 1335 |
+
lang_performances.append(lang_data.get("metrics", {}).get("ndcg@10", 0))
|
| 1336 |
+
|
| 1337 |
+
if lang_performances:
|
| 1338 |
+
best_lang_perf = max(lang_performances)
|
| 1339 |
+
avg_lang_perf = sum(lang_performances) / len(lang_performances)
|
| 1340 |
+
difficulty = "Easy" if avg_lang_perf > 0.3 else "Medium" if avg_lang_perf > 0.2 else "Hard"
|
| 1341 |
+
|
| 1342 |
+
report += f"| {lang} | {best_lang_perf:.4f} | {avg_lang_perf:.4f} | {difficulty} |\n"
|
| 1343 |
+
|
| 1344 |
+
report += """
|
| 1345 |
+
|
| 1346 |
+
## π― Conclusions and Recommendations
|
| 1347 |
+
|
| 1348 |
+
### Teacher Model Analysis
|
| 1349 |
+
|
| 1350 |
+
Based on the evaluation results across all simplified distillation models:
|
| 1351 |
+
|
| 1352 |
+
"""
|
| 1353 |
+
|
| 1354 |
+
if simplified_models and len(simplified_models) > 1:
|
| 1355 |
+
# Analyze which teacher models work best
|
| 1356 |
+
teacher_performance = {}
|
| 1357 |
+
for result in simplified_models:
|
| 1358 |
+
model_name = result["model_name"]
|
| 1359 |
+
score = result.get("overall", {}).get("ndcg@10", 0)
|
| 1360 |
+
|
| 1361 |
+
if "all_MiniLM_L6_v2" in model_name:
|
| 1362 |
+
teacher_performance["all-MiniLM-L6-v2"] = score
|
| 1363 |
+
elif "codebert_base" in model_name:
|
| 1364 |
+
teacher_performance["codebert-base"] = score
|
| 1365 |
+
elif "graphcodebert_base" in model_name:
|
| 1366 |
+
teacher_performance["graphcodebert-base"] = score
|
| 1367 |
+
elif "gte_Qwen2_7B_instruct" in model_name:
|
| 1368 |
+
teacher_performance["gte-Qwen2-7B-instruct"] = score
|
| 1369 |
+
elif "all_mpnet_base_v2" in model_name:
|
| 1370 |
+
teacher_performance["all-mpnet-base-v2"] = score
|
| 1371 |
+
|
| 1372 |
+
if teacher_performance:
|
| 1373 |
+
best_teacher = max(teacher_performance.items(), key=lambda x: x[1])
|
| 1374 |
+
worst_teacher = min(teacher_performance.items(), key=lambda x: x[1])
|
| 1375 |
+
|
| 1376 |
+
report += f"""
|
| 1377 |
+
1. **Best Teacher Model**: {best_teacher[0]} (NDCG@10: {best_teacher[1]:.4f})
|
| 1378 |
+
2. **Least Effective Teacher**: {worst_teacher[0]} (NDCG@10: {worst_teacher[1]:.4f})
|
| 1379 |
+
3. **Teacher Model Impact**: Choice of teacher model affects performance by {((best_teacher[1] - worst_teacher[1]) / best_teacher[1] * 100):.1f}%
|
| 1380 |
+
|
| 1381 |
+
### Recommendations
|
| 1382 |
+
|
| 1383 |
+
- **For Production**: Use {best_teacher[0]} as teacher model for best performance
|
| 1384 |
+
- **For Efficiency**: Model2Vec distillation provides significant size reduction with competitive performance
|
| 1385 |
+
- **For Code Tasks**: Specialized models consistently outperform general-purpose models
|
| 1386 |
+
"""
|
| 1387 |
+
|
| 1388 |
+
report += f"""
|
| 1389 |
+
|
| 1390 |
+
## π Methodology
|
| 1391 |
+
|
| 1392 |
+
### Evaluation Protocol
|
| 1393 |
+
- **Dataset**: CodeSearchNet test sets for 6 programming languages
|
| 1394 |
+
- **Metrics**: NDCG@k, MRR, Recall@k following CodeSearchNet methodology
|
| 1395 |
+
- **Query Format**: Natural language documentation strings
|
| 1396 |
+
- **Corpus Format**: Function code strings
|
| 1397 |
+
- **Evaluation**: Retrieval of correct code for each documentation query
|
| 1398 |
+
|
| 1399 |
+
### Teacher Models Tested
|
| 1400 |
+
- sentence-transformers/all-MiniLM-L6-v2 (proven baseline)
|
| 1401 |
+
- microsoft/codebert-base (code-specialized)
|
| 1402 |
+
- microsoft/graphcodebert-base (graph-aware code model)
|
| 1403 |
+
- Alibaba-NLP/gte-Qwen2-7B-instruct (large instruction model)
|
| 1404 |
+
- sentence-transformers/all-mpnet-base-v2 (general purpose)
|
| 1405 |
+
|
| 1406 |
+
### Distillation Method
|
| 1407 |
+
- **Technique**: Model2Vec static embedding generation
|
| 1408 |
+
- **Parameters**: PCA dims=256, SIF coefficient=1e-3, Zipf weighting=True
|
| 1409 |
+
- **Training Data**: CodeSearchNet comment-code pairs
|
| 1410 |
+
- **Languages**: Python, JavaScript, Java, PHP, Ruby, Go
|
| 1411 |
+
|
| 1412 |
+
---
|
| 1413 |
+
|
| 1414 |
+
*Report generated on {time.strftime("%Y-%m-%d %H:%M:%S")} using automated analysis pipeline.*
|
| 1415 |
+
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
| 1416 |
+
"""
|
| 1417 |
+
|
| 1418 |
+
return report
|
| 1419 |
+
|
| 1420 |
+
def export_results(self, output_file: str) -> None:
|
| 1421 |
+
"""Export results to CSV format."""
|
| 1422 |
+
if self.comparison_df is not None:
|
| 1423 |
+
self.comparison_df.to_csv(output_file, index=False)
|
| 1424 |
+
logger.info(f"Results exported to {output_file}")
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
def main() -> None:
|
| 1428 |
+
"""Main analysis function."""
|
| 1429 |
+
parser = argparse.ArgumentParser(description="Analyze CodeSearchNet evaluation results and performance benchmarks")
|
| 1430 |
+
parser.add_argument("--results-dir", default=DEFAULT_EVALUATION_DIR, help="Evaluation results directory")
|
| 1431 |
+
parser.add_argument("--benchmark-dir", default=DEFAULT_BENCHMARK_DIR, help="Benchmark results directory")
|
| 1432 |
+
parser.add_argument("--model-name", default="gte_qwen2_m2v_code (Ours)", help="Model name for report")
|
| 1433 |
+
parser.add_argument("--output", default="REPORT.md", help="Output report file")
|
| 1434 |
+
parser.add_argument("--export-csv", help="Export comparison results to CSV")
|
| 1435 |
+
|
| 1436 |
+
args = parser.parse_args()
|
| 1437 |
+
|
| 1438 |
+
logger.info("Starting CodeSearchNet Analysis with Benchmark Integration")
|
| 1439 |
+
logger.info("=" * 60)
|
| 1440 |
+
|
| 1441 |
+
# Setup output directories
|
| 1442 |
+
output_dir, images_dir, reports_dir = setup_directories()
|
| 1443 |
+
|
| 1444 |
+
# Initialize analyzer with local directories
|
| 1445 |
+
analyzer = CodeSearchNetAnalyzer(
|
| 1446 |
+
results_dir=args.results_dir,
|
| 1447 |
+
benchmark_dir=args.benchmark_dir,
|
| 1448 |
+
images_dir=images_dir,
|
| 1449 |
+
)
|
| 1450 |
+
|
| 1451 |
+
# Load results (this will also load benchmark results)
|
| 1452 |
+
analyzer.load_results()
|
| 1453 |
+
|
| 1454 |
+
if not analyzer.results:
|
| 1455 |
+
logger.error("No evaluation results found! Please run evaluation first.")
|
| 1456 |
+
return
|
| 1457 |
+
|
| 1458 |
+
# Print summary (includes both evaluation and benchmark summaries)
|
| 1459 |
+
analyzer.print_summary()
|
| 1460 |
+
analyzer.analyze_language_performance()
|
| 1461 |
+
|
| 1462 |
+
# Analyze benchmark performance if available
|
| 1463 |
+
if analyzer.benchmark_results:
|
| 1464 |
+
analyzer.analyze_benchmark_performance()
|
| 1465 |
+
else:
|
| 1466 |
+
logger.warning("No benchmark results found. Run benchmark.py first for complete analysis.")
|
| 1467 |
+
|
| 1468 |
+
# Generate comprehensive report with benchmark integration
|
| 1469 |
+
logger.info("Generating comprehensive report with benchmark data...")
|
| 1470 |
+
report = analyzer.generate_comprehensive_report(args.model_name)
|
| 1471 |
+
|
| 1472 |
+
# Save report
|
| 1473 |
+
report_path = Path(args.output)
|
| 1474 |
+
with report_path.open("w") as f:
|
| 1475 |
+
f.write(report)
|
| 1476 |
+
|
| 1477 |
+
# Export CSV if requested
|
| 1478 |
+
if args.export_csv:
|
| 1479 |
+
analyzer.export_results(args.export_csv)
|
| 1480 |
+
|
| 1481 |
+
# Export benchmark CSV if available
|
| 1482 |
+
if analyzer.benchmark_df is not None and not analyzer.benchmark_df.empty:
|
| 1483 |
+
benchmark_csv = report_path.parent / f"{args.model_name}_benchmark_comparison.csv"
|
| 1484 |
+
analyzer.benchmark_df.to_csv(benchmark_csv, index=False)
|
| 1485 |
+
logger.info(f"π Benchmark comparison saved to: {benchmark_csv}")
|
| 1486 |
+
|
| 1487 |
+
logger.info("β
CodeSearchNet analysis with benchmarks complete!")
|
| 1488 |
+
logger.info(f"π Report saved to: {report_path}")
|
| 1489 |
+
logger.info(f"πΌοΈ Charts saved to: {images_dir}")
|
| 1490 |
+
|
| 1491 |
+
|
| 1492 |
+
if __name__ == "__main__":
|
| 1493 |
+
import argparse
|
| 1494 |
+
|
| 1495 |
+
main()
|
src/distiller/beam_utils.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Beam Cloud Utilities for Model Distillation and Evaluation.
|
| 3 |
+
|
| 4 |
+
This module provides comprehensive utilities for managing Beam volumes, checkpoints,
|
| 5 |
+
and file operations across distillation, evaluation, and analysis workflows.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Volume management (direct file operations when mounted)
|
| 9 |
+
- Checkpoint operations (save, load, cleanup, resume)
|
| 10 |
+
- File transfer utilities (copy, move, sync)
|
| 11 |
+
- Evaluation result management
|
| 12 |
+
- Model artifact handling
|
| 13 |
+
- Distributed storage optimization
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import shutil
|
| 19 |
+
import time
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
# Configure logging
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BeamVolumeManager:
|
| 28 |
+
"""Manager for Beam distributed storage volumes using direct file operations."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, volume_name: str, mount_path: str = "./data") -> None:
|
| 31 |
+
"""
|
| 32 |
+
Initialize Beam Volume Manager.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
volume_name: Name of the Beam volume
|
| 36 |
+
mount_path: Local mount path for the volume (should match Beam function mount path)
|
| 37 |
+
"""
|
| 38 |
+
self.volume_name = volume_name
|
| 39 |
+
self.mount_path = Path(mount_path)
|
| 40 |
+
self.mount_path.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
def exists(self) -> bool:
|
| 43 |
+
"""Check if the volume mount path exists."""
|
| 44 |
+
return self.mount_path.exists()
|
| 45 |
+
|
| 46 |
+
def list_contents(self, subpath: str = "") -> list[dict[str, Any]]:
|
| 47 |
+
"""List contents of the volume directory."""
|
| 48 |
+
try:
|
| 49 |
+
target_path = self.mount_path / subpath if subpath else self.mount_path
|
| 50 |
+
if not target_path.exists():
|
| 51 |
+
logger.warning(f"β οΈ Path does not exist: {target_path}")
|
| 52 |
+
return []
|
| 53 |
+
|
| 54 |
+
contents: list[dict[str, Any]] = []
|
| 55 |
+
for item in target_path.iterdir():
|
| 56 |
+
stat = item.stat()
|
| 57 |
+
contents.append(
|
| 58 |
+
{
|
| 59 |
+
"name": item.name,
|
| 60 |
+
"size": f"{stat.st_size / (1024 * 1024):.2f}MB" if item.is_file() else "0MB",
|
| 61 |
+
"modified": time.ctime(stat.st_mtime),
|
| 62 |
+
"is_dir": item.is_dir(),
|
| 63 |
+
"path": str(item.relative_to(self.mount_path)),
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return sorted(contents, key=lambda x: (not x["is_dir"], x["name"]))
|
| 68 |
+
|
| 69 |
+
except Exception:
|
| 70 |
+
logger.exception("β Error listing contents")
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
def copy_file(self, src: str | Path, dst: str | Path) -> bool:
|
| 74 |
+
"""Copy a file within the volume."""
|
| 75 |
+
try:
|
| 76 |
+
src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
|
| 77 |
+
dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
|
| 78 |
+
|
| 79 |
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
shutil.copy2(src_path, dst_path)
|
| 81 |
+
|
| 82 |
+
logger.info(f"π Copied {src_path.name} to {dst_path}")
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
except Exception:
|
| 86 |
+
logger.exception("β Error copying file")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def copy_directory(self, src: str | Path, dst: str | Path) -> bool:
|
| 90 |
+
"""Copy a directory within the volume."""
|
| 91 |
+
try:
|
| 92 |
+
src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
|
| 93 |
+
dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
|
| 94 |
+
|
| 95 |
+
if dst_path.exists():
|
| 96 |
+
shutil.rmtree(dst_path)
|
| 97 |
+
|
| 98 |
+
shutil.copytree(src_path, dst_path)
|
| 99 |
+
|
| 100 |
+
logger.info(f"π Copied directory {src_path.name} to {dst_path}")
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
except Exception:
|
| 104 |
+
logger.exception("β Error copying directory")
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
def move_file(self, src: str | Path, dst: str | Path) -> bool:
|
| 108 |
+
"""Move a file within the volume."""
|
| 109 |
+
try:
|
| 110 |
+
src_path = self.mount_path / src if not Path(src).is_absolute() else Path(src)
|
| 111 |
+
dst_path = self.mount_path / dst if not Path(dst).is_absolute() else Path(dst)
|
| 112 |
+
|
| 113 |
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
shutil.move(str(src_path), str(dst_path))
|
| 115 |
+
|
| 116 |
+
logger.info(f"β‘οΈ Moved {src_path.name} to {dst_path}")
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
except Exception:
|
| 120 |
+
logger.exception("β Error moving file")
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
def remove_file(self, file_path: str | Path) -> bool:
|
| 124 |
+
"""Remove a file from the volume."""
|
| 125 |
+
try:
|
| 126 |
+
target_path = self.mount_path / file_path if not Path(file_path).is_absolute() else Path(file_path)
|
| 127 |
+
|
| 128 |
+
if target_path.exists():
|
| 129 |
+
if target_path.is_file():
|
| 130 |
+
target_path.unlink()
|
| 131 |
+
logger.info(f"ποΈ Removed file: {target_path.name}")
|
| 132 |
+
else:
|
| 133 |
+
logger.warning(f"β οΈ Path is not a file: {target_path}")
|
| 134 |
+
return False
|
| 135 |
+
return True
|
| 136 |
+
logger.warning(f"β οΈ File does not exist: {target_path}")
|
| 137 |
+
return False
|
| 138 |
+
|
| 139 |
+
except Exception:
|
| 140 |
+
logger.exception("β Error removing file")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
def remove_directory(self, dir_path: str | Path) -> bool:
|
| 144 |
+
"""Remove a directory from the volume."""
|
| 145 |
+
try:
|
| 146 |
+
target_path = self.mount_path / dir_path if not Path(dir_path).is_absolute() else Path(dir_path)
|
| 147 |
+
|
| 148 |
+
if target_path.exists() and target_path.is_dir():
|
| 149 |
+
shutil.rmtree(target_path)
|
| 150 |
+
logger.info(f"ποΈ Removed directory: {target_path.name}")
|
| 151 |
+
return True
|
| 152 |
+
logger.warning(f"β οΈ Directory does not exist: {target_path}")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
except Exception:
|
| 156 |
+
logger.exception("β Error removing directory")
|
| 157 |
+
return False
|
| 158 |
+
|
| 159 |
+
def cleanup_old_files(self, pattern: str = "*", older_than_days: int = 7, subpath: str = "") -> list[str]:
|
| 160 |
+
"""Clean up old files in the volume based on age and pattern."""
|
| 161 |
+
try:
|
| 162 |
+
target_path = self.mount_path / subpath if subpath else self.mount_path
|
| 163 |
+
if not target_path.exists():
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
cutoff_time = time.time() - (older_than_days * 24 * 3600)
|
| 167 |
+
removed_files: list[str] = []
|
| 168 |
+
|
| 169 |
+
for item in target_path.rglob(pattern):
|
| 170 |
+
if item.is_file() and item.stat().st_mtime < cutoff_time:
|
| 171 |
+
try:
|
| 172 |
+
item.unlink()
|
| 173 |
+
removed_files.append(str(item.relative_to(self.mount_path)))
|
| 174 |
+
logger.info(f"π§Ή Removed old file: {item.name}")
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.warning(f"β οΈ Could not remove {item.name}: {e}")
|
| 177 |
+
|
| 178 |
+
if removed_files:
|
| 179 |
+
logger.info(f"π§Ή Cleaned up {len(removed_files)} old files")
|
| 180 |
+
|
| 181 |
+
return removed_files
|
| 182 |
+
|
| 183 |
+
except Exception:
|
| 184 |
+
logger.exception("β Error during cleanup")
|
| 185 |
+
return []
|
| 186 |
+
|
| 187 |
+
def get_size(self, subpath: str = "") -> int:
|
| 188 |
+
"""Get total size of volume or subpath in bytes."""
|
| 189 |
+
try:
|
| 190 |
+
target_path = self.mount_path / subpath if subpath else self.mount_path
|
| 191 |
+
if not target_path.exists():
|
| 192 |
+
return 0
|
| 193 |
+
|
| 194 |
+
total_size = 0
|
| 195 |
+
for item in target_path.rglob("*"):
|
| 196 |
+
if item.is_file():
|
| 197 |
+
total_size += item.stat().st_size
|
| 198 |
+
|
| 199 |
+
return total_size
|
| 200 |
+
|
| 201 |
+
except Exception:
|
| 202 |
+
logger.exception("β Error calculating size")
|
| 203 |
+
return 0
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class BeamCheckpointManager:
|
| 207 |
+
"""Manager for checkpoint operations on Beam volumes."""
|
| 208 |
+
|
| 209 |
+
def __init__(self, volume_manager: BeamVolumeManager, checkpoint_prefix: str = "checkpoints") -> None:
|
| 210 |
+
"""
|
| 211 |
+
Initialize Checkpoint Manager.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
volume_manager: BeamVolumeManager instance
|
| 215 |
+
checkpoint_prefix: Prefix for checkpoint files
|
| 216 |
+
"""
|
| 217 |
+
self.volume = volume_manager
|
| 218 |
+
self.checkpoint_prefix = checkpoint_prefix
|
| 219 |
+
self.checkpoint_dir = self.volume.mount_path / checkpoint_prefix
|
| 220 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 221 |
+
|
| 222 |
+
def save_checkpoint(self, stage: str, data: dict[str, Any], step: int = 0) -> bool:
|
| 223 |
+
"""Save checkpoint to volume."""
|
| 224 |
+
try:
|
| 225 |
+
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
|
| 226 |
+
checkpoint_path = self.checkpoint_dir / checkpoint_filename
|
| 227 |
+
|
| 228 |
+
with checkpoint_path.open("w") as f:
|
| 229 |
+
json.dump(data, f, indent=2, default=str)
|
| 230 |
+
|
| 231 |
+
logger.info(f"πΎ Saved checkpoint: {stage} step {step}")
|
| 232 |
+
return True
|
| 233 |
+
|
| 234 |
+
except Exception:
|
| 235 |
+
logger.exception("β Error saving checkpoint")
|
| 236 |
+
return False
|
| 237 |
+
|
| 238 |
+
def load_checkpoint(self, stage: str, step: int = 0) -> dict[str, Any] | None:
|
| 239 |
+
"""Load checkpoint from volume."""
|
| 240 |
+
try:
|
| 241 |
+
checkpoint_filename = f"{self.checkpoint_prefix}_{stage}_step_{step}.json"
|
| 242 |
+
checkpoint_path = self.checkpoint_dir / checkpoint_filename
|
| 243 |
+
|
| 244 |
+
if checkpoint_path.exists():
|
| 245 |
+
with checkpoint_path.open("r") as f:
|
| 246 |
+
data = json.load(f)
|
| 247 |
+
logger.info(f"π Loaded checkpoint: {stage} step {step}")
|
| 248 |
+
return data
|
| 249 |
+
|
| 250 |
+
logger.info(f"Info: No checkpoint found: {stage} step {step}")
|
| 251 |
+
return None
|
| 252 |
+
|
| 253 |
+
except Exception:
|
| 254 |
+
logger.exception("β Error loading checkpoint")
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
def get_latest_checkpoint(self, stage: str) -> tuple[int, dict[str, Any]] | None:
|
| 258 |
+
"""Get the latest checkpoint for a stage."""
|
| 259 |
+
try:
|
| 260 |
+
# Find checkpoint files for this stage
|
| 261 |
+
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
|
| 262 |
+
stage_checkpoints: list[tuple[int, Path]] = []
|
| 263 |
+
|
| 264 |
+
for checkpoint_file in self.checkpoint_dir.glob(pattern):
|
| 265 |
+
try:
|
| 266 |
+
# Extract step number from filename
|
| 267 |
+
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
|
| 268 |
+
step = int(step_str)
|
| 269 |
+
stage_checkpoints.append((step, checkpoint_file))
|
| 270 |
+
except ValueError:
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
if not stage_checkpoints:
|
| 274 |
+
logger.info(f"Info: No checkpoints found for stage: {stage}")
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
# Get the latest checkpoint
|
| 278 |
+
latest_step, latest_file = max(stage_checkpoints, key=lambda x: x[0])
|
| 279 |
+
|
| 280 |
+
# Load the latest checkpoint
|
| 281 |
+
with latest_file.open("r") as f:
|
| 282 |
+
data = json.load(f)
|
| 283 |
+
logger.info(f"π Found latest checkpoint: {stage} step {latest_step}")
|
| 284 |
+
return latest_step, data
|
| 285 |
+
|
| 286 |
+
except Exception:
|
| 287 |
+
logger.exception("β Error finding latest checkpoint")
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
def cleanup_old_checkpoints(self, stage: str, keep_latest: int = 3) -> list[str]:
|
| 291 |
+
"""Clean up old checkpoints, keeping only the latest N."""
|
| 292 |
+
try:
|
| 293 |
+
# Find checkpoint files for this stage
|
| 294 |
+
pattern = f"{self.checkpoint_prefix}_{stage}_step_*.json"
|
| 295 |
+
stage_checkpoints: list[tuple[int, Path]] = []
|
| 296 |
+
|
| 297 |
+
for checkpoint_file in self.checkpoint_dir.glob(pattern):
|
| 298 |
+
try:
|
| 299 |
+
step_str = checkpoint_file.stem.replace(f"{self.checkpoint_prefix}_{stage}_step_", "")
|
| 300 |
+
step = int(step_str)
|
| 301 |
+
stage_checkpoints.append((step, checkpoint_file))
|
| 302 |
+
except ValueError:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
# Sort by step number (descending)
|
| 306 |
+
stage_checkpoints.sort(key=lambda x: x[0], reverse=True)
|
| 307 |
+
|
| 308 |
+
# Remove old checkpoints
|
| 309 |
+
removed_files: list[str] = []
|
| 310 |
+
if len(stage_checkpoints) > keep_latest:
|
| 311 |
+
for _step, checkpoint_file in stage_checkpoints[keep_latest:]:
|
| 312 |
+
try:
|
| 313 |
+
checkpoint_file.unlink()
|
| 314 |
+
removed_files.append(checkpoint_file.name)
|
| 315 |
+
logger.info(f"π§Ή Removed old checkpoint: {checkpoint_file.name}")
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.warning(f"β οΈ Could not remove {checkpoint_file.name}: {e}")
|
| 318 |
+
|
| 319 |
+
if removed_files:
|
| 320 |
+
logger.info(f"π§Ή Cleaned up {len(removed_files)} old checkpoints for {stage}")
|
| 321 |
+
|
| 322 |
+
return removed_files
|
| 323 |
+
|
| 324 |
+
except Exception:
|
| 325 |
+
logger.exception("β Error cleaning up checkpoints")
|
| 326 |
+
return []
|
| 327 |
+
|
| 328 |
+
def list_checkpoints(self, stage: str | None = None) -> list[dict[str, Any]]:
|
| 329 |
+
"""List all checkpoints, optionally filtered by stage."""
|
| 330 |
+
try:
|
| 331 |
+
checkpoints: list[dict[str, Any]] = []
|
| 332 |
+
pattern = f"{self.checkpoint_prefix}_*.json"
|
| 333 |
+
|
| 334 |
+
for checkpoint_file in self.checkpoint_dir.glob(pattern):
|
| 335 |
+
# Parse checkpoint info
|
| 336 |
+
name_parts = checkpoint_file.stem.split("_")
|
| 337 |
+
if len(name_parts) >= 4:
|
| 338 |
+
checkpoint_stage = name_parts[1]
|
| 339 |
+
try:
|
| 340 |
+
step = int(name_parts[3])
|
| 341 |
+
except ValueError:
|
| 342 |
+
step = 0
|
| 343 |
+
|
| 344 |
+
if stage is None or checkpoint_stage == stage:
|
| 345 |
+
stat = checkpoint_file.stat()
|
| 346 |
+
checkpoints.append(
|
| 347 |
+
{
|
| 348 |
+
"stage": checkpoint_stage,
|
| 349 |
+
"step": step,
|
| 350 |
+
"filename": checkpoint_file.name,
|
| 351 |
+
"size": f"{stat.st_size / 1024:.1f}KB",
|
| 352 |
+
"modified": time.ctime(stat.st_mtime),
|
| 353 |
+
}
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return sorted(checkpoints, key=lambda x: (x["stage"], x["step"]))
|
| 357 |
+
|
| 358 |
+
except Exception:
|
| 359 |
+
logger.exception("β Error listing checkpoints")
|
| 360 |
+
return []
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class BeamModelManager:
|
| 364 |
+
"""Manager for model artifacts on Beam volumes."""
|
| 365 |
+
|
| 366 |
+
def __init__(self, volume_manager: BeamVolumeManager, model_prefix: str = "models") -> None:
|
| 367 |
+
"""
|
| 368 |
+
Initialize Model Manager.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
volume_manager: BeamVolumeManager instance
|
| 372 |
+
model_prefix: Prefix for model files
|
| 373 |
+
"""
|
| 374 |
+
self.volume = volume_manager
|
| 375 |
+
self.model_prefix = model_prefix
|
| 376 |
+
self.model_dir = self.volume.mount_path / model_prefix
|
| 377 |
+
self.model_dir.mkdir(parents=True, exist_ok=True)
|
| 378 |
+
|
| 379 |
+
def save_model(self, model_name: str, local_model_path: str | Path) -> bool:
|
| 380 |
+
"""Save model to Beam volume."""
|
| 381 |
+
try:
|
| 382 |
+
local_path = Path(local_model_path)
|
| 383 |
+
if not local_path.exists():
|
| 384 |
+
logger.error(f"β Model path does not exist: {local_path}")
|
| 385 |
+
return False
|
| 386 |
+
|
| 387 |
+
model_dest = self.model_dir / model_name
|
| 388 |
+
|
| 389 |
+
if local_path.is_dir():
|
| 390 |
+
# Copy entire directory
|
| 391 |
+
if model_dest.exists():
|
| 392 |
+
shutil.rmtree(model_dest)
|
| 393 |
+
shutil.copytree(local_path, model_dest)
|
| 394 |
+
logger.info(f"πΎ Saved model directory {model_name}")
|
| 395 |
+
else:
|
| 396 |
+
# Copy single file
|
| 397 |
+
model_dest.mkdir(exist_ok=True)
|
| 398 |
+
shutil.copy2(local_path, model_dest / local_path.name)
|
| 399 |
+
logger.info(f"πΎ Saved model file {model_name}")
|
| 400 |
+
|
| 401 |
+
return True
|
| 402 |
+
|
| 403 |
+
except Exception:
|
| 404 |
+
logger.exception("β Error saving model")
|
| 405 |
+
return False
|
| 406 |
+
|
| 407 |
+
def load_model(self, model_name: str, local_model_path: str | Path = "./models") -> bool:
|
| 408 |
+
"""Load model from Beam volume."""
|
| 409 |
+
try:
|
| 410 |
+
local_path = Path(local_model_path)
|
| 411 |
+
local_path.mkdir(parents=True, exist_ok=True)
|
| 412 |
+
|
| 413 |
+
model_source = self.model_dir / model_name
|
| 414 |
+
model_dest = local_path / model_name
|
| 415 |
+
|
| 416 |
+
if not model_source.exists():
|
| 417 |
+
logger.error(f"β Model does not exist: {model_name}")
|
| 418 |
+
return False
|
| 419 |
+
|
| 420 |
+
if model_dest.exists():
|
| 421 |
+
if model_dest.is_dir():
|
| 422 |
+
shutil.rmtree(model_dest)
|
| 423 |
+
else:
|
| 424 |
+
model_dest.unlink()
|
| 425 |
+
|
| 426 |
+
if model_source.is_dir():
|
| 427 |
+
shutil.copytree(model_source, model_dest)
|
| 428 |
+
else:
|
| 429 |
+
shutil.copy2(model_source, model_dest)
|
| 430 |
+
|
| 431 |
+
logger.info(f"π₯ Loaded model {model_name}")
|
| 432 |
+
return True
|
| 433 |
+
|
| 434 |
+
except Exception:
|
| 435 |
+
logger.exception("β Error loading model")
|
| 436 |
+
return False
|
| 437 |
+
|
| 438 |
+
def list_models(self) -> list[dict[str, str]]:
|
| 439 |
+
"""List all models in the volume."""
|
| 440 |
+
try:
|
| 441 |
+
models: list[dict[str, str]] = []
|
| 442 |
+
|
| 443 |
+
if not self.model_dir.exists():
|
| 444 |
+
return models
|
| 445 |
+
|
| 446 |
+
for item in self.model_dir.iterdir():
|
| 447 |
+
if item.is_dir():
|
| 448 |
+
stat = item.stat()
|
| 449 |
+
# Calculate directory size
|
| 450 |
+
total_size = sum(f.stat().st_size for f in item.rglob("*") if f.is_file())
|
| 451 |
+
|
| 452 |
+
models.append(
|
| 453 |
+
{
|
| 454 |
+
"name": item.name,
|
| 455 |
+
"size": f"{total_size / (1024 * 1024):.1f}MB",
|
| 456 |
+
"modified": time.ctime(stat.st_mtime),
|
| 457 |
+
}
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
return sorted(models, key=lambda x: x["name"])
|
| 461 |
+
|
| 462 |
+
except Exception:
|
| 463 |
+
logger.exception("β Error listing models")
|
| 464 |
+
return []
|
| 465 |
+
|
| 466 |
+
def remove_model(self, model_name: str) -> bool:
|
| 467 |
+
"""Remove a model from the volume."""
|
| 468 |
+
try:
|
| 469 |
+
model_path = self.model_dir / model_name
|
| 470 |
+
|
| 471 |
+
if model_path.exists():
|
| 472 |
+
if model_path.is_dir():
|
| 473 |
+
shutil.rmtree(model_path)
|
| 474 |
+
else:
|
| 475 |
+
model_path.unlink()
|
| 476 |
+
logger.info(f"ποΈ Removed model: {model_name}")
|
| 477 |
+
return True
|
| 478 |
+
logger.warning(f"β οΈ Model does not exist: {model_name}")
|
| 479 |
+
return False
|
| 480 |
+
|
| 481 |
+
except Exception:
|
| 482 |
+
logger.exception("β Error removing model")
|
| 483 |
+
return False
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class BeamEvaluationManager:
|
| 487 |
+
"""Manager for evaluation results on Beam volumes."""
|
| 488 |
+
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
volume_manager: BeamVolumeManager,
|
| 492 |
+
results_prefix: str = "evaluation_results",
|
| 493 |
+
) -> None:
|
| 494 |
+
"""
|
| 495 |
+
Initialize Evaluation Manager.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
volume_manager: BeamVolumeManager instance
|
| 499 |
+
results_prefix: Prefix for evaluation result files
|
| 500 |
+
"""
|
| 501 |
+
self.volume = volume_manager
|
| 502 |
+
self.results_prefix = results_prefix
|
| 503 |
+
self.results_dir = self.volume.mount_path / results_prefix
|
| 504 |
+
self.results_dir.mkdir(parents=True, exist_ok=True)
|
| 505 |
+
|
| 506 |
+
def save_evaluation_results(
|
| 507 |
+
self, model_name: str, results: dict[str, Any], eval_type: str = "codesearchnet"
|
| 508 |
+
) -> bool:
|
| 509 |
+
"""Save evaluation results to Beam volume."""
|
| 510 |
+
try:
|
| 511 |
+
results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
|
| 512 |
+
results_path = self.results_dir / results_filename
|
| 513 |
+
|
| 514 |
+
with results_path.open("w") as f:
|
| 515 |
+
json.dump(results, f, indent=2, default=str)
|
| 516 |
+
|
| 517 |
+
logger.info(f"πΎ Saved evaluation results for {model_name}")
|
| 518 |
+
return True
|
| 519 |
+
|
| 520 |
+
except Exception:
|
| 521 |
+
logger.exception("β Error saving evaluation results")
|
| 522 |
+
return False
|
| 523 |
+
|
| 524 |
+
def load_evaluation_results(self, model_name: str, eval_type: str = "codesearchnet") -> dict[str, Any] | None:
|
| 525 |
+
"""Load evaluation results from Beam volume."""
|
| 526 |
+
try:
|
| 527 |
+
results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
|
| 528 |
+
results_path = self.results_dir / results_filename
|
| 529 |
+
|
| 530 |
+
if results_path.exists():
|
| 531 |
+
with results_path.open("r") as f:
|
| 532 |
+
results = json.load(f)
|
| 533 |
+
logger.info(f"π Loaded evaluation results for {model_name}")
|
| 534 |
+
return results
|
| 535 |
+
|
| 536 |
+
logger.info(f"Info: No evaluation results found for {model_name}")
|
| 537 |
+
return None
|
| 538 |
+
|
| 539 |
+
except Exception:
|
| 540 |
+
logger.exception("β Error loading evaluation results")
|
| 541 |
+
return None
|
| 542 |
+
|
| 543 |
+
def list_evaluation_results(self, eval_type: str | None = None) -> list[dict[str, str]]:
|
| 544 |
+
"""List all evaluation results."""
|
| 545 |
+
try:
|
| 546 |
+
results: list[dict[str, str]] = []
|
| 547 |
+
|
| 548 |
+
if not self.results_dir.exists():
|
| 549 |
+
return results
|
| 550 |
+
|
| 551 |
+
for item in self.results_dir.glob("*.json"):
|
| 552 |
+
# Parse evaluation info
|
| 553 |
+
if eval_type is None or item.name.startswith(f"{eval_type}_eval_"):
|
| 554 |
+
# Extract model name from filename
|
| 555 |
+
model_name = item.name.replace("_eval_", "_").replace(".json", "")
|
| 556 |
+
if eval_type:
|
| 557 |
+
model_name = model_name.replace(f"{eval_type}_", "")
|
| 558 |
+
|
| 559 |
+
stat = item.stat()
|
| 560 |
+
results.append(
|
| 561 |
+
{
|
| 562 |
+
"model_name": model_name,
|
| 563 |
+
"filename": item.name,
|
| 564 |
+
"size": f"{stat.st_size / 1024:.1f}KB",
|
| 565 |
+
"modified": time.ctime(stat.st_mtime),
|
| 566 |
+
}
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
return sorted(results, key=lambda x: x["model_name"])
|
| 570 |
+
|
| 571 |
+
except Exception:
|
| 572 |
+
logger.exception("β Error listing evaluation results")
|
| 573 |
+
return []
|
| 574 |
+
|
| 575 |
+
def remove_evaluation_results(self, model_name: str, eval_type: str = "codesearchnet") -> bool:
|
| 576 |
+
"""Remove evaluation results from volume."""
|
| 577 |
+
try:
|
| 578 |
+
results_filename = f"{eval_type}_eval_{model_name.replace('/', '_')}.json"
|
| 579 |
+
results_path = self.results_dir / results_filename
|
| 580 |
+
|
| 581 |
+
if results_path.exists():
|
| 582 |
+
results_path.unlink()
|
| 583 |
+
logger.info(f"ποΈ Removed evaluation results for {model_name}")
|
| 584 |
+
return True
|
| 585 |
+
logger.warning(f"β οΈ Evaluation results do not exist for {model_name}")
|
| 586 |
+
return False
|
| 587 |
+
|
| 588 |
+
except Exception:
|
| 589 |
+
logger.exception("β Error removing evaluation results")
|
| 590 |
+
return False
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def create_beam_utilities(
|
| 594 |
+
volume_name: str, mount_path: str = "./data"
|
| 595 |
+
) -> tuple[BeamVolumeManager, BeamCheckpointManager, BeamModelManager, BeamEvaluationManager]:
|
| 596 |
+
"""
|
| 597 |
+
Create a complete set of Beam utilities.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
volume_name: Name of the Beam volume
|
| 601 |
+
mount_path: Local mount path for the volume
|
| 602 |
+
|
| 603 |
+
Returns:
|
| 604 |
+
Tuple of (volume_manager, checkpoint_manager, model_manager, evaluation_manager)
|
| 605 |
+
"""
|
| 606 |
+
volume_manager = BeamVolumeManager(volume_name, mount_path)
|
| 607 |
+
checkpoint_manager = BeamCheckpointManager(volume_manager)
|
| 608 |
+
model_manager = BeamModelManager(volume_manager)
|
| 609 |
+
evaluation_manager = BeamEvaluationManager(volume_manager)
|
| 610 |
+
|
| 611 |
+
return volume_manager, checkpoint_manager, model_manager, evaluation_manager
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def cleanup_beam_workspace(volume_name: str, mount_path: str = "./data", confirm: bool = False) -> bool:
|
| 615 |
+
"""
|
| 616 |
+
Clean up entire Beam workspace including all data in the mounted volume.
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
volume_name: Name of the volume to clean up
|
| 620 |
+
mount_path: Mount path of the volume
|
| 621 |
+
confirm: If True, skip confirmation prompt
|
| 622 |
+
|
| 623 |
+
Returns:
|
| 624 |
+
True if cleanup successful, False otherwise
|
| 625 |
+
"""
|
| 626 |
+
if not confirm:
|
| 627 |
+
response = input(f"β οΈ This will delete all data in volume '{volume_name}' at '{mount_path}'. Continue? (y/N): ")
|
| 628 |
+
if response.lower() != "y":
|
| 629 |
+
logger.info("Cleanup cancelled")
|
| 630 |
+
return False
|
| 631 |
+
|
| 632 |
+
try:
|
| 633 |
+
volume_manager = BeamVolumeManager(volume_name, mount_path)
|
| 634 |
+
|
| 635 |
+
if not volume_manager.exists():
|
| 636 |
+
logger.info(f"Volume mount path does not exist: {mount_path}")
|
| 637 |
+
return True
|
| 638 |
+
|
| 639 |
+
# List what will be deleted
|
| 640 |
+
contents = volume_manager.list_contents()
|
| 641 |
+
logger.info(f"ποΈ Will delete {len(contents)} items from volume '{volume_name}'")
|
| 642 |
+
|
| 643 |
+
# Delete all contents in the mount path
|
| 644 |
+
for item in volume_manager.mount_path.iterdir():
|
| 645 |
+
try:
|
| 646 |
+
if item.is_dir():
|
| 647 |
+
shutil.rmtree(item)
|
| 648 |
+
logger.info(f"ποΈ Removed directory: {item.name}")
|
| 649 |
+
else:
|
| 650 |
+
item.unlink()
|
| 651 |
+
logger.info(f"ποΈ Removed file: {item.name}")
|
| 652 |
+
except Exception as e:
|
| 653 |
+
logger.warning(f"β οΈ Could not remove {item.name}: {e}")
|
| 654 |
+
|
| 655 |
+
logger.info(f"β
Successfully cleaned up Beam workspace: {volume_name}")
|
| 656 |
+
return True
|
| 657 |
+
|
| 658 |
+
except Exception:
|
| 659 |
+
logger.exception("β Error during cleanup")
|
| 660 |
+
return False
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def get_workspace_info(volume_name: str, mount_path: str = "./data") -> dict[str, Any]:
|
| 664 |
+
"""
|
| 665 |
+
Get information about the Beam workspace.
|
| 666 |
+
|
| 667 |
+
Args:
|
| 668 |
+
volume_name: Name of the volume
|
| 669 |
+
mount_path: Mount path of the volume
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Dictionary with workspace information
|
| 673 |
+
"""
|
| 674 |
+
try:
|
| 675 |
+
volume_manager = BeamVolumeManager(volume_name, mount_path)
|
| 676 |
+
|
| 677 |
+
if not volume_manager.exists():
|
| 678 |
+
return {
|
| 679 |
+
"volume_name": volume_name,
|
| 680 |
+
"mount_path": mount_path,
|
| 681 |
+
"exists": False,
|
| 682 |
+
"size": 0,
|
| 683 |
+
"contents": [],
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
contents = volume_manager.list_contents()
|
| 687 |
+
total_size = volume_manager.get_size()
|
| 688 |
+
|
| 689 |
+
return {
|
| 690 |
+
"volume_name": volume_name,
|
| 691 |
+
"mount_path": str(volume_manager.mount_path),
|
| 692 |
+
"exists": True,
|
| 693 |
+
"size": total_size,
|
| 694 |
+
"size_mb": f"{total_size / (1024 * 1024):.1f}MB",
|
| 695 |
+
"num_items": len(contents),
|
| 696 |
+
"contents": contents[:10], # First 10 items
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
except Exception:
|
| 700 |
+
logger.exception("β Error getting workspace info")
|
| 701 |
+
return {
|
| 702 |
+
"volume_name": volume_name,
|
| 703 |
+
"mount_path": mount_path,
|
| 704 |
+
"error": "Error occurred",
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
# Example usage functions
|
| 709 |
+
def example_distillation_workflow() -> None:
|
| 710 |
+
"""Example of using Beam utilities for distillation workflow."""
|
| 711 |
+
volume_name = "gte_qwen2_m2v_code"
|
| 712 |
+
mount_path = "./gte_qwen2_m2v_code" # Should match Beam function mount path
|
| 713 |
+
|
| 714 |
+
# Create utilities
|
| 715 |
+
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
|
| 716 |
+
|
| 717 |
+
# Check if volume exists
|
| 718 |
+
if volume_mgr.exists():
|
| 719 |
+
logger.info(f"Volume {volume_name} is mounted at {mount_path}")
|
| 720 |
+
else:
|
| 721 |
+
logger.warning(f"Volume {volume_name} not found at {mount_path}")
|
| 722 |
+
return
|
| 723 |
+
|
| 724 |
+
# Save a checkpoint
|
| 725 |
+
checkpoint_data = {
|
| 726 |
+
"epoch": 1,
|
| 727 |
+
"loss": 0.25,
|
| 728 |
+
"model_state": "dummy_state",
|
| 729 |
+
"timestamp": time.time(),
|
| 730 |
+
}
|
| 731 |
+
checkpoint_mgr.save_checkpoint("training", checkpoint_data, step=1000)
|
| 732 |
+
|
| 733 |
+
# List checkpoints
|
| 734 |
+
checkpoints = checkpoint_mgr.list_checkpoints("training")
|
| 735 |
+
logger.info(f"Found {len(checkpoints)} training checkpoints")
|
| 736 |
+
|
| 737 |
+
# Save evaluation results
|
| 738 |
+
eval_results = {
|
| 739 |
+
"model_name": "gte_qwen2_m2v_code",
|
| 740 |
+
"overall": {"ndcg@10": 0.35, "mrr": 0.42},
|
| 741 |
+
"timestamp": time.time(),
|
| 742 |
+
}
|
| 743 |
+
eval_mgr.save_evaluation_results("gte_qwen2_m2v_code", eval_results)
|
| 744 |
+
|
| 745 |
+
# Get workspace info
|
| 746 |
+
info = get_workspace_info(volume_name, mount_path)
|
| 747 |
+
logger.info(f"Workspace info: {info}")
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
if __name__ == "__main__":
|
| 751 |
+
# Example usage
|
| 752 |
+
logging.basicConfig(level=logging.INFO)
|
| 753 |
+
example_distillation_workflow()
|
src/distiller/benchmark.py
ADDED
|
@@ -0,0 +1,1181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Operational Performance Benchmarking for Embedding Models.
|
| 3 |
+
|
| 4 |
+
This module benchmarks embedding models on operational metrics like:
|
| 5 |
+
- Inference speed (latency and throughput)
|
| 6 |
+
- Memory efficiency (RAM and GPU usage)
|
| 7 |
+
- Model size and storage requirements
|
| 8 |
+
- Scalability with batch size
|
| 9 |
+
- CPU vs GPU performance
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import gc
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import psutil
|
| 22 |
+
import torch
|
| 23 |
+
from beam import GpuType, Image, Volume, function
|
| 24 |
+
from sentence_transformers import SentenceTransformer
|
| 25 |
+
|
| 26 |
+
from .beam_utils import (
|
| 27 |
+
BeamCheckpointManager,
|
| 28 |
+
BeamEvaluationManager,
|
| 29 |
+
create_beam_utilities,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# =============================================================================
|
| 35 |
+
# BEAM CONFIGURATION
|
| 36 |
+
# =============================================================================
|
| 37 |
+
|
| 38 |
+
GPU_NAME = GpuType.A100_40
|
| 39 |
+
VOLUME_NAME = "gte_qwen2_m2v_code" # Same volume as distill.py and evaluate.py
|
| 40 |
+
VOLUME_PATH = "./gte_qwen2_m2v_code" # Same mount path as distill.py and evaluate.py
|
| 41 |
+
BENCHMARK_RESULTS_DIR = "benchmark_results" # Subdirectory within volume
|
| 42 |
+
BENCHMARK_CACHE_DIR = "benchmark_cache" # Cache for models
|
| 43 |
+
|
| 44 |
+
IMAGE = Image(python_version="python3.12").add_python_packages(
|
| 45 |
+
[
|
| 46 |
+
"torch>=2.7.0",
|
| 47 |
+
"transformers>=4.40.0",
|
| 48 |
+
"datasets>=3.2.0",
|
| 49 |
+
"sentence-transformers>=4.1.0",
|
| 50 |
+
"model2vec[train]>=0.5.0",
|
| 51 |
+
"numpy>=1.26.4",
|
| 52 |
+
"scikit-learn>=1.6.1",
|
| 53 |
+
"pandas>=2.0.0",
|
| 54 |
+
"tqdm>=4.65.0",
|
| 55 |
+
"psutil>=5.9.0",
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# =============================================================================
|
| 60 |
+
# CONFIGURATION
|
| 61 |
+
# =============================================================================
|
| 62 |
+
|
| 63 |
+
DEFAULT_OUTPUT_DIR = "benchmark_results" # Local fallback directory
|
| 64 |
+
|
| 65 |
+
# Default models to benchmark (can be overridden via command line)
|
| 66 |
+
DEFAULT_BENCHMARK_MODELS = [
|
| 67 |
+
# Your distilled model (local files in Beam volume root)
|
| 68 |
+
"gte_qwen2_m2v_code", # This will be resolved to VOLUME_PATH in Beam
|
| 69 |
+
# Established Code Models
|
| 70 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 71 |
+
"microsoft/codebert-base",
|
| 72 |
+
"microsoft/graphcodebert-base",
|
| 73 |
+
"huggingface/CodeBERTa-small-v1",
|
| 74 |
+
"sentence-transformers/all-mpnet-base-v2",
|
| 75 |
+
"sentence-transformers/all-MiniLM-L12-v2",
|
| 76 |
+
# Model2Vec & Efficiency Models (Direct Competitors)
|
| 77 |
+
"minishlab/potion-base-8M",
|
| 78 |
+
"minishlab/potion-retrieval-32M",
|
| 79 |
+
# Small Transformer-Based Code Models
|
| 80 |
+
"Salesforce/codet5-base",
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
# =============================================================================
|
| 84 |
+
# CHECKPOINT CONFIGURATION
|
| 85 |
+
# =============================================================================
|
| 86 |
+
|
| 87 |
+
# Prevent conflicts with other modules by using unique prefixes
|
| 88 |
+
BENCHMARK_CHECKPOINT_PREFIX = "benchmark_checkpoints"
|
| 89 |
+
MODEL_CACHE_PREFIX = "model_cache"
|
| 90 |
+
|
| 91 |
+
# Sample texts for benchmarking (various lengths)
|
| 92 |
+
BENCHMARK_TEXTS = {
|
| 93 |
+
"short": [
|
| 94 |
+
"def add(a, b): return a + b",
|
| 95 |
+
"function multiply(x, y) { return x * y; }",
|
| 96 |
+
"class Calculator { public int subtract(int a, int b) { return a - b; } }",
|
| 97 |
+
]
|
| 98 |
+
* 100, # 300 short texts
|
| 99 |
+
"medium": [
|
| 100 |
+
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
| 101 |
+
"function quickSort(arr) {\n if (arr.length <= 1) return arr;\n const pivot = arr[arr.length - 1];\n const left = [], right = [];\n for (let i = 0; i < arr.length - 1; i++) {\n if (arr[i] < pivot) left.push(arr[i]);\n else right.push(arr[i]);\n }\n return [...quickSort(left), pivot, ...quickSort(right)];\n}",
|
| 102 |
+
]
|
| 103 |
+
* 50, # 100 medium texts
|
| 104 |
+
"long": [
|
| 105 |
+
"""
|
| 106 |
+
def complex_algorithm(data, config):
|
| 107 |
+
'''
|
| 108 |
+
Complex data processing algorithm with multiple steps.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
data: Input data structure
|
| 112 |
+
config: Configuration parameters
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Processed results
|
| 116 |
+
'''
|
| 117 |
+
results = []
|
| 118 |
+
|
| 119 |
+
# Step 1: Data validation
|
| 120 |
+
if not isinstance(data, (list, tuple)):
|
| 121 |
+
raise ValueError("Data must be list or tuple")
|
| 122 |
+
|
| 123 |
+
# Step 2: Preprocessing
|
| 124 |
+
processed_data = []
|
| 125 |
+
for item in data:
|
| 126 |
+
if config.get('normalize', False):
|
| 127 |
+
item = normalize_item(item)
|
| 128 |
+
if config.get('filter', False):
|
| 129 |
+
if not filter_item(item, config['filter_criteria']):
|
| 130 |
+
continue
|
| 131 |
+
processed_data.append(item)
|
| 132 |
+
|
| 133 |
+
# Step 3: Main processing
|
| 134 |
+
for item in processed_data:
|
| 135 |
+
result = process_item(item, config)
|
| 136 |
+
if result is not None:
|
| 137 |
+
results.append(result)
|
| 138 |
+
|
| 139 |
+
# Step 4: Post-processing
|
| 140 |
+
if config.get('sort', False):
|
| 141 |
+
results.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 142 |
+
|
| 143 |
+
return results
|
| 144 |
+
""".strip(),
|
| 145 |
+
]
|
| 146 |
+
* 20, # 20 long texts
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class PerformanceBenchmark:
|
| 151 |
+
"""Comprehensive performance benchmarking for embedding models."""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
model_path: str,
|
| 156 |
+
model_name: str | None = None,
|
| 157 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 158 |
+
eval_manager: BeamEvaluationManager | None = None,
|
| 159 |
+
) -> None:
|
| 160 |
+
"""Initialize benchmarker with model and optional Beam utilities."""
|
| 161 |
+
self.model_path = model_path
|
| 162 |
+
self.model_name = model_name or Path(model_path).name
|
| 163 |
+
self.model: SentenceTransformer | None = None
|
| 164 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 165 |
+
self.results: dict[str, Any] = {}
|
| 166 |
+
self.checkpoint_manager = checkpoint_manager
|
| 167 |
+
self.eval_manager = eval_manager
|
| 168 |
+
|
| 169 |
+
def load_model(self) -> None:
|
| 170 |
+
"""Load the embedding model."""
|
| 171 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 172 |
+
start_time = time.time()
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
self.model = SentenceTransformer(self.model_path, device=self.device, trust_remote_code=True)
|
| 176 |
+
load_time = time.time() - start_time
|
| 177 |
+
|
| 178 |
+
logger.info(f"β
Model loaded in {load_time:.2f}s on {self.device}")
|
| 179 |
+
self.results["model_load_time"] = load_time
|
| 180 |
+
|
| 181 |
+
except Exception:
|
| 182 |
+
logger.exception("β Failed to load model")
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
def measure_model_size(self) -> dict[str, float]:
|
| 186 |
+
"""Measure model size metrics."""
|
| 187 |
+
logger.info("π Measuring model size...")
|
| 188 |
+
|
| 189 |
+
size_metrics = {}
|
| 190 |
+
|
| 191 |
+
# Disk size - handle both local paths and HuggingFace models
|
| 192 |
+
try:
|
| 193 |
+
if Path(self.model_path).is_dir():
|
| 194 |
+
# Local directory - calculate size of model files only
|
| 195 |
+
model_extensions = {".safetensors", ".bin", ".json", ".txt", ".tokenizer"}
|
| 196 |
+
total_size = 0
|
| 197 |
+
model_dir = Path(self.model_path)
|
| 198 |
+
|
| 199 |
+
for file_path in model_dir.rglob("*"):
|
| 200 |
+
if file_path.is_file() and (
|
| 201 |
+
file_path.suffix.lower() in model_extensions
|
| 202 |
+
or file_path.name.lower() in {"config.json", "tokenizer.json", "modules.json", "README.md"}
|
| 203 |
+
):
|
| 204 |
+
total_size += file_path.stat().st_size
|
| 205 |
+
|
| 206 |
+
size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
|
| 207 |
+
elif Path(self.model_path).is_file():
|
| 208 |
+
# Single file
|
| 209 |
+
total_size = Path(self.model_path).stat().st_size
|
| 210 |
+
size_metrics["disk_size_mb"] = total_size / (1024 * 1024)
|
| 211 |
+
else:
|
| 212 |
+
# HuggingFace model - estimate from cache if available
|
| 213 |
+
from transformers import AutoConfig
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
config = AutoConfig.from_pretrained(self.model_path)
|
| 217 |
+
# Estimate size based on parameters (rough approximation)
|
| 218 |
+
if hasattr(config, "hidden_size") and hasattr(config, "num_hidden_layers"):
|
| 219 |
+
# Rough estimation for transformer models
|
| 220 |
+
estimated_params = config.hidden_size * config.num_hidden_layers * 1000 # Very rough
|
| 221 |
+
size_metrics["disk_size_mb"] = estimated_params * 4 / (1024 * 1024) # 4 bytes per float32
|
| 222 |
+
else:
|
| 223 |
+
size_metrics["disk_size_mb"] = 0 # Unknown
|
| 224 |
+
except Exception:
|
| 225 |
+
logger.warning(f"Could not determine disk size for HuggingFace model: {self.model_path}")
|
| 226 |
+
size_metrics["disk_size_mb"] = 0 # Unknown
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.warning(f"Could not determine disk size: {e}")
|
| 229 |
+
size_metrics["disk_size_mb"] = 0
|
| 230 |
+
|
| 231 |
+
# Model parameters (if accessible)
|
| 232 |
+
try:
|
| 233 |
+
if self.model is not None and hasattr(self.model, "modules"):
|
| 234 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 235 |
+
size_metrics["parameters_millions"] = total_params / 1_000_000
|
| 236 |
+
|
| 237 |
+
# Try to get embedding dimension from model config
|
| 238 |
+
try:
|
| 239 |
+
# Use the public modules() method instead of private _modules
|
| 240 |
+
modules = list(self.model.modules())
|
| 241 |
+
if len(modules) > 1: # modules[0] is usually the entire model, modules[1] is first submodule
|
| 242 |
+
first_module = modules[1]
|
| 243 |
+
if hasattr(first_module, "auto_model") and hasattr(first_module.auto_model, "config"):
|
| 244 |
+
config = first_module.auto_model.config
|
| 245 |
+
if hasattr(config, "hidden_size"):
|
| 246 |
+
size_metrics["embedding_dim"] = config.hidden_size
|
| 247 |
+
elif hasattr(config, "model_dim"):
|
| 248 |
+
size_metrics["embedding_dim"] = config.model_dim
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.debug(
|
| 251 |
+
f"Could not extract embedding dimension from model config: {e}"
|
| 252 |
+
) # Silently continue if this method fails
|
| 253 |
+
|
| 254 |
+
# For Model2Vec static models
|
| 255 |
+
elif self.model is not None and hasattr(self.model, "embedding"):
|
| 256 |
+
# Handle both tensor and numpy array embeddings
|
| 257 |
+
embedding = self.model.embedding
|
| 258 |
+
if hasattr(embedding, "shape"):
|
| 259 |
+
vocab_size, embedding_dim = embedding.shape # type: ignore[misc]
|
| 260 |
+
total_params = vocab_size * embedding_dim
|
| 261 |
+
size_metrics["parameters_millions"] = total_params / 1_000_000
|
| 262 |
+
size_metrics["vocab_size"] = vocab_size
|
| 263 |
+
size_metrics["embedding_dim"] = embedding_dim
|
| 264 |
+
else:
|
| 265 |
+
logger.warning("Could not determine embedding shape for Model2Vec model")
|
| 266 |
+
|
| 267 |
+
# Alternative method: get embedding dimension from a test encoding
|
| 268 |
+
if "embedding_dim" not in size_metrics and self.model is not None:
|
| 269 |
+
try:
|
| 270 |
+
test_embedding = self.model.encode(["test"], convert_to_tensor=False)
|
| 271 |
+
if hasattr(test_embedding, "shape") and len(test_embedding.shape) >= 2:
|
| 272 |
+
size_metrics["embedding_dim"] = test_embedding.shape[1]
|
| 273 |
+
elif (
|
| 274 |
+
isinstance(test_embedding, (list, tuple))
|
| 275 |
+
and len(test_embedding) > 0
|
| 276 |
+
and hasattr(test_embedding[0], "__len__")
|
| 277 |
+
):
|
| 278 |
+
size_metrics["embedding_dim"] = len(test_embedding[0])
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.warning(f"Could not determine embedding dimension: {e}")
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
logger.warning(f"Could not determine parameter count: {e}")
|
| 284 |
+
|
| 285 |
+
# Memory footprint
|
| 286 |
+
if self.device == "cuda" and torch.cuda.is_available():
|
| 287 |
+
torch.cuda.empty_cache()
|
| 288 |
+
size_metrics["gpu_memory_mb"] = torch.cuda.memory_allocated() / (1024 * 1024)
|
| 289 |
+
|
| 290 |
+
# RAM usage (approximate)
|
| 291 |
+
process = psutil.Process(os.getpid())
|
| 292 |
+
size_metrics["ram_usage_mb"] = process.memory_info().rss / (1024 * 1024)
|
| 293 |
+
|
| 294 |
+
self.results["size_metrics"] = size_metrics
|
| 295 |
+
return size_metrics
|
| 296 |
+
|
| 297 |
+
def benchmark_inference_speed(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
|
| 298 |
+
"""Benchmark inference speed across different batch sizes."""
|
| 299 |
+
if batch_sizes is None:
|
| 300 |
+
batch_sizes = [1, 8, 16, 32, 64, 128]
|
| 301 |
+
logger.info("β‘ Benchmarking inference speed...")
|
| 302 |
+
|
| 303 |
+
if self.model is None:
|
| 304 |
+
self.load_model()
|
| 305 |
+
|
| 306 |
+
if self.model is None:
|
| 307 |
+
msg = "Failed to load model"
|
| 308 |
+
raise RuntimeError(msg)
|
| 309 |
+
|
| 310 |
+
speed_results = {}
|
| 311 |
+
text_lengths = ["short", "medium", "long"]
|
| 312 |
+
|
| 313 |
+
for text_length in text_lengths:
|
| 314 |
+
logger.info(f" π Testing {text_length} texts...")
|
| 315 |
+
texts = BENCHMARK_TEXTS[text_length]
|
| 316 |
+
|
| 317 |
+
length_results = {}
|
| 318 |
+
|
| 319 |
+
for batch_size in batch_sizes:
|
| 320 |
+
if batch_size > len(texts):
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
logger.info(f" π Batch size: {batch_size}")
|
| 324 |
+
|
| 325 |
+
# Prepare batch
|
| 326 |
+
batch_texts = texts[:batch_size]
|
| 327 |
+
|
| 328 |
+
# Warmup
|
| 329 |
+
if self.device == "cuda":
|
| 330 |
+
torch.cuda.synchronize()
|
| 331 |
+
_ = self.model.encode(batch_texts[: min(2, batch_size)], convert_to_tensor=False)
|
| 332 |
+
|
| 333 |
+
# Clear cache
|
| 334 |
+
if self.device == "cuda":
|
| 335 |
+
torch.cuda.empty_cache()
|
| 336 |
+
torch.cuda.synchronize()
|
| 337 |
+
|
| 338 |
+
# Measure inference time
|
| 339 |
+
start_time = time.perf_counter()
|
| 340 |
+
|
| 341 |
+
embeddings = self.model.encode(batch_texts, convert_to_tensor=False, show_progress_bar=False)
|
| 342 |
+
|
| 343 |
+
if self.device == "cuda":
|
| 344 |
+
torch.cuda.synchronize()
|
| 345 |
+
|
| 346 |
+
end_time = time.perf_counter()
|
| 347 |
+
|
| 348 |
+
# Calculate metrics
|
| 349 |
+
total_time = end_time - start_time
|
| 350 |
+
time_per_text = total_time / batch_size
|
| 351 |
+
texts_per_second = batch_size / total_time
|
| 352 |
+
|
| 353 |
+
# Estimate tokens (rough approximation)
|
| 354 |
+
avg_tokens = sum(len(text.split()) for text in batch_texts) / batch_size
|
| 355 |
+
total_tokens = avg_tokens * batch_size
|
| 356 |
+
tokens_per_second = total_tokens / total_time
|
| 357 |
+
|
| 358 |
+
length_results[f"batch_{batch_size}"] = {
|
| 359 |
+
"total_time_ms": total_time * 1000,
|
| 360 |
+
"time_per_text_ms": time_per_text * 1000,
|
| 361 |
+
"texts_per_second": texts_per_second,
|
| 362 |
+
"tokens_per_second": tokens_per_second,
|
| 363 |
+
"avg_tokens_per_text": avg_tokens,
|
| 364 |
+
"embedding_shape": embeddings.shape
|
| 365 |
+
if hasattr(embeddings, "shape")
|
| 366 |
+
else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
speed_results[text_length] = length_results
|
| 370 |
+
|
| 371 |
+
self.results["speed_benchmarks"] = speed_results
|
| 372 |
+
return speed_results
|
| 373 |
+
|
| 374 |
+
def benchmark_memory_scaling(self, batch_sizes: list[int] | None = None) -> dict[str, Any]:
|
| 375 |
+
"""Benchmark memory usage across batch sizes."""
|
| 376 |
+
if batch_sizes is None:
|
| 377 |
+
batch_sizes = [1, 8, 16, 32, 64, 128, 256]
|
| 378 |
+
logger.info("πΎ Benchmarking memory scaling...")
|
| 379 |
+
|
| 380 |
+
if self.model is None:
|
| 381 |
+
self.load_model()
|
| 382 |
+
|
| 383 |
+
if self.model is None:
|
| 384 |
+
msg = "Failed to load model"
|
| 385 |
+
raise RuntimeError(msg)
|
| 386 |
+
|
| 387 |
+
memory_results: dict[str, Any] = {}
|
| 388 |
+
texts = BENCHMARK_TEXTS["medium"]
|
| 389 |
+
|
| 390 |
+
baseline_memory = 0
|
| 391 |
+
if self.device == "cuda":
|
| 392 |
+
torch.cuda.empty_cache()
|
| 393 |
+
baseline_memory = torch.cuda.memory_allocated()
|
| 394 |
+
|
| 395 |
+
for batch_size in batch_sizes:
|
| 396 |
+
if batch_size > len(texts):
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
logger.info(f" π Testing batch size: {batch_size}")
|
| 400 |
+
|
| 401 |
+
# Clear cache
|
| 402 |
+
if self.device == "cuda":
|
| 403 |
+
torch.cuda.empty_cache()
|
| 404 |
+
gc.collect()
|
| 405 |
+
|
| 406 |
+
batch_texts = texts[:batch_size]
|
| 407 |
+
|
| 408 |
+
# Measure memory before
|
| 409 |
+
if self.device == "cuda":
|
| 410 |
+
torch.cuda.memory_allocated()
|
| 411 |
+
|
| 412 |
+
# Run inference
|
| 413 |
+
try:
|
| 414 |
+
embeddings = self.model.encode(
|
| 415 |
+
batch_texts,
|
| 416 |
+
convert_to_tensor=self.device == "cuda",
|
| 417 |
+
show_progress_bar=False,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Measure memory after
|
| 421 |
+
memory_after = 0
|
| 422 |
+
if self.device == "cuda":
|
| 423 |
+
memory_after = torch.cuda.max_memory_allocated()
|
| 424 |
+
torch.cuda.reset_peak_memory_stats()
|
| 425 |
+
|
| 426 |
+
memory_used_mb = (memory_after - baseline_memory) / (1024 * 1024)
|
| 427 |
+
memory_per_text_mb = memory_used_mb / batch_size if batch_size > 0 else 0
|
| 428 |
+
|
| 429 |
+
memory_results[f"batch_{batch_size}"] = {
|
| 430 |
+
"memory_used_mb": memory_used_mb,
|
| 431 |
+
"memory_per_text_mb": memory_per_text_mb,
|
| 432 |
+
"baseline_memory_mb": baseline_memory / (1024 * 1024),
|
| 433 |
+
"peak_memory_mb": memory_after / (1024 * 1024),
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
# Clean up
|
| 437 |
+
del embeddings
|
| 438 |
+
|
| 439 |
+
except torch.cuda.OutOfMemoryError:
|
| 440 |
+
logger.warning(f"β OOM at batch size {batch_size}")
|
| 441 |
+
memory_results[f"batch_{batch_size}"] = {"oom": True}
|
| 442 |
+
break
|
| 443 |
+
except Exception as e:
|
| 444 |
+
logger.warning(f"β Error at batch size {batch_size}: {e}")
|
| 445 |
+
memory_results[f"batch_{batch_size}"] = {"error": str(e)}
|
| 446 |
+
|
| 447 |
+
self.results["memory_benchmarks"] = memory_results
|
| 448 |
+
return memory_results
|
| 449 |
+
|
| 450 |
+
def benchmark_cpu_vs_gpu(self) -> dict[str, Any]:
|
| 451 |
+
"""Compare CPU vs GPU performance."""
|
| 452 |
+
logger.info("π₯οΈ Benchmarking CPU vs GPU performance...")
|
| 453 |
+
|
| 454 |
+
comparison_results = {}
|
| 455 |
+
test_texts = BENCHMARK_TEXTS["medium"][:32] # Fixed batch size
|
| 456 |
+
|
| 457 |
+
devices = ["cpu"]
|
| 458 |
+
if torch.cuda.is_available():
|
| 459 |
+
devices.append("cuda")
|
| 460 |
+
|
| 461 |
+
for device in devices:
|
| 462 |
+
logger.info(f" π Testing on {device}")
|
| 463 |
+
|
| 464 |
+
# Load model on device
|
| 465 |
+
try:
|
| 466 |
+
model = SentenceTransformer(self.model_path, device=device)
|
| 467 |
+
|
| 468 |
+
# Warmup
|
| 469 |
+
_ = model.encode(test_texts[:2], convert_to_tensor=False)
|
| 470 |
+
|
| 471 |
+
# Benchmark
|
| 472 |
+
start_time = time.perf_counter()
|
| 473 |
+
embeddings = model.encode(test_texts, convert_to_tensor=False, show_progress_bar=False)
|
| 474 |
+
end_time = time.perf_counter()
|
| 475 |
+
|
| 476 |
+
total_time = end_time - start_time
|
| 477 |
+
|
| 478 |
+
comparison_results[device] = {
|
| 479 |
+
"total_time_ms": total_time * 1000,
|
| 480 |
+
"texts_per_second": len(test_texts) / total_time,
|
| 481 |
+
"time_per_text_ms": (total_time / len(test_texts)) * 1000,
|
| 482 |
+
"embedding_shape": embeddings.shape
|
| 483 |
+
if hasattr(embeddings, "shape")
|
| 484 |
+
else f"({len(embeddings)}, {len(embeddings[0]) if embeddings else 0})",
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
del model
|
| 488 |
+
if device == "cuda":
|
| 489 |
+
torch.cuda.empty_cache()
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
logger.warning(f"β Failed on {device}: {e}")
|
| 493 |
+
comparison_results[device] = {"error": str(e)}
|
| 494 |
+
|
| 495 |
+
self.results["cpu_vs_gpu"] = comparison_results
|
| 496 |
+
return comparison_results
|
| 497 |
+
|
| 498 |
+
def run_comprehensive_benchmark(self) -> dict[str, Any]:
|
| 499 |
+
"""Run all benchmarks and return comprehensive results."""
|
| 500 |
+
logger.info(f"π Starting comprehensive benchmark for {self.model_name}")
|
| 501 |
+
|
| 502 |
+
# Load model
|
| 503 |
+
self.load_model()
|
| 504 |
+
|
| 505 |
+
# Run all benchmarks
|
| 506 |
+
self.measure_model_size()
|
| 507 |
+
self.benchmark_inference_speed()
|
| 508 |
+
self.benchmark_memory_scaling()
|
| 509 |
+
self.benchmark_cpu_vs_gpu()
|
| 510 |
+
|
| 511 |
+
# Add metadata
|
| 512 |
+
self.results["model_name"] = self.model_name
|
| 513 |
+
self.results["model_path"] = self.model_path
|
| 514 |
+
self.results["device"] = self.device
|
| 515 |
+
self.results["torch_version"] = torch.__version__
|
| 516 |
+
self.results["cuda_available"] = torch.cuda.is_available()
|
| 517 |
+
|
| 518 |
+
if torch.cuda.is_available():
|
| 519 |
+
self.results["gpu_name"] = torch.cuda.get_device_name(0)
|
| 520 |
+
self.results["gpu_memory_gb"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 521 |
+
|
| 522 |
+
# System info
|
| 523 |
+
self.results["cpu_count"] = psutil.cpu_count()
|
| 524 |
+
self.results["ram_gb"] = psutil.virtual_memory().total / (1024**3)
|
| 525 |
+
|
| 526 |
+
logger.info("β
Comprehensive benchmark completed!")
|
| 527 |
+
return self.results
|
| 528 |
+
|
| 529 |
+
def save_results(self, output_file: str) -> None:
|
| 530 |
+
"""Save benchmark results to JSON file."""
|
| 531 |
+
output_path = Path(output_file)
|
| 532 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 533 |
+
|
| 534 |
+
with output_path.open("w") as f:
|
| 535 |
+
json.dump(self.results, f, indent=2, default=str)
|
| 536 |
+
|
| 537 |
+
logger.info(f"π Results saved to {output_path}")
|
| 538 |
+
|
| 539 |
+
def print_summary(self) -> None:
|
| 540 |
+
"""Print a summary of benchmark results."""
|
| 541 |
+
if not self.results:
|
| 542 |
+
logger.warning("No results to summarize")
|
| 543 |
+
return
|
| 544 |
+
|
| 545 |
+
print(f"\n{'=' * 60}")
|
| 546 |
+
print(f"Performance Benchmark Summary: {self.model_name}")
|
| 547 |
+
print(f"{'=' * 60}")
|
| 548 |
+
|
| 549 |
+
# Model size
|
| 550 |
+
if "size_metrics" in self.results:
|
| 551 |
+
size = self.results["size_metrics"]
|
| 552 |
+
print("\nπ Model Size:")
|
| 553 |
+
print(f" Disk Size: {size.get('disk_size_mb', 0):.1f} MB")
|
| 554 |
+
if "parameters_millions" in size:
|
| 555 |
+
print(f" Parameters: {size['parameters_millions']:.1f}M")
|
| 556 |
+
if "embedding_dim" in size:
|
| 557 |
+
print(f" Embedding Dim: {size['embedding_dim']}")
|
| 558 |
+
|
| 559 |
+
# Speed summary
|
| 560 |
+
if "speed_benchmarks" in self.results:
|
| 561 |
+
speed = self.results["speed_benchmarks"]
|
| 562 |
+
print("\nβ‘ Speed (medium texts, batch 32):")
|
| 563 |
+
if "medium" in speed and "batch_32" in speed["medium"]:
|
| 564 |
+
batch_32 = speed["medium"]["batch_32"]
|
| 565 |
+
print(f" Throughput: {batch_32['texts_per_second']:.1f} texts/sec")
|
| 566 |
+
print(f" Latency: {batch_32['time_per_text_ms']:.1f} ms/text")
|
| 567 |
+
print(f" Token Speed: {batch_32['tokens_per_second']:.0f} tokens/sec")
|
| 568 |
+
|
| 569 |
+
# CPU vs GPU
|
| 570 |
+
if "cpu_vs_gpu" in self.results:
|
| 571 |
+
comparison = self.results["cpu_vs_gpu"]
|
| 572 |
+
print("\nπ₯οΈ CPU vs GPU:")
|
| 573 |
+
for device, metrics in comparison.items():
|
| 574 |
+
if "error" not in metrics:
|
| 575 |
+
print(f" {device.upper()}: {metrics['texts_per_second']:.1f} texts/sec")
|
| 576 |
+
|
| 577 |
+
print()
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def run_benchmark(
|
| 581 |
+
model_path: str | list[str],
|
| 582 |
+
model_name: str | None = None,
|
| 583 |
+
output: str = "benchmark_results.json",
|
| 584 |
+
quick: bool = False,
|
| 585 |
+
compare_models: list[str] | None = None,
|
| 586 |
+
) -> None:
|
| 587 |
+
"""Run benchmark for one or multiple models with comparison."""
|
| 588 |
+
# Handle both single model and multiple models
|
| 589 |
+
models_to_benchmark = [model_path] if isinstance(model_path, str) else model_path
|
| 590 |
+
|
| 591 |
+
if compare_models:
|
| 592 |
+
models_to_benchmark.extend(compare_models)
|
| 593 |
+
|
| 594 |
+
all_results = []
|
| 595 |
+
|
| 596 |
+
for i, model in enumerate(models_to_benchmark):
|
| 597 |
+
current_model_name = model_name if i == 0 else Path(model).name
|
| 598 |
+
|
| 599 |
+
print(f"\n{'=' * 60}")
|
| 600 |
+
print(f"Benchmarking Model {i + 1}/{len(models_to_benchmark)}: {current_model_name}")
|
| 601 |
+
print(f"{'=' * 60}")
|
| 602 |
+
|
| 603 |
+
try:
|
| 604 |
+
benchmarker = PerformanceBenchmark(model, current_model_name)
|
| 605 |
+
|
| 606 |
+
if quick:
|
| 607 |
+
# Quick benchmark
|
| 608 |
+
benchmarker.load_model()
|
| 609 |
+
benchmarker.measure_model_size()
|
| 610 |
+
benchmarker.benchmark_inference_speed([1, 16, 32])
|
| 611 |
+
else:
|
| 612 |
+
# Comprehensive benchmark
|
| 613 |
+
benchmarker.run_comprehensive_benchmark()
|
| 614 |
+
|
| 615 |
+
all_results.append(benchmarker.results)
|
| 616 |
+
benchmarker.print_summary()
|
| 617 |
+
|
| 618 |
+
except Exception:
|
| 619 |
+
logger.exception(f"β Failed to benchmark {current_model_name}")
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
# Save individual results
|
| 623 |
+
output_dir = Path(output).parent if Path(output).suffix else Path(output)
|
| 624 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 625 |
+
|
| 626 |
+
for results in all_results:
|
| 627 |
+
model_name_safe = "".join(c for c in results["model_name"] if c.isalnum() or c in ("-", "_", "."))
|
| 628 |
+
output_path = output_dir / f"benchmark_{model_name_safe}.json"
|
| 629 |
+
|
| 630 |
+
with output_path.open("w") as f:
|
| 631 |
+
json.dump(results, f, indent=2, default=str)
|
| 632 |
+
|
| 633 |
+
logger.info(f"π Results saved to {output_path}")
|
| 634 |
+
|
| 635 |
+
# Create comparison if multiple models
|
| 636 |
+
if len(all_results) > 1:
|
| 637 |
+
create_benchmark_comparison(all_results, str(output_dir / "benchmark_comparison.json"))
|
| 638 |
+
|
| 639 |
+
print(f"\nβ
Benchmark complete! Results saved to {output_dir}")
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def create_benchmark_comparison(all_results: list[dict[str, Any]], output_path: str) -> None:
|
| 643 |
+
"""Create a comparison report for multiple benchmark results."""
|
| 644 |
+
print(f"\n{'=' * 80}")
|
| 645 |
+
print("Performance Benchmark Comparison")
|
| 646 |
+
print(f"{'=' * 80}")
|
| 647 |
+
|
| 648 |
+
comparison_data = []
|
| 649 |
+
|
| 650 |
+
for results in all_results:
|
| 651 |
+
model_name = results.get("model_name", "Unknown")
|
| 652 |
+
size_metrics = results.get("size_metrics", {})
|
| 653 |
+
speed_benchmarks = results.get("speed_benchmarks", {})
|
| 654 |
+
cpu_vs_gpu = results.get("cpu_vs_gpu", {})
|
| 655 |
+
|
| 656 |
+
# Extract key metrics
|
| 657 |
+
row = {
|
| 658 |
+
"Model": model_name,
|
| 659 |
+
"Disk Size (MB)": size_metrics.get("disk_size_mb", 0),
|
| 660 |
+
"Parameters (M)": size_metrics.get("parameters_millions", 0),
|
| 661 |
+
"Embedding Dim": size_metrics.get("embedding_dim", 0),
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
# Speed metrics (medium texts, batch 32)
|
| 665 |
+
if "medium" in speed_benchmarks and "batch_32" in speed_benchmarks["medium"]:
|
| 666 |
+
batch_32 = speed_benchmarks["medium"]["batch_32"]
|
| 667 |
+
row.update(
|
| 668 |
+
{
|
| 669 |
+
"Throughput (texts/sec)": batch_32.get("texts_per_second", 0),
|
| 670 |
+
"Latency (ms/text)": batch_32.get("time_per_text_ms", 0),
|
| 671 |
+
"Token Speed (tokens/sec)": batch_32.get("tokens_per_second", 0),
|
| 672 |
+
}
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# CPU vs GPU comparison
|
| 676 |
+
for device in ["cpu", "cuda"]:
|
| 677 |
+
if device in cpu_vs_gpu and "error" not in cpu_vs_gpu[device]:
|
| 678 |
+
row[f"{device.upper()} Speed (texts/sec)"] = cpu_vs_gpu[device].get("texts_per_second", 0)
|
| 679 |
+
|
| 680 |
+
comparison_data.append(row)
|
| 681 |
+
|
| 682 |
+
# Create DataFrame and save
|
| 683 |
+
df = pd.DataFrame(comparison_data)
|
| 684 |
+
|
| 685 |
+
# Sort by throughput (descending)
|
| 686 |
+
if "Throughput (texts/sec)" in df.columns:
|
| 687 |
+
df = df.sort_values("Throughput (texts/sec)", ascending=False)
|
| 688 |
+
|
| 689 |
+
# Print comparison table
|
| 690 |
+
print(df.to_string(index=False, float_format="%.2f"))
|
| 691 |
+
|
| 692 |
+
# Save comparison results
|
| 693 |
+
comparison_summary = {
|
| 694 |
+
"comparison_table": df.to_dict(orient="records"),
|
| 695 |
+
"summary": {
|
| 696 |
+
"fastest_model": df.iloc[0]["Model"] if len(df) > 0 else None,
|
| 697 |
+
"smallest_model": df.loc[df["Disk Size (MB)"].idxmin()]["Model"] if len(df) > 0 else None,
|
| 698 |
+
"most_efficient": df.loc[df["Throughput (texts/sec)"].idxmax()]["Model"]
|
| 699 |
+
if "Throughput (texts/sec)" in df.columns and len(df) > 0
|
| 700 |
+
else None,
|
| 701 |
+
},
|
| 702 |
+
"timestamp": time.time(),
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
with Path(output_path).open("w") as f:
|
| 706 |
+
json.dump(comparison_summary, f, indent=2, default=str)
|
| 707 |
+
|
| 708 |
+
print(f"\nπ Comparison saved to {output_path}")
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def save_benchmark_results(
|
| 712 |
+
results: dict[str, Any],
|
| 713 |
+
output_dir: str,
|
| 714 |
+
model_name: str,
|
| 715 |
+
volume_results_dir: Path | None = None,
|
| 716 |
+
) -> None:
|
| 717 |
+
"""Save benchmark results to JSON file with Beam volume support."""
|
| 718 |
+
# Save to Beam volume if available
|
| 719 |
+
if volume_results_dir:
|
| 720 |
+
volume_output_path = volume_results_dir / f"benchmark_{model_name}.json"
|
| 721 |
+
try:
|
| 722 |
+
with volume_output_path.open("w") as f:
|
| 723 |
+
json.dump(results, f, indent=2, default=str)
|
| 724 |
+
logger.info(f"πΎ Results saved to Beam volume: {volume_output_path}")
|
| 725 |
+
except Exception as e:
|
| 726 |
+
logger.warning(f"β οΈ Failed to save to Beam volume: {e}")
|
| 727 |
+
|
| 728 |
+
# Always save local backup
|
| 729 |
+
output_path = Path(output_dir)
|
| 730 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 731 |
+
|
| 732 |
+
# Clean model name for filename
|
| 733 |
+
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
| 734 |
+
filename = f"benchmark_{safe_name}.json"
|
| 735 |
+
filepath = output_path / filename
|
| 736 |
+
|
| 737 |
+
with filepath.open("w") as f:
|
| 738 |
+
json.dump(results, f, indent=2, default=str)
|
| 739 |
+
|
| 740 |
+
logger.info(f"π Local backup saved to {filepath}")
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def beam_benchmark_models(
|
| 744 |
+
models: list[str],
|
| 745 |
+
quick: bool = False,
|
| 746 |
+
output_dir: str = DEFAULT_OUTPUT_DIR,
|
| 747 |
+
volume_name: str = VOLUME_NAME,
|
| 748 |
+
mount_path: str = VOLUME_PATH,
|
| 749 |
+
) -> list[dict[str, Any]]:
|
| 750 |
+
"""Main benchmarking function for Beam execution with checkpoint support."""
|
| 751 |
+
logger.info("π Starting Beam-powered performance benchmarking")
|
| 752 |
+
logger.info(f"π Benchmarking {len(models)} models")
|
| 753 |
+
|
| 754 |
+
# Initialize Beam utilities
|
| 755 |
+
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
|
| 756 |
+
|
| 757 |
+
# Create benchmark results directory in volume
|
| 758 |
+
results_dir = Path(mount_path) / BENCHMARK_RESULTS_DIR
|
| 759 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 760 |
+
|
| 761 |
+
logger.info(f"π Using Beam volume: {volume_name} at {mount_path}")
|
| 762 |
+
logger.info(f"πΎ Benchmark results directory: {results_dir}")
|
| 763 |
+
|
| 764 |
+
all_results = []
|
| 765 |
+
skipped_models = []
|
| 766 |
+
|
| 767 |
+
for model_path in models:
|
| 768 |
+
model_name = Path(model_path).name if model_path != str(Path(mount_path)) else "gte_qwen2_m2v_code"
|
| 769 |
+
|
| 770 |
+
# Check if this model has already been benchmarked (except for trained model)
|
| 771 |
+
is_trained_model = model_path == str(Path(mount_path)) or model_name == "gte_qwen2_m2v_code"
|
| 772 |
+
|
| 773 |
+
if not is_trained_model:
|
| 774 |
+
# Check for existing benchmark results
|
| 775 |
+
existing_result_file = results_dir / f"benchmark_{model_name}.json"
|
| 776 |
+
if existing_result_file.exists():
|
| 777 |
+
logger.info(f"β
Model {model_name} already benchmarked - loading existing results")
|
| 778 |
+
try:
|
| 779 |
+
with existing_result_file.open("r") as f:
|
| 780 |
+
existing_results = json.load(f)
|
| 781 |
+
all_results.append(existing_results)
|
| 782 |
+
skipped_models.append(model_name)
|
| 783 |
+
continue
|
| 784 |
+
except Exception as e:
|
| 785 |
+
logger.warning(f"β οΈ Failed to load existing results for {model_name}: {e}")
|
| 786 |
+
# Continue with benchmarking if loading fails
|
| 787 |
+
|
| 788 |
+
logger.info(f"\n{'=' * 60}")
|
| 789 |
+
logger.info(f"π Benchmarking model: {model_name}")
|
| 790 |
+
logger.info(f"π Path: {model_path}")
|
| 791 |
+
if is_trained_model:
|
| 792 |
+
logger.info("π― Trained model - always re-benchmark")
|
| 793 |
+
logger.info(f"{'=' * 60}")
|
| 794 |
+
|
| 795 |
+
try:
|
| 796 |
+
# Distinguish between local paths and HuggingFace model names
|
| 797 |
+
is_huggingface_model = (
|
| 798 |
+
"/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
if is_huggingface_model:
|
| 802 |
+
# This is a HuggingFace model name - pass directly to benchmarker
|
| 803 |
+
logger.info(f"π₯ Loading HuggingFace model: {model_path}")
|
| 804 |
+
benchmarker = PerformanceBenchmark(
|
| 805 |
+
model_path,
|
| 806 |
+
model_name,
|
| 807 |
+
checkpoint_manager=checkpoint_mgr,
|
| 808 |
+
eval_manager=eval_mgr,
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
# This is a local path - check if it exists in Beam volume
|
| 812 |
+
actual_model_path = model_path # Default to original path
|
| 813 |
+
if not Path(model_path).exists() and not model_path.startswith("/"):
|
| 814 |
+
# Try to load from Beam volume
|
| 815 |
+
local_model_path = Path(mount_path) / model_name
|
| 816 |
+
logger.info(f"π Trying to load {model_name} from Beam volume: {local_model_path}")
|
| 817 |
+
if local_model_path.exists():
|
| 818 |
+
actual_model_path = str(local_model_path)
|
| 819 |
+
logger.info(f"β
Found model in Beam volume: {actual_model_path}")
|
| 820 |
+
else:
|
| 821 |
+
# Try in root of volume (for your trained model)
|
| 822 |
+
root_model_path = Path(mount_path)
|
| 823 |
+
if (root_model_path / "config.json").exists():
|
| 824 |
+
actual_model_path = str(root_model_path)
|
| 825 |
+
logger.info(f"β
Found model in Beam volume root: {actual_model_path}")
|
| 826 |
+
else:
|
| 827 |
+
logger.warning(f"β οΈ Model not found locally or in Beam volume: {model_name}")
|
| 828 |
+
continue
|
| 829 |
+
|
| 830 |
+
benchmarker = PerformanceBenchmark(
|
| 831 |
+
actual_model_path,
|
| 832 |
+
model_name,
|
| 833 |
+
checkpoint_manager=checkpoint_mgr,
|
| 834 |
+
eval_manager=eval_mgr,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# Run benchmarking
|
| 838 |
+
if quick:
|
| 839 |
+
# Quick benchmark
|
| 840 |
+
benchmarker.load_model()
|
| 841 |
+
benchmarker.measure_model_size()
|
| 842 |
+
benchmarker.benchmark_inference_speed([1, 16, 32])
|
| 843 |
+
else:
|
| 844 |
+
# Comprehensive benchmark
|
| 845 |
+
benchmarker.run_comprehensive_benchmark()
|
| 846 |
+
|
| 847 |
+
# Save results with Beam support
|
| 848 |
+
save_benchmark_results(benchmarker.results, output_dir, model_name, results_dir)
|
| 849 |
+
|
| 850 |
+
# Print summary
|
| 851 |
+
benchmarker.print_summary()
|
| 852 |
+
|
| 853 |
+
all_results.append(benchmarker.results)
|
| 854 |
+
|
| 855 |
+
except Exception:
|
| 856 |
+
logger.exception(f"β Failed to benchmark {model_name}")
|
| 857 |
+
continue
|
| 858 |
+
|
| 859 |
+
# Create comparison report in Beam volume
|
| 860 |
+
if len(all_results) > 1:
|
| 861 |
+
comparison_dir = results_dir / "comparisons"
|
| 862 |
+
comparison_dir.mkdir(parents=True, exist_ok=True)
|
| 863 |
+
create_benchmark_comparison(all_results, str(comparison_dir / "benchmark_comparison.json"))
|
| 864 |
+
logger.info(f"π Comparison report saved to Beam volume: {comparison_dir}")
|
| 865 |
+
|
| 866 |
+
# Log summary of what was done
|
| 867 |
+
newly_benchmarked = len(all_results) - len(skipped_models)
|
| 868 |
+
logger.info("\nβ
Beam benchmarking complete!")
|
| 869 |
+
logger.info(f"π Newly benchmarked: {newly_benchmarked} models")
|
| 870 |
+
logger.info(f"βοΈ Skipped (already done): {len(skipped_models)} models")
|
| 871 |
+
logger.info(f"π Total results: {len(all_results)} models")
|
| 872 |
+
logger.info(f"πΎ Results available in Beam volume: {volume_name}")
|
| 873 |
+
|
| 874 |
+
if skipped_models:
|
| 875 |
+
logger.info(f"βοΈ Skipped models: {', '.join(skipped_models)}")
|
| 876 |
+
|
| 877 |
+
return all_results
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
@function(
|
| 881 |
+
gpu=GPU_NAME,
|
| 882 |
+
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
| 883 |
+
image=IMAGE,
|
| 884 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 885 |
+
env={
|
| 886 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 887 |
+
"CUDA_LAUNCH_BLOCKING": "0",
|
| 888 |
+
},
|
| 889 |
+
timeout=3600 * 4, # 4 hours for benchmarking all models
|
| 890 |
+
)
|
| 891 |
+
def main() -> None:
|
| 892 |
+
"""Main benchmarking function - runs all default models on Beam."""
|
| 893 |
+
logger.info("π Starting comprehensive performance benchmarking on Beam")
|
| 894 |
+
|
| 895 |
+
# Use default models but replace the local model path with Beam volume path
|
| 896 |
+
models = DEFAULT_BENCHMARK_MODELS.copy()
|
| 897 |
+
|
| 898 |
+
# Replace "gte_qwen2_m2v_code" with actual Beam volume path
|
| 899 |
+
for i, model in enumerate(models):
|
| 900 |
+
if model == "gte_qwen2_m2v_code":
|
| 901 |
+
models[i] = str(Path(VOLUME_PATH)) # Use the Beam volume root
|
| 902 |
+
logger.info(f"π― Using trained model from Beam volume: {models[i]}")
|
| 903 |
+
|
| 904 |
+
# Discover simplified distillation models
|
| 905 |
+
logger.info("π Discovering simplified distillation models...")
|
| 906 |
+
discovered_models = discover_simplified_models(".")
|
| 907 |
+
|
| 908 |
+
# Add discovered models
|
| 909 |
+
if discovered_models:
|
| 910 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 911 |
+
for model_path in discovered_models:
|
| 912 |
+
models.append(model_path)
|
| 913 |
+
logger.info(f" π {model_path}")
|
| 914 |
+
else:
|
| 915 |
+
logger.warning("β οΈ No simplified distillation models found")
|
| 916 |
+
|
| 917 |
+
logger.info(f"π Benchmarking {len(models)} models:")
|
| 918 |
+
for i, model in enumerate(models, 1):
|
| 919 |
+
logger.info(f" {i}. {model}")
|
| 920 |
+
|
| 921 |
+
logger.info("\nπ‘ Checkpoint Info:")
|
| 922 |
+
logger.info(" - Already benchmarked models will be skipped")
|
| 923 |
+
logger.info(" - Your trained model will always be re-benchmarked")
|
| 924 |
+
logger.info(" - Results are saved persistently to Beam volume")
|
| 925 |
+
|
| 926 |
+
# Run comprehensive benchmark using Beam utilities
|
| 927 |
+
results = beam_benchmark_models(
|
| 928 |
+
models=models,
|
| 929 |
+
quick=True, # Use quick benchmark for efficiency
|
| 930 |
+
output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
|
| 931 |
+
volume_name=VOLUME_NAME,
|
| 932 |
+
mount_path=VOLUME_PATH,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
# Print final summary
|
| 936 |
+
print("\nπ― Benchmarking Summary:")
|
| 937 |
+
print(f"π Total models processed: {len(results)}")
|
| 938 |
+
print(f"πΎ Results saved to Beam volume: {VOLUME_NAME}")
|
| 939 |
+
print(f"π Directory: {BENCHMARK_RESULTS_DIR}")
|
| 940 |
+
print("\nπ To view analysis:")
|
| 941 |
+
print(" beam run src.distiller.analyze:beam_analysis")
|
| 942 |
+
print("\nπ To run benchmarks again:")
|
| 943 |
+
print(" distiller benchmark (will skip already completed models)")
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
def discover_simplified_models(base_path: str = ".") -> list[str]:
|
| 947 |
+
"""
|
| 948 |
+
Discover all simplified distillation models in the correct directory.
|
| 949 |
+
|
| 950 |
+
Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
|
| 951 |
+
"""
|
| 952 |
+
discovered_models: list[str] = []
|
| 953 |
+
|
| 954 |
+
# Look in the correct location where distill_simplified.py saves models
|
| 955 |
+
models_dir = Path(base_path) / "code_model2vec" / "final"
|
| 956 |
+
|
| 957 |
+
if not models_dir.exists():
|
| 958 |
+
logger.warning(f"Models directory not found: {models_dir}")
|
| 959 |
+
return discovered_models
|
| 960 |
+
|
| 961 |
+
# Look for simplified model directories with the updated pattern
|
| 962 |
+
pattern = "code_model2vec_*"
|
| 963 |
+
for model_dir in models_dir.glob(pattern):
|
| 964 |
+
if model_dir.is_dir() and (model_dir / "config.json").exists():
|
| 965 |
+
discovered_models.append(str(model_dir))
|
| 966 |
+
logger.info(f"π Discovered simplified model: {model_dir}")
|
| 967 |
+
|
| 968 |
+
# Sort alphabetically for consistent ordering
|
| 969 |
+
discovered_models.sort()
|
| 970 |
+
|
| 971 |
+
return discovered_models
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
@function(
|
| 975 |
+
gpu=GPU_NAME,
|
| 976 |
+
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
| 977 |
+
image=IMAGE,
|
| 978 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 979 |
+
env={
|
| 980 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 981 |
+
"CUDA_LAUNCH_BLOCKING": "0",
|
| 982 |
+
},
|
| 983 |
+
timeout=3600 * 3, # 3 hours for simplified models only
|
| 984 |
+
)
|
| 985 |
+
def benchmark_simplified_only() -> None:
|
| 986 |
+
"""Benchmark only simplified distillation models, skipping 3rd party models."""
|
| 987 |
+
logger.info("π Starting simplified distillation models benchmarking on Beam")
|
| 988 |
+
logger.info("βοΈ Skipping 3rd party models - benchmarking only simplified distillation models")
|
| 989 |
+
|
| 990 |
+
# Discover simplified distillation models
|
| 991 |
+
logger.info("π Discovering simplified distillation models...")
|
| 992 |
+
discovered_models = discover_simplified_models(".")
|
| 993 |
+
|
| 994 |
+
if not discovered_models:
|
| 995 |
+
logger.error("β No simplified distillation models found! Run distill-simple first.")
|
| 996 |
+
return
|
| 997 |
+
|
| 998 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 999 |
+
for model_path in discovered_models:
|
| 1000 |
+
logger.info(f" π {model_path}")
|
| 1001 |
+
|
| 1002 |
+
logger.info("\nπ‘ Checkpoint Info:")
|
| 1003 |
+
logger.info(" - Already benchmarked models will be skipped")
|
| 1004 |
+
logger.info(" - Results are saved persistently to Beam volume")
|
| 1005 |
+
|
| 1006 |
+
# Run comprehensive benchmark using Beam utilities
|
| 1007 |
+
results = beam_benchmark_models(
|
| 1008 |
+
models=discovered_models,
|
| 1009 |
+
quick=True, # Use quick benchmark for efficiency
|
| 1010 |
+
output_dir=str(Path(VOLUME_PATH) / BENCHMARK_RESULTS_DIR),
|
| 1011 |
+
volume_name=VOLUME_NAME,
|
| 1012 |
+
mount_path=VOLUME_PATH,
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
# Print final summary
|
| 1016 |
+
print("\nπ― Simplified Benchmarking Summary:")
|
| 1017 |
+
print(f"π Total simplified models processed: {len(results)}")
|
| 1018 |
+
print(f"πΎ Results saved to Beam volume: {VOLUME_NAME}")
|
| 1019 |
+
print(f"π Directory: {BENCHMARK_RESULTS_DIR}")
|
| 1020 |
+
print("βοΈ 3rd party models were skipped")
|
| 1021 |
+
print("\nπ To view analysis:")
|
| 1022 |
+
print(" distiller analyze")
|
| 1023 |
+
print("\nπ To run full benchmarks (including 3rd party):")
|
| 1024 |
+
print(" distiller benchmark")
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
def run_local_benchmark(
|
| 1028 |
+
models: list[str] | None = None,
|
| 1029 |
+
quick: bool = False,
|
| 1030 |
+
output_dir: str = DEFAULT_OUTPUT_DIR,
|
| 1031 |
+
) -> list[dict[str, Any]]:
|
| 1032 |
+
"""Main benchmarking function for local execution without Beam utilities."""
|
| 1033 |
+
logger.info("π₯οΈ Running performance benchmarking locally")
|
| 1034 |
+
|
| 1035 |
+
if models is None:
|
| 1036 |
+
models = DEFAULT_BENCHMARK_MODELS.copy()
|
| 1037 |
+
|
| 1038 |
+
# Replace "gte_qwen2_m2v_code" with a reasonable local path
|
| 1039 |
+
for i, model in enumerate(models):
|
| 1040 |
+
if model == "gte_qwen2_m2v_code":
|
| 1041 |
+
# Look for local trained model
|
| 1042 |
+
local_model_paths = [
|
| 1043 |
+
"./gte_qwen2_m2v_code",
|
| 1044 |
+
"./models/gte_qwen2_m2v_code",
|
| 1045 |
+
"./output/gte_qwen2_m2v_code",
|
| 1046 |
+
]
|
| 1047 |
+
found = False
|
| 1048 |
+
for local_path in local_model_paths:
|
| 1049 |
+
if Path(local_path).exists():
|
| 1050 |
+
models[i] = local_path
|
| 1051 |
+
logger.info(f"π― Found local trained model: {local_path}")
|
| 1052 |
+
found = True
|
| 1053 |
+
break
|
| 1054 |
+
if not found:
|
| 1055 |
+
logger.warning("β οΈ Local trained model not found, skipping")
|
| 1056 |
+
models.pop(i)
|
| 1057 |
+
|
| 1058 |
+
# Discover simplified distillation models
|
| 1059 |
+
logger.info("π Discovering simplified distillation models...")
|
| 1060 |
+
discovered_models = discover_simplified_models(".")
|
| 1061 |
+
|
| 1062 |
+
# Add discovered models
|
| 1063 |
+
if discovered_models:
|
| 1064 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 1065 |
+
for model_path in discovered_models:
|
| 1066 |
+
models.append(model_path)
|
| 1067 |
+
logger.info(f" π {model_path}")
|
| 1068 |
+
else:
|
| 1069 |
+
logger.warning("β οΈ No simplified distillation models found")
|
| 1070 |
+
|
| 1071 |
+
logger.info(f"π Benchmarking {len(models)} models")
|
| 1072 |
+
logger.info(f"π Using local output directory: {output_dir}")
|
| 1073 |
+
|
| 1074 |
+
# Create local output directory
|
| 1075 |
+
output_path = Path(output_dir)
|
| 1076 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 1077 |
+
|
| 1078 |
+
all_results = []
|
| 1079 |
+
skipped_models = []
|
| 1080 |
+
|
| 1081 |
+
for model_path in models:
|
| 1082 |
+
model_name = Path(model_path).name
|
| 1083 |
+
|
| 1084 |
+
# Check for existing benchmark results locally
|
| 1085 |
+
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
| 1086 |
+
result_file = output_path / f"benchmark_{safe_name}.json"
|
| 1087 |
+
|
| 1088 |
+
if result_file.exists():
|
| 1089 |
+
logger.info(f"β
Model {model_name} already benchmarked - loading existing results")
|
| 1090 |
+
try:
|
| 1091 |
+
with result_file.open("r") as f:
|
| 1092 |
+
existing_results = json.load(f)
|
| 1093 |
+
all_results.append(existing_results)
|
| 1094 |
+
skipped_models.append(model_name)
|
| 1095 |
+
continue
|
| 1096 |
+
except Exception as e:
|
| 1097 |
+
logger.warning(f"β οΈ Failed to load existing results for {model_name}: {e}")
|
| 1098 |
+
|
| 1099 |
+
logger.info(f"\n{'=' * 60}")
|
| 1100 |
+
logger.info(f"π Benchmarking model: {model_name}")
|
| 1101 |
+
logger.info(f"π Path: {model_path}")
|
| 1102 |
+
logger.info(f"{'=' * 60}")
|
| 1103 |
+
|
| 1104 |
+
try:
|
| 1105 |
+
# Create benchmarker without Beam utilities
|
| 1106 |
+
benchmarker = PerformanceBenchmark(
|
| 1107 |
+
model_path,
|
| 1108 |
+
model_name,
|
| 1109 |
+
checkpoint_manager=None, # No checkpointing for local benchmarking
|
| 1110 |
+
eval_manager=None,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# Run benchmarking
|
| 1114 |
+
if quick:
|
| 1115 |
+
# Quick benchmark
|
| 1116 |
+
benchmarker.load_model()
|
| 1117 |
+
benchmarker.measure_model_size()
|
| 1118 |
+
benchmarker.benchmark_inference_speed([1, 16, 32])
|
| 1119 |
+
else:
|
| 1120 |
+
# Comprehensive benchmark
|
| 1121 |
+
benchmarker.run_comprehensive_benchmark()
|
| 1122 |
+
|
| 1123 |
+
# Save results locally only
|
| 1124 |
+
save_benchmark_results(benchmarker.results, output_dir, model_name, volume_results_dir=None)
|
| 1125 |
+
|
| 1126 |
+
# Print summary
|
| 1127 |
+
benchmarker.print_summary()
|
| 1128 |
+
|
| 1129 |
+
all_results.append(benchmarker.results)
|
| 1130 |
+
|
| 1131 |
+
except Exception:
|
| 1132 |
+
logger.exception(f"β Failed to benchmark {model_name}")
|
| 1133 |
+
continue
|
| 1134 |
+
|
| 1135 |
+
# Create comparison report locally
|
| 1136 |
+
if len(all_results) > 1:
|
| 1137 |
+
create_benchmark_comparison(all_results, str(output_path / "benchmark_comparison.json"))
|
| 1138 |
+
logger.info(f"π Comparison report saved locally: {output_dir}")
|
| 1139 |
+
|
| 1140 |
+
# Log summary
|
| 1141 |
+
newly_benchmarked = len(all_results) - len(skipped_models)
|
| 1142 |
+
logger.info("\nβ
Local benchmarking complete!")
|
| 1143 |
+
logger.info(f"π Newly benchmarked: {newly_benchmarked} models")
|
| 1144 |
+
logger.info(f"βοΈ Skipped (already done): {len(skipped_models)} models")
|
| 1145 |
+
logger.info(f"π Total results: {len(all_results)} models")
|
| 1146 |
+
logger.info(f"πΎ Results available locally: {output_dir}")
|
| 1147 |
+
|
| 1148 |
+
if skipped_models:
|
| 1149 |
+
logger.info(f"βοΈ Skipped models: {', '.join(skipped_models)}")
|
| 1150 |
+
|
| 1151 |
+
return all_results
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
def run_local_benchmark_simplified(
|
| 1155 |
+
quick: bool = False,
|
| 1156 |
+
output_dir: str = DEFAULT_OUTPUT_DIR,
|
| 1157 |
+
) -> list[dict[str, Any]]:
|
| 1158 |
+
"""Local benchmarking function for simplified models only."""
|
| 1159 |
+
logger.info("π₯οΈ Running simplified model benchmarking locally")
|
| 1160 |
+
|
| 1161 |
+
# Discover simplified distillation models only
|
| 1162 |
+
logger.info("π Discovering simplified distillation models...")
|
| 1163 |
+
discovered_models = discover_simplified_models(".")
|
| 1164 |
+
|
| 1165 |
+
if not discovered_models:
|
| 1166 |
+
logger.error("β No simplified distillation models found! Run 'distiller distill-simple' first.")
|
| 1167 |
+
return []
|
| 1168 |
+
|
| 1169 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 1170 |
+
for model_path in discovered_models:
|
| 1171 |
+
logger.info(f" π {model_path}")
|
| 1172 |
+
|
| 1173 |
+
return run_local_benchmark(
|
| 1174 |
+
models=discovered_models,
|
| 1175 |
+
quick=quick,
|
| 1176 |
+
output_dir=output_dir,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
if __name__ == "__main__":
|
| 1181 |
+
main()
|
src/distiller/distill.py
ADDED
|
@@ -0,0 +1,1306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code-Specialized Model2Vec Distillation Script with Checkpoint Support.
|
| 3 |
+
|
| 4 |
+
This script implements a focused approach for creating code-specialized embeddings
|
| 5 |
+
using Model2Vec distillation with one additional training round on code-specific tasks.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Incremental checkpoint saving
|
| 9 |
+
- Resume from previous progress
|
| 10 |
+
- Persistent storage of embeddings and models
|
| 11 |
+
- Robust error handling and recovery
|
| 12 |
+
- Smart checkpoint validation for parameter compatibility
|
| 13 |
+
|
| 14 |
+
Approach:
|
| 15 |
+
1. Basic Model2Vec distillation with optimized parameters
|
| 16 |
+
2. Single code specialization round using sentence-transformers/codesearchnet dataset
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import time
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Any
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from beam import GpuType, Image, Volume, function
|
| 29 |
+
from datasets import load_dataset
|
| 30 |
+
from model2vec.distill import distill
|
| 31 |
+
from model2vec.train.base import FinetunableStaticModel, TextDataset
|
| 32 |
+
from sentence_transformers import SentenceTransformer
|
| 33 |
+
from sklearn.model_selection import train_test_split
|
| 34 |
+
from torch import nn, optim
|
| 35 |
+
|
| 36 |
+
from .beam_utils import (
|
| 37 |
+
BeamCheckpointManager,
|
| 38 |
+
BeamModelManager,
|
| 39 |
+
create_beam_utilities,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# CODE-FOCUSED CONFIGURATION
|
| 44 |
+
# =============================================================================
|
| 45 |
+
|
| 46 |
+
# Model Configuration
|
| 47 |
+
MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct"
|
| 48 |
+
OUTPUT_DIR = "gte_qwen2_m2v_code"
|
| 49 |
+
CHECKPOINT_DIR = "gte_qwen2_m2v_code/checkpoints"
|
| 50 |
+
|
| 51 |
+
# Code-optimized parameters
|
| 52 |
+
PCA_DIMS = 512 # Higher dims for code complexity
|
| 53 |
+
TRAINING_EPOCHS = 2
|
| 54 |
+
LEARNING_RATE = 1e-4
|
| 55 |
+
BATCH_SIZE = 32
|
| 56 |
+
REGULARIZATION_WEIGHT = 0.01
|
| 57 |
+
|
| 58 |
+
# CodeSearchNet dataset configuration
|
| 59 |
+
CODESEARCHNET_DATASET = "sentence-transformers/codesearchnet"
|
| 60 |
+
MAX_TRAINING_SAMPLES = 50000 # Limit for manageable training time
|
| 61 |
+
|
| 62 |
+
# Checkpoint configuration
|
| 63 |
+
CHECKPOINT_INTERVAL = 1000 # Save every N samples
|
| 64 |
+
EMBEDDINGS_BATCH_SIZE = 100 # Save embeddings in smaller batches
|
| 65 |
+
|
| 66 |
+
# OPTIMIZED TEACHER MODEL CONFIGURATION FOR 40GB VRAM
|
| 67 |
+
TEACHER_MODEL_CONFIG: dict[str, Any] = {
|
| 68 |
+
"batch_size": 12, # Slightly reduced due to float32 memory usage
|
| 69 |
+
"precision": "float32", # Use float32 for quality preservation
|
| 70 |
+
"max_seq_length": 8192, # Reduce from 32k default for better performance
|
| 71 |
+
"device_map": "auto", # Automatic device placement
|
| 72 |
+
"torch_dtype": torch.float32, # Use float32 for quality preservation
|
| 73 |
+
"trust_remote_code": True,
|
| 74 |
+
"use_flash_attention": True, # Try to enable flash attention if available
|
| 75 |
+
"attn_implementation": "flash_attention_2", # Use flash attention 2 if available
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# =============================================================================
|
| 79 |
+
# BEAM CONFIGURATION
|
| 80 |
+
# =============================================================================
|
| 81 |
+
|
| 82 |
+
GPU_NAME = GpuType.A100_40
|
| 83 |
+
VOLUME_NAME = "gte_qwen2_m2v_code"
|
| 84 |
+
VOLUME_PATH = "./gte_qwen2_m2v_code"
|
| 85 |
+
IMAGE = Image(python_version="python3.12").add_python_packages(
|
| 86 |
+
[
|
| 87 |
+
"torch>=2.7.0", # Install torch first
|
| 88 |
+
"transformers>=4.40.0", # Latest transformers with flash attention support
|
| 89 |
+
"accelerate>=1.7.0",
|
| 90 |
+
"datasets>=3.2.0",
|
| 91 |
+
"model2vec[train]>=0.5.0",
|
| 92 |
+
"numpy>=1.26.4",
|
| 93 |
+
"scikit-learn>=1.6.1",
|
| 94 |
+
"sentence-transformers>=4.1.0",
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 99 |
+
logger = logging.getLogger(__name__)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_current_config_hash() -> str:
|
| 103 |
+
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
| 104 |
+
import hashlib
|
| 105 |
+
|
| 106 |
+
config_params = {
|
| 107 |
+
"model_name": MODEL_NAME,
|
| 108 |
+
"pca_dims": PCA_DIMS,
|
| 109 |
+
"precision": TEACHER_MODEL_CONFIG["precision"],
|
| 110 |
+
"torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
|
| 111 |
+
"max_samples": MAX_TRAINING_SAMPLES,
|
| 112 |
+
"codesearchnet_dataset": CODESEARCHNET_DATASET,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
config_str = str(sorted(config_params.items()))
|
| 116 |
+
return hashlib.md5(config_str.encode()).hexdigest()[:12] # noqa: S324
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def validate_checkpoint_compatibility(checkpoint_data: dict[str, Any]) -> bool:
|
| 120 |
+
"""
|
| 121 |
+
Validate if checkpoint is compatible with current configuration.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
checkpoint_data: Checkpoint data dictionary
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
True if compatible, False otherwise
|
| 128 |
+
"""
|
| 129 |
+
current_hash = get_current_config_hash()
|
| 130 |
+
checkpoint_hash = checkpoint_data.get("config_hash", "")
|
| 131 |
+
|
| 132 |
+
if checkpoint_hash != current_hash:
|
| 133 |
+
logger.warning(f"Configuration mismatch: current={current_hash}, checkpoint={checkpoint_hash}")
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
# Additional validation checks
|
| 137 |
+
checkpoint_config = checkpoint_data.get("config", {})
|
| 138 |
+
|
| 139 |
+
# Check critical parameters
|
| 140 |
+
if checkpoint_config.get("pca_dims") != PCA_DIMS:
|
| 141 |
+
logger.warning(f"PCA dimensions mismatch: current={PCA_DIMS}, checkpoint={checkpoint_config.get('pca_dims')}")
|
| 142 |
+
return False
|
| 143 |
+
|
| 144 |
+
if checkpoint_config.get("precision") != TEACHER_MODEL_CONFIG["precision"]:
|
| 145 |
+
logger.warning(
|
| 146 |
+
f"Precision mismatch: current={TEACHER_MODEL_CONFIG['precision']}, checkpoint={checkpoint_config.get('precision')}"
|
| 147 |
+
)
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
if checkpoint_config.get("max_samples") != MAX_TRAINING_SAMPLES:
|
| 151 |
+
logger.warning(
|
| 152 |
+
f"Max samples mismatch: current={MAX_TRAINING_SAMPLES}, checkpoint={checkpoint_config.get('max_samples')}"
|
| 153 |
+
)
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
logger.info("β
Checkpoint configuration is compatible")
|
| 157 |
+
return True
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def create_checkpoint_data(stage: str, data: dict[str, Any], step: int = 0) -> dict[str, Any]:
|
| 161 |
+
"""
|
| 162 |
+
Create checkpoint data with configuration metadata.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
stage: Checkpoint stage name
|
| 166 |
+
data: Core checkpoint data
|
| 167 |
+
step: Step number
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Enhanced checkpoint data with configuration
|
| 171 |
+
"""
|
| 172 |
+
return {
|
| 173 |
+
"config_hash": get_current_config_hash(),
|
| 174 |
+
"config": {
|
| 175 |
+
"model_name": MODEL_NAME,
|
| 176 |
+
"pca_dims": PCA_DIMS,
|
| 177 |
+
"precision": TEACHER_MODEL_CONFIG["precision"],
|
| 178 |
+
"torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
|
| 179 |
+
"max_samples": MAX_TRAINING_SAMPLES,
|
| 180 |
+
"codesearchnet_dataset": CODESEARCHNET_DATASET,
|
| 181 |
+
},
|
| 182 |
+
"stage": stage,
|
| 183 |
+
"step": step,
|
| 184 |
+
"timestamp": time.time(),
|
| 185 |
+
"data": data,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def load_codesearchnet_dataset_with_resume(
|
| 190 |
+
max_samples: int = MAX_TRAINING_SAMPLES,
|
| 191 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 192 |
+
) -> list[str]:
|
| 193 |
+
"""Load and format the sentence-transformers/codesearchnet dataset with resume capability."""
|
| 194 |
+
logger.info(f"Loading CodeSearchNet dataset from {CODESEARCHNET_DATASET}")
|
| 195 |
+
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
| 196 |
+
|
| 197 |
+
# Check for existing dataset checkpoint with validation
|
| 198 |
+
if checkpoint_manager:
|
| 199 |
+
checkpoint_data = checkpoint_manager.load_checkpoint("dataset", 0)
|
| 200 |
+
if checkpoint_data:
|
| 201 |
+
if validate_checkpoint_compatibility(checkpoint_data):
|
| 202 |
+
texts = checkpoint_data.get("data", {}).get("texts", [])
|
| 203 |
+
if len(texts) >= max_samples:
|
| 204 |
+
logger.info(f"β
Resumed dataset loading: {len(texts)} texts from checkpoint")
|
| 205 |
+
return texts[:max_samples]
|
| 206 |
+
logger.info(f"π Partial dataset found: {len(texts)} texts, continuing from there")
|
| 207 |
+
start_from = len(texts)
|
| 208 |
+
else:
|
| 209 |
+
logger.warning("π Incompatible dataset checkpoint found, starting fresh")
|
| 210 |
+
# Clean up incompatible checkpoint
|
| 211 |
+
checkpoint_manager.cleanup_old_checkpoints("dataset", keep_latest=0)
|
| 212 |
+
texts = []
|
| 213 |
+
start_from = 0
|
| 214 |
+
else:
|
| 215 |
+
texts = []
|
| 216 |
+
start_from = 0
|
| 217 |
+
else:
|
| 218 |
+
texts = []
|
| 219 |
+
start_from = 0
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
# Load the dataset
|
| 223 |
+
dataset = load_dataset(CODESEARCHNET_DATASET, split="train", streaming=True)
|
| 224 |
+
|
| 225 |
+
# Skip to where we left off
|
| 226 |
+
dataset_iter = iter(dataset)
|
| 227 |
+
for _ in range(start_from):
|
| 228 |
+
try:
|
| 229 |
+
next(dataset_iter)
|
| 230 |
+
except StopIteration:
|
| 231 |
+
break
|
| 232 |
+
|
| 233 |
+
for i, example in enumerate(dataset_iter, start=start_from):
|
| 234 |
+
if len(texts) >= max_samples:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
comment = example.get("comment", "").strip()
|
| 238 |
+
code = example.get("code", "").strip()
|
| 239 |
+
|
| 240 |
+
if comment and code and len(comment) > 10 and len(code) > 50:
|
| 241 |
+
# Format as comment-code pair for training
|
| 242 |
+
text = f"Comment: {comment}\nCode:\n{code}"
|
| 243 |
+
|
| 244 |
+
# Ensure reasonable length
|
| 245 |
+
if len(text) <= 2048: # Reasonable limit for embedding models
|
| 246 |
+
texts.append(text)
|
| 247 |
+
|
| 248 |
+
# Save checkpoint periodically
|
| 249 |
+
if checkpoint_manager and (i + 1) % CHECKPOINT_INTERVAL == 0:
|
| 250 |
+
checkpoint_data = create_checkpoint_data("dataset", {"texts": texts}, 0)
|
| 251 |
+
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
|
| 252 |
+
logger.info(f"πΎ Saved dataset checkpoint: {len(texts)} texts collected")
|
| 253 |
+
|
| 254 |
+
if (i + 1) % 10000 == 0:
|
| 255 |
+
logger.info(f"Processed {i + 1} examples, collected {len(texts)} valid pairs")
|
| 256 |
+
|
| 257 |
+
# Final checkpoint save
|
| 258 |
+
if checkpoint_manager:
|
| 259 |
+
checkpoint_data = create_checkpoint_data("dataset", {"texts": texts}, 0)
|
| 260 |
+
checkpoint_manager.save_checkpoint("dataset", checkpoint_data, 0)
|
| 261 |
+
|
| 262 |
+
logger.info(f"Successfully loaded {len(texts)} code-comment pairs from CodeSearchNet")
|
| 263 |
+
return texts
|
| 264 |
+
|
| 265 |
+
except Exception:
|
| 266 |
+
logger.exception("Error loading CodeSearchNet dataset")
|
| 267 |
+
return texts # Return what we have so far
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def generate_teacher_embeddings_with_checkpoints(
|
| 271 |
+
teacher_model: SentenceTransformer,
|
| 272 |
+
texts: list[str],
|
| 273 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 274 |
+
) -> torch.Tensor:
|
| 275 |
+
"""Generate teacher embeddings for code training with checkpoint support."""
|
| 276 |
+
logger.info(f"Generating teacher embeddings for {len(texts)} texts...")
|
| 277 |
+
|
| 278 |
+
# Check for existing embeddings checkpoint using torch.save format
|
| 279 |
+
final_embeddings = None
|
| 280 |
+
|
| 281 |
+
if checkpoint_manager:
|
| 282 |
+
# Try to load complete embeddings tensor directly
|
| 283 |
+
embeddings_path = Path(VOLUME_PATH) / "embeddings_cache.pt"
|
| 284 |
+
config_path = Path(VOLUME_PATH) / "embeddings_config.json"
|
| 285 |
+
|
| 286 |
+
if embeddings_path.exists() and config_path.exists():
|
| 287 |
+
try:
|
| 288 |
+
# Load config first to validate compatibility
|
| 289 |
+
with config_path.open("r") as f:
|
| 290 |
+
config_data = json.load(f)
|
| 291 |
+
|
| 292 |
+
# Create a dummy checkpoint data structure for validation
|
| 293 |
+
checkpoint_data = {
|
| 294 |
+
"config_hash": config_data.get("config_hash"),
|
| 295 |
+
"config": config_data.get("config", {}),
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
if validate_checkpoint_compatibility(checkpoint_data):
|
| 299 |
+
# Load the embeddings tensor
|
| 300 |
+
final_embeddings = torch.load(embeddings_path, map_location="cpu")
|
| 301 |
+
num_expected = config_data.get("num_texts", len(texts))
|
| 302 |
+
|
| 303 |
+
if final_embeddings.shape[0] >= num_expected:
|
| 304 |
+
logger.info(
|
| 305 |
+
f"β
Loaded complete embeddings from cache ({final_embeddings.shape[0]} embeddings)"
|
| 306 |
+
)
|
| 307 |
+
return final_embeddings[: len(texts)] # Return only the needed amount
|
| 308 |
+
logger.info(
|
| 309 |
+
f"β οΈ Cached embeddings incomplete ({final_embeddings.shape[0]}/{num_expected}), regenerating"
|
| 310 |
+
)
|
| 311 |
+
final_embeddings = None
|
| 312 |
+
else:
|
| 313 |
+
logger.warning("π Incompatible embeddings cache found, regenerating")
|
| 314 |
+
final_embeddings = None
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.warning(f"Failed to load embeddings cache: {e}, regenerating...")
|
| 317 |
+
final_embeddings = None
|
| 318 |
+
|
| 319 |
+
# If we have complete embeddings, return them
|
| 320 |
+
if final_embeddings is not None:
|
| 321 |
+
return final_embeddings
|
| 322 |
+
|
| 323 |
+
# Generate embeddings from scratch
|
| 324 |
+
logger.info("Generating fresh teacher embeddings...")
|
| 325 |
+
|
| 326 |
+
# Use optimized batch size for large models with proper type casting
|
| 327 |
+
batch_size_raw = TEACHER_MODEL_CONFIG["batch_size"]
|
| 328 |
+
current_batch_size: int = batch_size_raw if isinstance(batch_size_raw, int) else 16
|
| 329 |
+
logger.info(f"Using optimized batch size: {current_batch_size} for 40GB VRAM (7B model)")
|
| 330 |
+
|
| 331 |
+
embeddings_list = []
|
| 332 |
+
|
| 333 |
+
for i in range(0, len(texts), current_batch_size):
|
| 334 |
+
batch_texts = texts[i : i + current_batch_size]
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
# Use optimized encoding with convert_to_tensor=True for efficiency
|
| 338 |
+
batch_embeddings = teacher_model.encode(
|
| 339 |
+
batch_texts,
|
| 340 |
+
convert_to_tensor=True,
|
| 341 |
+
batch_size=current_batch_size,
|
| 342 |
+
show_progress_bar=False, # Reduce overhead
|
| 343 |
+
normalize_embeddings=True, # Pre-normalize for efficiency
|
| 344 |
+
)
|
| 345 |
+
embeddings_list.append(batch_embeddings)
|
| 346 |
+
|
| 347 |
+
if i % (current_batch_size * 10) == 0:
|
| 348 |
+
logger.info(f"Generated embeddings for {i + len(batch_texts)}/{len(texts)} texts")
|
| 349 |
+
|
| 350 |
+
except torch.cuda.OutOfMemoryError:
|
| 351 |
+
logger.warning(
|
| 352 |
+
f"GPU OOM with batch size {current_batch_size}, reducing to {max(1, current_batch_size // 2)}"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Clear cache and reduce batch size
|
| 356 |
+
if torch.cuda.is_available():
|
| 357 |
+
torch.cuda.empty_cache()
|
| 358 |
+
|
| 359 |
+
current_batch_size = max(1, current_batch_size // 2)
|
| 360 |
+
|
| 361 |
+
# Retry with smaller batch size
|
| 362 |
+
batch_texts = texts[i : i + current_batch_size]
|
| 363 |
+
batch_embeddings = teacher_model.encode(
|
| 364 |
+
batch_texts,
|
| 365 |
+
convert_to_tensor=True,
|
| 366 |
+
batch_size=current_batch_size,
|
| 367 |
+
show_progress_bar=False,
|
| 368 |
+
normalize_embeddings=True,
|
| 369 |
+
)
|
| 370 |
+
embeddings_list.append(batch_embeddings)
|
| 371 |
+
|
| 372 |
+
logger.info(f"Successfully processed batch with reduced size {current_batch_size}")
|
| 373 |
+
|
| 374 |
+
# Combine all embeddings and force fp32 precision
|
| 375 |
+
teacher_embeddings = torch.cat(embeddings_list, dim=0)
|
| 376 |
+
|
| 377 |
+
# Ensure teacher embeddings are in fp32 for maximum quality
|
| 378 |
+
if teacher_embeddings.dtype != torch.float32:
|
| 379 |
+
logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to fp32")
|
| 380 |
+
teacher_embeddings = teacher_embeddings.to(torch.float32)
|
| 381 |
+
|
| 382 |
+
logger.info(f"Generated {teacher_embeddings.shape[0]} teacher embeddings in {teacher_embeddings.dtype}")
|
| 383 |
+
|
| 384 |
+
# Save embeddings cache using torch.save for future runs
|
| 385 |
+
if checkpoint_manager:
|
| 386 |
+
try:
|
| 387 |
+
embeddings_path = Path(VOLUME_PATH) / "embeddings_cache.pt"
|
| 388 |
+
config_path = Path(VOLUME_PATH) / "embeddings_config.json"
|
| 389 |
+
|
| 390 |
+
# Save embeddings tensor
|
| 391 |
+
torch.save(teacher_embeddings, embeddings_path)
|
| 392 |
+
|
| 393 |
+
# Save configuration
|
| 394 |
+
config_data = {
|
| 395 |
+
"config_hash": get_current_config_hash(),
|
| 396 |
+
"config": {
|
| 397 |
+
"model_name": MODEL_NAME,
|
| 398 |
+
"pca_dims": PCA_DIMS,
|
| 399 |
+
"precision": TEACHER_MODEL_CONFIG["precision"],
|
| 400 |
+
"torch_dtype": str(TEACHER_MODEL_CONFIG["torch_dtype"]),
|
| 401 |
+
"max_samples": MAX_TRAINING_SAMPLES,
|
| 402 |
+
"codesearchnet_dataset": CODESEARCHNET_DATASET,
|
| 403 |
+
},
|
| 404 |
+
"num_texts": len(texts),
|
| 405 |
+
"embedding_shape": list(teacher_embeddings.shape),
|
| 406 |
+
"timestamp": time.time(),
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
with config_path.open("w") as f:
|
| 410 |
+
json.dump(config_data, f, indent=2)
|
| 411 |
+
|
| 412 |
+
logger.info("πΎ Saved embeddings cache for future runs")
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.warning(f"Failed to save embeddings cache: {e}")
|
| 416 |
+
|
| 417 |
+
return teacher_embeddings
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def refine_with_code_training(
|
| 421 |
+
student_model: Any,
|
| 422 |
+
training_texts: list[str],
|
| 423 |
+
teacher_embeddings: torch.Tensor,
|
| 424 |
+
epochs: int = 2,
|
| 425 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 426 |
+
model_manager: BeamModelManager | None = None,
|
| 427 |
+
) -> Any:
|
| 428 |
+
"""Refine the student model with code-specific training."""
|
| 429 |
+
logger.info(f"Starting code specialization training for {epochs} epochs...")
|
| 430 |
+
|
| 431 |
+
# Validate input parameters
|
| 432 |
+
if student_model is None:
|
| 433 |
+
logger.error("student_model is None - cannot proceed with code training")
|
| 434 |
+
msg = "student_model cannot be None"
|
| 435 |
+
raise ValueError(msg)
|
| 436 |
+
|
| 437 |
+
if not hasattr(student_model, "embedding"):
|
| 438 |
+
logger.error(f"student_model of type {type(student_model)} does not have 'embedding' attribute")
|
| 439 |
+
msg = f"student_model must have 'embedding' attribute, got {type(student_model)}"
|
| 440 |
+
raise ValueError(msg)
|
| 441 |
+
|
| 442 |
+
logger.info(f"Student model type: {type(student_model)}")
|
| 443 |
+
logger.info(f"Student model embedding shape: {student_model.embedding.shape}")
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
# Force fp32 precision throughout for maximum quality
|
| 447 |
+
target_dtype = torch.float32
|
| 448 |
+
logger.info("π― Enforcing fp32 precision throughout for maximum quality")
|
| 449 |
+
|
| 450 |
+
# Detect student model dtype for logging purposes
|
| 451 |
+
student_dtype = student_model.embedding.dtype
|
| 452 |
+
logger.info(f"Student model original embedding dtype: {student_dtype}")
|
| 453 |
+
|
| 454 |
+
# Force teacher embeddings to fp32 if not already
|
| 455 |
+
if teacher_embeddings.dtype != target_dtype:
|
| 456 |
+
logger.info(f"Converting teacher embeddings from {teacher_embeddings.dtype} to {target_dtype}")
|
| 457 |
+
teacher_embeddings = teacher_embeddings.to(target_dtype)
|
| 458 |
+
|
| 459 |
+
# Get dimensions
|
| 460 |
+
student_embedding_dim = student_model.embedding.shape[1]
|
| 461 |
+
teacher_embedding_dim = teacher_embeddings.shape[1]
|
| 462 |
+
|
| 463 |
+
logger.info(f"Student dims: {student_embedding_dim}, Teacher dims: {teacher_embedding_dim}")
|
| 464 |
+
|
| 465 |
+
# Project teacher embeddings if needed with high-precision PCA
|
| 466 |
+
if teacher_embedding_dim != student_embedding_dim:
|
| 467 |
+
from sklearn.decomposition import PCA
|
| 468 |
+
|
| 469 |
+
logger.info("Performing high-precision PCA projection for quality preservation...")
|
| 470 |
+
pca = PCA(n_components=student_embedding_dim)
|
| 471 |
+
|
| 472 |
+
# Use float64 for PCA computation to maximize precision
|
| 473 |
+
teacher_embeddings_np = teacher_embeddings.cpu().numpy().astype(np.float64)
|
| 474 |
+
teacher_embeddings_projected = pca.fit_transform(teacher_embeddings_np)
|
| 475 |
+
|
| 476 |
+
# Convert back to fp32 (always use fp32, never fp16)
|
| 477 |
+
teacher_embeddings = torch.tensor(
|
| 478 |
+
teacher_embeddings_projected.astype(np.float32),
|
| 479 |
+
dtype=target_dtype,
|
| 480 |
+
)
|
| 481 |
+
logger.info(f"PCA projection completed: {teacher_embeddings.shape} with dtype {target_dtype}")
|
| 482 |
+
logger.info(
|
| 483 |
+
f"PCA preserved variance ratio: {pca.explained_variance_ratio_[:5].sum():.4f} (first 5 components)"
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Create trainable model
|
| 487 |
+
trainable_model = FinetunableStaticModel.from_static_model(
|
| 488 |
+
model=student_model,
|
| 489 |
+
out_dim=student_embedding_dim,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# Force ALL model parameters to fp32 to ensure no precision loss
|
| 493 |
+
trainable_model = trainable_model.float()
|
| 494 |
+
|
| 495 |
+
# Additional explicit conversion of embedding weights to fp32
|
| 496 |
+
if hasattr(trainable_model, "embeddings") and hasattr(trainable_model.embeddings, "weight"):
|
| 497 |
+
trainable_model.embeddings.weight.data = trainable_model.embeddings.weight.data.to(target_dtype)
|
| 498 |
+
|
| 499 |
+
# Verify final model dtype after model2vec patch fix
|
| 500 |
+
actual_model_dtype = None
|
| 501 |
+
for param in trainable_model.parameters():
|
| 502 |
+
actual_model_dtype = param.dtype
|
| 503 |
+
break
|
| 504 |
+
|
| 505 |
+
logger.info(f"Model parameter dtype: {actual_model_dtype}")
|
| 506 |
+
logger.info(f"Embedding weight dtype: {trainable_model.embeddings.weight.dtype}")
|
| 507 |
+
|
| 508 |
+
# Ensure teacher embeddings are definitely in fp32
|
| 509 |
+
teacher_embeddings = teacher_embeddings.to(target_dtype)
|
| 510 |
+
logger.info(f"Final teacher embeddings dtype: {teacher_embeddings.dtype}")
|
| 511 |
+
logger.info(f"Final model parameter dtype: {actual_model_dtype}")
|
| 512 |
+
|
| 513 |
+
# Verify we're using fp32 throughout
|
| 514 |
+
if teacher_embeddings.dtype != target_dtype:
|
| 515 |
+
logger.warning(f"β οΈ Teacher embeddings not in {target_dtype}: {teacher_embeddings.dtype}")
|
| 516 |
+
if actual_model_dtype != target_dtype:
|
| 517 |
+
logger.warning(f"β οΈ Model parameters not in {target_dtype}: {actual_model_dtype}")
|
| 518 |
+
|
| 519 |
+
logger.info("β
Confirmed fp32 precision throughout the training pipeline")
|
| 520 |
+
|
| 521 |
+
# Tokenize texts
|
| 522 |
+
tokenized_texts = []
|
| 523 |
+
for text in training_texts:
|
| 524 |
+
tokens = trainable_model.tokenize([text])
|
| 525 |
+
if tokens.shape[1] > 0:
|
| 526 |
+
tokenized_texts.append(tokens[0].tolist())
|
| 527 |
+
|
| 528 |
+
# Prepare training data with explicit fp32 casting
|
| 529 |
+
targets = teacher_embeddings[: len(tokenized_texts)]
|
| 530 |
+
|
| 531 |
+
# Force targets to fp32 to maintain maximum precision
|
| 532 |
+
targets = targets.to(target_dtype)
|
| 533 |
+
logger.info(f"Cast targets to fp32: {targets.dtype}")
|
| 534 |
+
|
| 535 |
+
train_texts, val_texts, train_targets, val_targets = train_test_split(
|
| 536 |
+
tokenized_texts, targets, test_size=0.2, random_state=42
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
logger.info(f"Train targets dtype: {train_targets.dtype}")
|
| 540 |
+
logger.info(f"Val targets dtype: {val_targets.dtype}")
|
| 541 |
+
|
| 542 |
+
# Training setup
|
| 543 |
+
train_dataset = TextDataset(train_texts, train_targets)
|
| 544 |
+
val_dataset = TextDataset(val_texts, val_targets)
|
| 545 |
+
|
| 546 |
+
optimizer = optim.Adam(trainable_model.parameters(), lr=LEARNING_RATE)
|
| 547 |
+
mse_loss = nn.MSELoss()
|
| 548 |
+
|
| 549 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
trainable_model = trainable_model.to(device)
|
| 553 |
+
logger.info(f"Training on {device}")
|
| 554 |
+
except torch.cuda.OutOfMemoryError:
|
| 555 |
+
logger.warning("GPU OOM loading training model, using CPU")
|
| 556 |
+
device = torch.device("cpu")
|
| 557 |
+
trainable_model = trainable_model.to(device)
|
| 558 |
+
if torch.cuda.is_available():
|
| 559 |
+
torch.cuda.empty_cache()
|
| 560 |
+
|
| 561 |
+
# Adaptive batch size for training
|
| 562 |
+
adaptive_batch_size = BATCH_SIZE
|
| 563 |
+
|
| 564 |
+
# Quality monitoring: compute embedding similarity before training
|
| 565 |
+
logger.info("π Quality monitoring: Computing pre-training teacher-student similarity...")
|
| 566 |
+
trainable_model.eval()
|
| 567 |
+
with torch.no_grad():
|
| 568 |
+
# Take a small sample of texts for quality measurement
|
| 569 |
+
sample_texts = training_texts[: min(5, len(training_texts))]
|
| 570 |
+
sample_tokens = trainable_model.tokenize(sample_texts)
|
| 571 |
+
sample_tokens = sample_tokens.to(device)
|
| 572 |
+
|
| 573 |
+
_, student_embeddings_before = trainable_model(sample_tokens)
|
| 574 |
+
sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
|
| 575 |
+
|
| 576 |
+
# Compute average cosine similarity
|
| 577 |
+
similarities_before = []
|
| 578 |
+
for i in range(len(sample_texts)):
|
| 579 |
+
sim = torch.cosine_similarity(
|
| 580 |
+
student_embeddings_before[i].unsqueeze(0),
|
| 581 |
+
sample_teacher_embeddings[i].unsqueeze(0),
|
| 582 |
+
).item()
|
| 583 |
+
similarities_before.append(sim)
|
| 584 |
+
|
| 585 |
+
avg_similarity_before = np.mean(similarities_before)
|
| 586 |
+
logger.info(f"π Pre-training average teacher-student similarity: {avg_similarity_before:.4f}")
|
| 587 |
+
|
| 588 |
+
# Training loop with validation
|
| 589 |
+
for epoch in range(epochs):
|
| 590 |
+
# Training phase
|
| 591 |
+
trainable_model.train()
|
| 592 |
+
|
| 593 |
+
# Try with current batch size, reduce if OOM
|
| 594 |
+
train_successful = False
|
| 595 |
+
while not train_successful and adaptive_batch_size >= 1:
|
| 596 |
+
try:
|
| 597 |
+
train_loader = train_dataset.to_dataloader(shuffle=True, batch_size=adaptive_batch_size)
|
| 598 |
+
|
| 599 |
+
epoch_loss = 0.0
|
| 600 |
+
num_batches = 0
|
| 601 |
+
|
| 602 |
+
for batch_idx, (tokens, targets_batch) in enumerate(train_loader):
|
| 603 |
+
batch_tokens = tokens.to(device)
|
| 604 |
+
batch_targets = targets_batch.to(device)
|
| 605 |
+
|
| 606 |
+
optimizer.zero_grad()
|
| 607 |
+
_, student_embeddings = trainable_model(batch_tokens)
|
| 608 |
+
|
| 609 |
+
# Debug dtype information on first batch
|
| 610 |
+
if batch_idx == 0:
|
| 611 |
+
logger.info(
|
| 612 |
+
f"Batch {batch_idx}: tokens shape {batch_tokens.shape}, dtype {batch_tokens.dtype}"
|
| 613 |
+
)
|
| 614 |
+
logger.info(
|
| 615 |
+
f"Batch {batch_idx}: targets shape {batch_targets.shape}, dtype {batch_targets.dtype}"
|
| 616 |
+
)
|
| 617 |
+
logger.info(
|
| 618 |
+
f"Batch {batch_idx}: student_embeddings shape {student_embeddings.shape}, dtype {student_embeddings.dtype}"
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# Force both tensors to fp32 to avoid any precision loss
|
| 622 |
+
if student_embeddings.dtype != target_dtype:
|
| 623 |
+
logger.warning(
|
| 624 |
+
f"Student embeddings not in fp32: {student_embeddings.dtype}, converting to fp32"
|
| 625 |
+
)
|
| 626 |
+
student_embeddings = student_embeddings.to(target_dtype)
|
| 627 |
+
if batch_targets.dtype != target_dtype:
|
| 628 |
+
logger.info(f"Converting targets from {batch_targets.dtype} to fp32")
|
| 629 |
+
batch_targets = batch_targets.to(target_dtype)
|
| 630 |
+
|
| 631 |
+
try:
|
| 632 |
+
loss = mse_loss(student_embeddings, batch_targets)
|
| 633 |
+
loss.backward()
|
| 634 |
+
optimizer.step()
|
| 635 |
+
except RuntimeError as e:
|
| 636 |
+
if "expected scalar type" in str(e):
|
| 637 |
+
logger.exception("Dtype mismatch error occurred:")
|
| 638 |
+
logger.exception(
|
| 639 |
+
f"student_embeddings: {student_embeddings.shape}, {student_embeddings.dtype}"
|
| 640 |
+
)
|
| 641 |
+
logger.exception(f"batch_targets: {batch_targets.shape}, {batch_targets.dtype}")
|
| 642 |
+
logger.exception(
|
| 643 |
+
f"MSE loss input dtypes: {student_embeddings.dtype} vs {batch_targets.dtype}"
|
| 644 |
+
)
|
| 645 |
+
# Force explicit casting to fp32 for maximum precision
|
| 646 |
+
batch_targets = batch_targets.to(target_dtype)
|
| 647 |
+
student_embeddings = student_embeddings.to(target_dtype)
|
| 648 |
+
logger.info("Emergency dtype fix: forced both to fp32")
|
| 649 |
+
loss = mse_loss(student_embeddings, batch_targets)
|
| 650 |
+
loss.backward()
|
| 651 |
+
optimizer.step()
|
| 652 |
+
else:
|
| 653 |
+
raise
|
| 654 |
+
|
| 655 |
+
epoch_loss += loss.item()
|
| 656 |
+
num_batches += 1
|
| 657 |
+
|
| 658 |
+
# Save training checkpoint periodically
|
| 659 |
+
if checkpoint_manager and batch_idx % 100 == 0:
|
| 660 |
+
training_state = {
|
| 661 |
+
"epoch": epoch,
|
| 662 |
+
"batch": batch_idx,
|
| 663 |
+
"model_state": trainable_model.state_dict(),
|
| 664 |
+
"optimizer_state": optimizer.state_dict(),
|
| 665 |
+
"loss": epoch_loss / max(1, num_batches),
|
| 666 |
+
}
|
| 667 |
+
checkpoint_data = create_checkpoint_data("training", training_state, epoch)
|
| 668 |
+
checkpoint_manager.save_checkpoint("training", checkpoint_data, epoch)
|
| 669 |
+
|
| 670 |
+
train_successful = True
|
| 671 |
+
|
| 672 |
+
except torch.cuda.OutOfMemoryError:
|
| 673 |
+
logger.warning(
|
| 674 |
+
f"Training OOM with batch size {adaptive_batch_size}, reducing to {adaptive_batch_size // 2}"
|
| 675 |
+
)
|
| 676 |
+
adaptive_batch_size = max(1, adaptive_batch_size // 2)
|
| 677 |
+
if torch.cuda.is_available():
|
| 678 |
+
torch.cuda.empty_cache()
|
| 679 |
+
|
| 680 |
+
if not train_successful:
|
| 681 |
+
logger.error("Unable to train even with batch size 1, skipping training")
|
| 682 |
+
break
|
| 683 |
+
|
| 684 |
+
avg_train_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
|
| 685 |
+
|
| 686 |
+
# Validation phase
|
| 687 |
+
trainable_model.eval()
|
| 688 |
+
val_loader = val_dataset.to_dataloader(shuffle=False, batch_size=adaptive_batch_size)
|
| 689 |
+
val_loss = 0.0
|
| 690 |
+
val_batches = 0
|
| 691 |
+
|
| 692 |
+
with torch.no_grad():
|
| 693 |
+
for tokens, targets_batch in val_loader:
|
| 694 |
+
batch_tokens = tokens.to(device)
|
| 695 |
+
batch_targets = targets_batch.to(device)
|
| 696 |
+
|
| 697 |
+
_, student_embeddings = trainable_model(batch_tokens)
|
| 698 |
+
|
| 699 |
+
# Force both tensors to fp32 to avoid any precision loss in validation
|
| 700 |
+
if student_embeddings.dtype != target_dtype:
|
| 701 |
+
student_embeddings = student_embeddings.to(target_dtype)
|
| 702 |
+
if batch_targets.dtype != target_dtype:
|
| 703 |
+
batch_targets = batch_targets.to(target_dtype)
|
| 704 |
+
|
| 705 |
+
loss = mse_loss(student_embeddings, batch_targets)
|
| 706 |
+
val_loss += loss.item()
|
| 707 |
+
val_batches += 1
|
| 708 |
+
|
| 709 |
+
avg_val_loss = val_loss / val_batches if val_batches > 0 else 0.0
|
| 710 |
+
|
| 711 |
+
logger.info(
|
| 712 |
+
f"Epoch {epoch + 1}/{epochs} - Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}, Batch Size: {adaptive_batch_size}"
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# Save epoch checkpoint
|
| 716 |
+
if checkpoint_manager:
|
| 717 |
+
epoch_state = {
|
| 718 |
+
"epoch": epoch + 1,
|
| 719 |
+
"model_state": trainable_model.state_dict(),
|
| 720 |
+
"optimizer_state": optimizer.state_dict(),
|
| 721 |
+
"train_loss": avg_train_loss,
|
| 722 |
+
"val_loss": avg_val_loss,
|
| 723 |
+
}
|
| 724 |
+
checkpoint_data = create_checkpoint_data("epoch", epoch_state, epoch + 1)
|
| 725 |
+
checkpoint_manager.save_checkpoint("epoch", checkpoint_data, epoch + 1)
|
| 726 |
+
|
| 727 |
+
# Quality monitoring: compute embedding similarity after training
|
| 728 |
+
logger.info("π Quality monitoring: Computing post-training teacher-student similarity...")
|
| 729 |
+
trainable_model.eval()
|
| 730 |
+
with torch.no_grad():
|
| 731 |
+
# Use the same sample texts as before
|
| 732 |
+
sample_texts = training_texts[: min(5, len(training_texts))]
|
| 733 |
+
sample_tokens = trainable_model.tokenize(sample_texts)
|
| 734 |
+
sample_tokens = sample_tokens.to(device)
|
| 735 |
+
|
| 736 |
+
_, student_embeddings_after = trainable_model(sample_tokens)
|
| 737 |
+
sample_teacher_embeddings = targets[: len(sample_texts)].to(device)
|
| 738 |
+
|
| 739 |
+
# Compute average cosine similarity
|
| 740 |
+
similarities_after = []
|
| 741 |
+
for i in range(len(sample_texts)):
|
| 742 |
+
sim = torch.cosine_similarity(
|
| 743 |
+
student_embeddings_after[i].unsqueeze(0),
|
| 744 |
+
sample_teacher_embeddings[i].unsqueeze(0),
|
| 745 |
+
).item()
|
| 746 |
+
similarities_after.append(sim)
|
| 747 |
+
|
| 748 |
+
avg_similarity_after = np.mean(similarities_after)
|
| 749 |
+
logger.info(f"π Post-training average teacher-student similarity: {avg_similarity_after:.4f}")
|
| 750 |
+
|
| 751 |
+
# Quality assessment
|
| 752 |
+
quality_change = avg_similarity_after - avg_similarity_before
|
| 753 |
+
logger.info(f"π Quality change: {quality_change:+.4f}")
|
| 754 |
+
|
| 755 |
+
if abs(quality_change) < 0.01:
|
| 756 |
+
logger.info("β
Quality well preserved during training!")
|
| 757 |
+
elif quality_change > 0:
|
| 758 |
+
logger.info("β
Quality improved during training!")
|
| 759 |
+
else:
|
| 760 |
+
logger.warning(f"β οΈ Quality degraded by {abs(quality_change):.4f} during training")
|
| 761 |
+
|
| 762 |
+
# Convert back to static model
|
| 763 |
+
refined_model = trainable_model.to_static_model()
|
| 764 |
+
|
| 765 |
+
# Save final refined model to beam volume
|
| 766 |
+
if model_manager:
|
| 767 |
+
# Save to temporary local directory first
|
| 768 |
+
temp_refined_path = Path("./temp_refined_save")
|
| 769 |
+
temp_refined_path.mkdir(exist_ok=True)
|
| 770 |
+
refined_model.save_pretrained(str(temp_refined_path))
|
| 771 |
+
|
| 772 |
+
# Upload to beam volume
|
| 773 |
+
model_manager.save_model("refined_model", str(temp_refined_path))
|
| 774 |
+
|
| 775 |
+
# Clean up temp directory
|
| 776 |
+
import shutil
|
| 777 |
+
|
| 778 |
+
shutil.rmtree(temp_refined_path, ignore_errors=True)
|
| 779 |
+
|
| 780 |
+
logger.info("πΎ Saved refined model to beam volume")
|
| 781 |
+
|
| 782 |
+
logger.info("Code specialization training completed")
|
| 783 |
+
return refined_model
|
| 784 |
+
|
| 785 |
+
except Exception as e:
|
| 786 |
+
logger.warning(f"Code training failed: {e}")
|
| 787 |
+
return student_model
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def apply_regularization(model: Any, weight: float = 0.01) -> Any:
|
| 791 |
+
"""Apply light regularization with overflow protection."""
|
| 792 |
+
# Validate input
|
| 793 |
+
if model is None:
|
| 794 |
+
logger.error("Cannot apply regularization: model is None")
|
| 795 |
+
msg = "model cannot be None"
|
| 796 |
+
raise ValueError(msg)
|
| 797 |
+
|
| 798 |
+
if not hasattr(model, "embedding"):
|
| 799 |
+
logger.error(f"Cannot apply regularization: model of type {type(model)} does not have 'embedding' attribute")
|
| 800 |
+
msg = f"model must have 'embedding' attribute, got {type(model)}"
|
| 801 |
+
raise ValueError(msg)
|
| 802 |
+
|
| 803 |
+
logger.info(f"Applying regularization to model of type: {type(model)}")
|
| 804 |
+
|
| 805 |
+
try:
|
| 806 |
+
embeddings = model.embedding.copy()
|
| 807 |
+
|
| 808 |
+
# Check for extreme values and clip if necessary
|
| 809 |
+
max_val = np.abs(embeddings).max()
|
| 810 |
+
if max_val > 1e6: # Clip extremely large values
|
| 811 |
+
logger.warning(f"Large embedding values detected (max: {max_val:.2e}), clipping to prevent overflow")
|
| 812 |
+
embeddings = np.clip(embeddings, -1e6, 1e6)
|
| 813 |
+
|
| 814 |
+
# Apply regularization
|
| 815 |
+
regularized_embeddings = embeddings * (1.0 - weight)
|
| 816 |
+
|
| 817 |
+
# Stable normalization to prevent overflow
|
| 818 |
+
norms = np.linalg.norm(regularized_embeddings, axis=1, keepdims=True)
|
| 819 |
+
|
| 820 |
+
# Handle zero norms and potential overflow
|
| 821 |
+
norms = np.where(norms == 0, 1, norms)
|
| 822 |
+
norms = np.where(norms > 1e6, 1e6, norms) # Prevent extremely large norms
|
| 823 |
+
|
| 824 |
+
regularized_embeddings = regularized_embeddings / norms
|
| 825 |
+
|
| 826 |
+
# Create new model
|
| 827 |
+
from model2vec.model import StaticModel
|
| 828 |
+
|
| 829 |
+
regularized_model = StaticModel(
|
| 830 |
+
vectors=regularized_embeddings,
|
| 831 |
+
tokenizer=model.tokenizer,
|
| 832 |
+
config=model.config,
|
| 833 |
+
base_model_name=model.base_model_name,
|
| 834 |
+
language=model.language,
|
| 835 |
+
normalize=True,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
logger.info("Regularization applied successfully")
|
| 839 |
+
return regularized_model
|
| 840 |
+
|
| 841 |
+
except Exception as e:
|
| 842 |
+
logger.warning(f"Regularization failed: {e}")
|
| 843 |
+
return model
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
def load_teacher_model_with_cache(
|
| 847 |
+
model_name: str,
|
| 848 |
+
output_dir: str,
|
| 849 |
+
device: str = "cuda",
|
| 850 |
+
resume: bool = True,
|
| 851 |
+
) -> SentenceTransformer:
|
| 852 |
+
"""Load teacher model with local caching to avoid re-downloading."""
|
| 853 |
+
cache_dir = Path(output_dir) / "teacher_model_cache"
|
| 854 |
+
|
| 855 |
+
# Check if cached model exists
|
| 856 |
+
if resume and cache_dir.exists():
|
| 857 |
+
try:
|
| 858 |
+
logger.info(f"Loading cached teacher model from {cache_dir}")
|
| 859 |
+
teacher_model = SentenceTransformer(str(cache_dir), device=device)
|
| 860 |
+
|
| 861 |
+
# Set optimized sequence length
|
| 862 |
+
max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
|
| 863 |
+
if isinstance(max_seq_len, int):
|
| 864 |
+
teacher_model.max_seq_length = max_seq_len
|
| 865 |
+
|
| 866 |
+
logger.info("Successfully loaded cached teacher model")
|
| 867 |
+
return teacher_model
|
| 868 |
+
except Exception as e:
|
| 869 |
+
logger.warning(f"Failed to load cached teacher model: {e}")
|
| 870 |
+
logger.info("Will download fresh model")
|
| 871 |
+
|
| 872 |
+
# Download and cache the model
|
| 873 |
+
logger.info(f"Downloading teacher model {model_name} (this may take a while)")
|
| 874 |
+
|
| 875 |
+
# Prepare model kwargs with flash attention
|
| 876 |
+
model_kwargs = {
|
| 877 |
+
"torch_dtype": TEACHER_MODEL_CONFIG["torch_dtype"],
|
| 878 |
+
"device_map": TEACHER_MODEL_CONFIG["device_map"],
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
# Try to add flash attention if available
|
| 882 |
+
if TEACHER_MODEL_CONFIG.get("use_flash_attention", False):
|
| 883 |
+
try:
|
| 884 |
+
model_kwargs["attn_implementation"] = TEACHER_MODEL_CONFIG["attn_implementation"]
|
| 885 |
+
logger.info("Flash Attention 2 enabled")
|
| 886 |
+
except Exception as e:
|
| 887 |
+
logger.warning(f"Flash Attention not available, using default attention: {e}")
|
| 888 |
+
|
| 889 |
+
try:
|
| 890 |
+
teacher_model = SentenceTransformer(
|
| 891 |
+
model_name,
|
| 892 |
+
device=device,
|
| 893 |
+
trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
|
| 894 |
+
model_kwargs=model_kwargs,
|
| 895 |
+
)
|
| 896 |
+
except ImportError as e:
|
| 897 |
+
if "flash_attn" in str(e):
|
| 898 |
+
logger.warning("Flash Attention 2 not available, falling back to default attention")
|
| 899 |
+
# Remove flash attention from model_kwargs and retry
|
| 900 |
+
model_kwargs_fallback = {k: v for k, v in model_kwargs.items() if k != "attn_implementation"}
|
| 901 |
+
teacher_model = SentenceTransformer(
|
| 902 |
+
model_name,
|
| 903 |
+
device=device,
|
| 904 |
+
trust_remote_code=bool(TEACHER_MODEL_CONFIG["trust_remote_code"]),
|
| 905 |
+
model_kwargs=model_kwargs_fallback,
|
| 906 |
+
)
|
| 907 |
+
else:
|
| 908 |
+
raise
|
| 909 |
+
|
| 910 |
+
# Set optimized sequence length
|
| 911 |
+
max_seq_len = TEACHER_MODEL_CONFIG.get("max_seq_length", 8192)
|
| 912 |
+
if isinstance(max_seq_len, int):
|
| 913 |
+
teacher_model.max_seq_length = max_seq_len
|
| 914 |
+
logger.info(f"Set max_seq_length to {max_seq_len} for better performance")
|
| 915 |
+
|
| 916 |
+
# Cache the model for future use
|
| 917 |
+
try:
|
| 918 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 919 |
+
teacher_model.save(str(cache_dir))
|
| 920 |
+
logger.info(f"Cached teacher model to {cache_dir}")
|
| 921 |
+
except Exception as e:
|
| 922 |
+
logger.warning(f"Failed to cache teacher model: {e}")
|
| 923 |
+
# Continue without caching
|
| 924 |
+
|
| 925 |
+
return teacher_model
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def code_specialized_distillation(
|
| 929 |
+
model_name: str = MODEL_NAME,
|
| 930 |
+
output_dir: str = OUTPUT_DIR,
|
| 931 |
+
pca_dims: int = PCA_DIMS,
|
| 932 |
+
max_samples: int = MAX_TRAINING_SAMPLES,
|
| 933 |
+
resume: bool = True,
|
| 934 |
+
) -> Any:
|
| 935 |
+
"""Main code-specialized distillation function using CodeSearchNet dataset with checkpoint support."""
|
| 936 |
+
output_path = Path(output_dir)
|
| 937 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 938 |
+
|
| 939 |
+
# Initialize Beam utilities
|
| 940 |
+
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
|
| 941 |
+
|
| 942 |
+
logger.info(f"Starting code-specialized distillation of {model_name}")
|
| 943 |
+
logger.info(f"Using CodeSearchNet dataset: {CODESEARCHNET_DATASET}")
|
| 944 |
+
logger.info(f"Resume mode: {resume}")
|
| 945 |
+
|
| 946 |
+
# GPU Diagnostics
|
| 947 |
+
logger.info("=== GPU DIAGNOSTICS ===")
|
| 948 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
| 949 |
+
if torch.cuda.is_available():
|
| 950 |
+
logger.info(f"CUDA version: {torch.version.cuda}")
|
| 951 |
+
logger.info(f"GPU count: {torch.cuda.device_count()}")
|
| 952 |
+
for i in range(torch.cuda.device_count()):
|
| 953 |
+
gpu_name = torch.cuda.get_device_name(i)
|
| 954 |
+
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
| 955 |
+
logger.info(f"GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
|
| 956 |
+
|
| 957 |
+
# Current GPU memory
|
| 958 |
+
current_device = torch.cuda.current_device()
|
| 959 |
+
allocated = torch.cuda.memory_allocated(current_device) / 1024**3
|
| 960 |
+
total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
|
| 961 |
+
logger.info(f"Current GPU {current_device}: {allocated:.2f}GB allocated, {total:.1f}GB total")
|
| 962 |
+
else:
|
| 963 |
+
logger.warning("CUDA not available - will use CPU (much slower)")
|
| 964 |
+
logger.info("======================")
|
| 965 |
+
|
| 966 |
+
start_time = time.time()
|
| 967 |
+
|
| 968 |
+
# Step 1: Basic Model2Vec distillation with checkpoint support
|
| 969 |
+
logger.info("Step 1: Basic Model2Vec distillation...")
|
| 970 |
+
|
| 971 |
+
# Check for existing distilled model in beam volume
|
| 972 |
+
m2v_model = None
|
| 973 |
+
if resume:
|
| 974 |
+
# Check if model files exist directly in the volume root
|
| 975 |
+
try:
|
| 976 |
+
# Try to load from the volume root where the model was successfully saved
|
| 977 |
+
volume_root_path = Path(VOLUME_PATH)
|
| 978 |
+
if (volume_root_path / "config.json").exists() and (volume_root_path / "model.safetensors").exists():
|
| 979 |
+
logger.info("β
Found existing model files in volume root")
|
| 980 |
+
from model2vec.model import StaticModel
|
| 981 |
+
|
| 982 |
+
m2v_model = StaticModel.from_pretrained(str(volume_root_path))
|
| 983 |
+
logger.info("β
Successfully loaded existing distilled model from volume")
|
| 984 |
+
else:
|
| 985 |
+
logger.info("No existing model files found in volume root")
|
| 986 |
+
except Exception as e:
|
| 987 |
+
logger.warning(f"Failed to load existing model from volume: {e}")
|
| 988 |
+
m2v_model = None
|
| 989 |
+
|
| 990 |
+
if m2v_model is None:
|
| 991 |
+
# Clear GPU cache before starting
|
| 992 |
+
if torch.cuda.is_available():
|
| 993 |
+
torch.cuda.empty_cache()
|
| 994 |
+
current_device = torch.cuda.current_device()
|
| 995 |
+
allocated = torch.cuda.memory_allocated(current_device) / 1024**3
|
| 996 |
+
total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3
|
| 997 |
+
logger.info(f"GPU memory before distillation: {allocated:.2f}GB allocated / {total:.1f}GB total")
|
| 998 |
+
else:
|
| 999 |
+
logger.info("Using CPU for distillation")
|
| 1000 |
+
|
| 1001 |
+
try:
|
| 1002 |
+
m2v_model = distill(
|
| 1003 |
+
model_name=model_name,
|
| 1004 |
+
pca_dims=pca_dims,
|
| 1005 |
+
apply_zipf=None,
|
| 1006 |
+
sif_coefficient=1e-4,
|
| 1007 |
+
trust_remote_code=True,
|
| 1008 |
+
)
|
| 1009 |
+
logger.info("Basic distillation completed with preserved precision")
|
| 1010 |
+
|
| 1011 |
+
# Validate the distilled model
|
| 1012 |
+
if m2v_model is None:
|
| 1013 |
+
msg = "Distillation returned None - this should not happen"
|
| 1014 |
+
raise ValueError(msg) from None
|
| 1015 |
+
|
| 1016 |
+
logger.info(f"Distilled model type: {type(m2v_model)}")
|
| 1017 |
+
logger.info(f"Distilled model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
|
| 1018 |
+
|
| 1019 |
+
# Save the base distilled model - DISABLED due to recursive directory bug
|
| 1020 |
+
# model_mgr.save_model("base_distilled_model", str(output_path))
|
| 1021 |
+
|
| 1022 |
+
except torch.cuda.OutOfMemoryError:
|
| 1023 |
+
logger.warning("GPU OOM during distillation, clearing cache and retrying...")
|
| 1024 |
+
torch.cuda.empty_cache()
|
| 1025 |
+
|
| 1026 |
+
# Force CPU-only distillation if GPU fails
|
| 1027 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 1028 |
+
|
| 1029 |
+
logger.info("Retrying distillation on CPU...")
|
| 1030 |
+
m2v_model = distill(
|
| 1031 |
+
model_name=model_name,
|
| 1032 |
+
pca_dims=pca_dims,
|
| 1033 |
+
apply_zipf=None,
|
| 1034 |
+
sif_coefficient=1e-4,
|
| 1035 |
+
trust_remote_code=True,
|
| 1036 |
+
)
|
| 1037 |
+
logger.info("Basic distillation completed on CPU")
|
| 1038 |
+
|
| 1039 |
+
# Validate the distilled model
|
| 1040 |
+
if m2v_model is None:
|
| 1041 |
+
msg = "CPU distillation returned None - this should not happen"
|
| 1042 |
+
raise ValueError(msg) from None
|
| 1043 |
+
|
| 1044 |
+
logger.info(f"CPU distilled model type: {type(m2v_model)}")
|
| 1045 |
+
logger.info(f"CPU distilled model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
|
| 1046 |
+
|
| 1047 |
+
# Save the base distilled model - DISABLED due to recursive directory bug
|
| 1048 |
+
# model_mgr.save_model("base_distilled_model", str(output_path))
|
| 1049 |
+
|
| 1050 |
+
except Exception:
|
| 1051 |
+
logger.exception("Distillation failed with error")
|
| 1052 |
+
raise
|
| 1053 |
+
|
| 1054 |
+
# Validate m2v_model before proceeding
|
| 1055 |
+
if m2v_model is None:
|
| 1056 |
+
msg = "m2v_model is None after distillation step - cannot proceed"
|
| 1057 |
+
raise ValueError(msg)
|
| 1058 |
+
|
| 1059 |
+
# Step 2: Load CodeSearchNet training data with resume
|
| 1060 |
+
logger.info("Step 2: Loading CodeSearchNet training data...")
|
| 1061 |
+
code_texts = load_codesearchnet_dataset_with_resume(max_samples, checkpoint_mgr)
|
| 1062 |
+
|
| 1063 |
+
if not code_texts:
|
| 1064 |
+
logger.warning("No code training data available, skipping code specialization")
|
| 1065 |
+
else:
|
| 1066 |
+
logger.info("Step 3: Code specialization training...")
|
| 1067 |
+
|
| 1068 |
+
# Check for existing refined model
|
| 1069 |
+
if resume:
|
| 1070 |
+
# Check if refined model exists in beam volume
|
| 1071 |
+
models = model_mgr.list_models()
|
| 1072 |
+
refined_model_exists = any(model["name"] == "refined_model" for model in models)
|
| 1073 |
+
|
| 1074 |
+
if refined_model_exists:
|
| 1075 |
+
# Download model to local path for loading
|
| 1076 |
+
temp_model_path = Path("./temp_refined_model")
|
| 1077 |
+
if model_mgr.load_model("refined_model", temp_model_path):
|
| 1078 |
+
try:
|
| 1079 |
+
from model2vec.model import StaticModel
|
| 1080 |
+
|
| 1081 |
+
refined_model = StaticModel.from_pretrained(str(temp_model_path / "refined_model"))
|
| 1082 |
+
logger.info("β
Resumed from existing refined model")
|
| 1083 |
+
m2v_model = refined_model
|
| 1084 |
+
# Clean up temp directory
|
| 1085 |
+
import shutil
|
| 1086 |
+
|
| 1087 |
+
shutil.rmtree(temp_model_path, ignore_errors=True)
|
| 1088 |
+
except Exception as e:
|
| 1089 |
+
logger.warning(f"Failed to load existing refined model: {e}")
|
| 1090 |
+
refined_model = None
|
| 1091 |
+
# Clean up temp directory
|
| 1092 |
+
import shutil
|
| 1093 |
+
|
| 1094 |
+
shutil.rmtree(temp_model_path, ignore_errors=True)
|
| 1095 |
+
else:
|
| 1096 |
+
refined_model = None
|
| 1097 |
+
else:
|
| 1098 |
+
refined_model = None
|
| 1099 |
+
|
| 1100 |
+
if refined_model is None:
|
| 1101 |
+
# Load teacher model with memory management
|
| 1102 |
+
try:
|
| 1103 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1104 |
+
logger.info(f"Loading teacher model on {device} with optimized settings")
|
| 1105 |
+
logger.info(
|
| 1106 |
+
f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
|
| 1107 |
+
)
|
| 1108 |
+
logger.info("Attempting to enable Flash Attention 2 for maximum performance")
|
| 1109 |
+
|
| 1110 |
+
teacher_model = load_teacher_model_with_cache(model_name, output_dir, device=device, resume=resume)
|
| 1111 |
+
|
| 1112 |
+
# Generate teacher embeddings with checkpoints
|
| 1113 |
+
teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
|
| 1114 |
+
teacher_model, code_texts, checkpoint_mgr
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
# Refine with code training
|
| 1118 |
+
m2v_model = refine_with_code_training(
|
| 1119 |
+
m2v_model,
|
| 1120 |
+
code_texts,
|
| 1121 |
+
teacher_embeddings,
|
| 1122 |
+
epochs=TRAINING_EPOCHS,
|
| 1123 |
+
checkpoint_manager=checkpoint_mgr,
|
| 1124 |
+
model_manager=model_mgr,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
del teacher_model
|
| 1128 |
+
if torch.cuda.is_available():
|
| 1129 |
+
torch.cuda.empty_cache()
|
| 1130 |
+
|
| 1131 |
+
except torch.cuda.OutOfMemoryError:
|
| 1132 |
+
logger.warning("GPU OOM during code training, falling back to CPU...")
|
| 1133 |
+
|
| 1134 |
+
if torch.cuda.is_available():
|
| 1135 |
+
torch.cuda.empty_cache()
|
| 1136 |
+
|
| 1137 |
+
# Force CPU for teacher model with optimized settings (no flash attention on CPU)
|
| 1138 |
+
try:
|
| 1139 |
+
teacher_model = load_teacher_model_with_cache(
|
| 1140 |
+
model_name, output_dir, device="cpu", resume=resume
|
| 1141 |
+
)
|
| 1142 |
+
except ImportError as e:
|
| 1143 |
+
if "flash_attn" in str(e):
|
| 1144 |
+
logger.warning("Flash Attention 2 not available on CPU, using default attention")
|
| 1145 |
+
# Fallback without any special attention implementation
|
| 1146 |
+
teacher_model = load_teacher_model_with_cache(
|
| 1147 |
+
model_name, output_dir, device="cpu", resume=resume
|
| 1148 |
+
)
|
| 1149 |
+
else:
|
| 1150 |
+
raise
|
| 1151 |
+
|
| 1152 |
+
# Generate teacher embeddings on CPU with checkpoints
|
| 1153 |
+
teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
|
| 1154 |
+
teacher_model, code_texts, checkpoint_mgr
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
# Refine with code training on CPU
|
| 1158 |
+
m2v_model = refine_with_code_training(
|
| 1159 |
+
m2v_model,
|
| 1160 |
+
code_texts,
|
| 1161 |
+
teacher_embeddings,
|
| 1162 |
+
epochs=TRAINING_EPOCHS,
|
| 1163 |
+
checkpoint_manager=checkpoint_mgr,
|
| 1164 |
+
model_manager=model_mgr,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
del teacher_model
|
| 1168 |
+
else:
|
| 1169 |
+
# Fresh training without resume
|
| 1170 |
+
try:
|
| 1171 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1172 |
+
logger.info(f"Loading teacher model on {device} with optimized settings")
|
| 1173 |
+
logger.info(
|
| 1174 |
+
f"Using precision: {TEACHER_MODEL_CONFIG['precision']}, batch_size: {TEACHER_MODEL_CONFIG['batch_size']}"
|
| 1175 |
+
)
|
| 1176 |
+
logger.info("Attempting to enable Flash Attention 2 for maximum performance")
|
| 1177 |
+
|
| 1178 |
+
teacher_model = load_teacher_model_with_cache(model_name, output_dir, device=device, resume=resume)
|
| 1179 |
+
|
| 1180 |
+
# Generate teacher embeddings with checkpoints
|
| 1181 |
+
teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
|
| 1182 |
+
teacher_model, code_texts, checkpoint_mgr
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
# Refine with code training
|
| 1186 |
+
m2v_model = refine_with_code_training(
|
| 1187 |
+
m2v_model,
|
| 1188 |
+
code_texts,
|
| 1189 |
+
teacher_embeddings,
|
| 1190 |
+
epochs=TRAINING_EPOCHS,
|
| 1191 |
+
checkpoint_manager=checkpoint_mgr,
|
| 1192 |
+
model_manager=model_mgr,
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
del teacher_model
|
| 1196 |
+
if torch.cuda.is_available():
|
| 1197 |
+
torch.cuda.empty_cache()
|
| 1198 |
+
|
| 1199 |
+
except torch.cuda.OutOfMemoryError:
|
| 1200 |
+
logger.warning("GPU OOM during code training, falling back to CPU...")
|
| 1201 |
+
|
| 1202 |
+
if torch.cuda.is_available():
|
| 1203 |
+
torch.cuda.empty_cache()
|
| 1204 |
+
|
| 1205 |
+
# Force CPU for teacher model with optimized settings (no flash attention on CPU)
|
| 1206 |
+
try:
|
| 1207 |
+
teacher_model = load_teacher_model_with_cache(model_name, output_dir, device="cpu", resume=resume)
|
| 1208 |
+
except ImportError as e:
|
| 1209 |
+
if "flash_attn" in str(e):
|
| 1210 |
+
logger.warning("Flash Attention 2 not available on CPU, using default attention")
|
| 1211 |
+
# Fallback without any special attention implementation
|
| 1212 |
+
teacher_model = load_teacher_model_with_cache(
|
| 1213 |
+
model_name, output_dir, device="cpu", resume=resume
|
| 1214 |
+
)
|
| 1215 |
+
else:
|
| 1216 |
+
raise
|
| 1217 |
+
|
| 1218 |
+
# Generate teacher embeddings on CPU with checkpoints
|
| 1219 |
+
teacher_embeddings = generate_teacher_embeddings_with_checkpoints(
|
| 1220 |
+
teacher_model, code_texts, checkpoint_mgr
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
# Refine with code training on CPU
|
| 1224 |
+
m2v_model = refine_with_code_training(
|
| 1225 |
+
m2v_model,
|
| 1226 |
+
code_texts,
|
| 1227 |
+
teacher_embeddings,
|
| 1228 |
+
epochs=TRAINING_EPOCHS,
|
| 1229 |
+
checkpoint_manager=checkpoint_mgr,
|
| 1230 |
+
model_manager=model_mgr,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
del teacher_model
|
| 1234 |
+
|
| 1235 |
+
# Step 4: Light regularization
|
| 1236 |
+
logger.info("Step 4: Applying regularization...")
|
| 1237 |
+
m2v_model = apply_regularization(m2v_model, REGULARIZATION_WEIGHT)
|
| 1238 |
+
|
| 1239 |
+
# Save final model
|
| 1240 |
+
logger.info("Saving code-specialized model...")
|
| 1241 |
+
|
| 1242 |
+
# Final validation before saving
|
| 1243 |
+
if m2v_model is None:
|
| 1244 |
+
msg = "Cannot save model: m2v_model is None"
|
| 1245 |
+
raise ValueError(msg)
|
| 1246 |
+
|
| 1247 |
+
if not hasattr(m2v_model, "save_pretrained"):
|
| 1248 |
+
msg = f"Cannot save model: m2v_model of type {type(m2v_model)} does not have save_pretrained method"
|
| 1249 |
+
raise ValueError(msg)
|
| 1250 |
+
|
| 1251 |
+
logger.info(f"Final model type: {type(m2v_model)}")
|
| 1252 |
+
logger.info(f"Final model has embedding attribute: {hasattr(m2v_model, 'embedding')}")
|
| 1253 |
+
|
| 1254 |
+
m2v_model.save_pretrained(str(output_path))
|
| 1255 |
+
|
| 1256 |
+
# Save final model to beam volume as well - DISABLED due to recursive directory bug
|
| 1257 |
+
# model_mgr.save_model("final_model", str(output_path))
|
| 1258 |
+
|
| 1259 |
+
total_time = time.time() - start_time
|
| 1260 |
+
logger.info(f"Code-specialized distillation completed in {total_time:.2f} seconds")
|
| 1261 |
+
|
| 1262 |
+
return m2v_model
|
| 1263 |
+
|
| 1264 |
+
|
| 1265 |
+
@function(
|
| 1266 |
+
gpu=GPU_NAME,
|
| 1267 |
+
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
| 1268 |
+
image=IMAGE,
|
| 1269 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 1270 |
+
env={
|
| 1271 |
+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
|
| 1272 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 1273 |
+
"CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
|
| 1274 |
+
"TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
|
| 1275 |
+
"OMP_NUM_THREADS": "8", # Limit CPU threads for better GPU utilization
|
| 1276 |
+
},
|
| 1277 |
+
timeout=3600 * 12, # 12 hours
|
| 1278 |
+
)
|
| 1279 |
+
def beam_code_distillation(
|
| 1280 |
+
model_name: str = MODEL_NAME,
|
| 1281 |
+
output_dir: str = OUTPUT_DIR,
|
| 1282 |
+
pca_dims: int = PCA_DIMS,
|
| 1283 |
+
max_samples: int = MAX_TRAINING_SAMPLES,
|
| 1284 |
+
resume: bool = True,
|
| 1285 |
+
) -> Any:
|
| 1286 |
+
# Apply all patches from the patches directory
|
| 1287 |
+
try:
|
| 1288 |
+
from .patch_utils import apply_all_patches
|
| 1289 |
+
|
| 1290 |
+
logger.info("Applying all patches from patches directory...")
|
| 1291 |
+
patches_applied = apply_all_patches()
|
| 1292 |
+
logger.info(f"Successfully applied {patches_applied} patches")
|
| 1293 |
+
except Exception as e:
|
| 1294 |
+
logger.warning(f"Failed to apply patches: {e}. Continuing without patches.")
|
| 1295 |
+
|
| 1296 |
+
return code_specialized_distillation(
|
| 1297 |
+
model_name=model_name,
|
| 1298 |
+
output_dir=output_dir,
|
| 1299 |
+
pca_dims=pca_dims,
|
| 1300 |
+
max_samples=max_samples,
|
| 1301 |
+
resume=resume,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
|
| 1305 |
+
if __name__ == "__main__":
|
| 1306 |
+
code_specialized_distillation()
|
src/distiller/distill_simplified.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified Code-Specialized Model2Vec Distillation Script.
|
| 3 |
+
|
| 4 |
+
This script implements a focused, simplified approach for creating code-specialized embeddings
|
| 5 |
+
using only the core Model2Vec distillation without additional fine-tuning that may degrade quality.
|
| 6 |
+
|
| 7 |
+
Can run locally or on Beam with the --use-beam flag.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from beam import GpuType, Image, Volume, function
|
| 19 |
+
from model2vec.distill import distill
|
| 20 |
+
|
| 21 |
+
# =============================================================================
|
| 22 |
+
# SIMPLIFIED CONFIGURATION
|
| 23 |
+
# =============================================================================
|
| 24 |
+
|
| 25 |
+
# Use a code-specialized teacher model instead of general instruction model
|
| 26 |
+
# Ordered by success likelihood and performance:
|
| 27 |
+
CODE_TEACHER_MODELS = [
|
| 28 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 29 |
+
"sentence-transformers/all-mpnet-base-v2",
|
| 30 |
+
"microsoft/codebert-base",
|
| 31 |
+
"microsoft/graphcodebert-base",
|
| 32 |
+
"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
| 33 |
+
"Alibaba-NLP/gte-Qwen2-7B-instruct",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
OUTPUT_BASE_DIR = "code_model2vec"
|
| 37 |
+
|
| 38 |
+
# Optimal Model2Vec parameters based on successful models
|
| 39 |
+
OPTIMAL_PCA_DIMS = 256 # Match other successful Model2Vec models
|
| 40 |
+
SIF_COEFFICIENT = 1e-3 # Slightly higher than default for code specialization
|
| 41 |
+
APPLY_ZIPF = True # Enable Zipf weighting for better word importance
|
| 42 |
+
|
| 43 |
+
# =============================================================================
|
| 44 |
+
# BEAM CONFIGURATION
|
| 45 |
+
# =============================================================================
|
| 46 |
+
|
| 47 |
+
GPU_NAME = GpuType.A100_40
|
| 48 |
+
VOLUME_NAME = "code_model2vec"
|
| 49 |
+
VOLUME_PATH = "./code_model2vec"
|
| 50 |
+
IMAGE = Image(python_version="python3.12").add_python_packages(
|
| 51 |
+
[
|
| 52 |
+
"torch>=2.7.0", # Install torch first
|
| 53 |
+
"transformers>=4.40.0", # Latest transformers with flash attention support
|
| 54 |
+
"lightning>=2.5.1.post0",
|
| 55 |
+
"model2vec[train]>=0.5.0",
|
| 56 |
+
"numpy>=1.26.4",
|
| 57 |
+
"scikit-learn>=1.6.1",
|
| 58 |
+
"sentence-transformers>=4.1.0",
|
| 59 |
+
"datasets>=3.2.0", # For evaluation
|
| 60 |
+
"pandas>=2.0.0",
|
| 61 |
+
"tqdm>=4.65.0",
|
| 62 |
+
]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# =============================================================================
|
| 66 |
+
|
| 67 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
# Add beam utilities for proper model persistence
|
| 71 |
+
try:
|
| 72 |
+
from .beam_utils import (
|
| 73 |
+
create_beam_utilities,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
BEAM_UTILS_AVAILABLE = True
|
| 77 |
+
except ImportError:
|
| 78 |
+
print("Beam utilities not available - models will only be saved locally")
|
| 79 |
+
BEAM_UTILS_AVAILABLE = False
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def apply_local_patches() -> bool:
|
| 83 |
+
"""Apply patches locally without requiring Beam utilities."""
|
| 84 |
+
try:
|
| 85 |
+
# Try using patch_utils if available
|
| 86 |
+
try:
|
| 87 |
+
from .patch_utils import apply_all_patches
|
| 88 |
+
|
| 89 |
+
patches_applied = apply_all_patches()
|
| 90 |
+
logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
|
| 91 |
+
return True
|
| 92 |
+
except ImportError:
|
| 93 |
+
logger.warning("patch_utils not available, trying direct patching")
|
| 94 |
+
|
| 95 |
+
return False
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.warning(f"Failed to apply patches: {e}")
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def simplified_code_distillation(
|
| 103 |
+
teacher_model: str,
|
| 104 |
+
output_dir: str,
|
| 105 |
+
pca_dims: int = OPTIMAL_PCA_DIMS,
|
| 106 |
+
) -> Any:
|
| 107 |
+
"""
|
| 108 |
+
Simplified code-specialized distillation using only core Model2Vec.
|
| 109 |
+
|
| 110 |
+
This approach:
|
| 111 |
+
1. Uses a teacher model that already performs well on code tasks
|
| 112 |
+
2. Applies optimal Model2Vec parameters
|
| 113 |
+
3. Avoids additional training that may degrade quality
|
| 114 |
+
"""
|
| 115 |
+
output_path = Path(output_dir)
|
| 116 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 117 |
+
|
| 118 |
+
logger.info(f"Starting simplified distillation from {teacher_model}")
|
| 119 |
+
logger.info(f"Target dimensions: {pca_dims}")
|
| 120 |
+
logger.info(f"SIF coefficient: {SIF_COEFFICIENT}")
|
| 121 |
+
logger.info(f"Zipf weighting: {APPLY_ZIPF}")
|
| 122 |
+
|
| 123 |
+
start_time = time.time()
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
# Perform distillation with optimal parameters
|
| 127 |
+
model = distill(
|
| 128 |
+
model_name=teacher_model,
|
| 129 |
+
pca_dims=pca_dims,
|
| 130 |
+
apply_zipf=APPLY_ZIPF,
|
| 131 |
+
sif_coefficient=SIF_COEFFICIENT,
|
| 132 |
+
trust_remote_code=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
logger.info("β
Core distillation completed successfully")
|
| 136 |
+
|
| 137 |
+
# Save the model
|
| 138 |
+
model.save_pretrained(str(output_path))
|
| 139 |
+
logger.info(f"πΎ Model saved to {output_path}")
|
| 140 |
+
|
| 141 |
+
# Log model info
|
| 142 |
+
logger.info(f"Model type: {type(model)}")
|
| 143 |
+
if hasattr(model, "embedding"):
|
| 144 |
+
logger.info(f"Embedding shape: {model.embedding.shape}")
|
| 145 |
+
logger.info(f"Embedding dtype: {model.embedding.dtype}")
|
| 146 |
+
|
| 147 |
+
total_time = time.time() - start_time
|
| 148 |
+
logger.info(f"π Simplified distillation completed in {total_time:.2f} seconds")
|
| 149 |
+
return model
|
| 150 |
+
|
| 151 |
+
except ValueError as e:
|
| 152 |
+
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
|
| 153 |
+
logger.warning(f"β οΈ Token-vector mismatch with {teacher_model} - this is a Model2Vec library issue")
|
| 154 |
+
logger.warning(f"Error details: {e}")
|
| 155 |
+
logger.warning("π‘ This model has incompatible tokenization. Skipping...")
|
| 156 |
+
return None
|
| 157 |
+
raise
|
| 158 |
+
except Exception:
|
| 159 |
+
logger.exception("β Distillation failed")
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def core_distill_all_teachers(use_beam_utilities: bool = False) -> dict[str, Any]:
|
| 164 |
+
"""
|
| 165 |
+
Core logic for distilling all teacher models.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
use_beam_utilities: Whether to use Beam utilities for persistence
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Dictionary with distillation results
|
| 172 |
+
"""
|
| 173 |
+
# Apply patches
|
| 174 |
+
logger.info("Applying all patches...")
|
| 175 |
+
patch_success = apply_local_patches()
|
| 176 |
+
if patch_success:
|
| 177 |
+
logger.info("Successfully applied patches")
|
| 178 |
+
else:
|
| 179 |
+
logger.warning("Failed to apply patches - Microsoft models may fail")
|
| 180 |
+
|
| 181 |
+
# Initialize Beam utilities if requested and available
|
| 182 |
+
volume_mgr = None
|
| 183 |
+
model_mgr = None
|
| 184 |
+
if use_beam_utilities and BEAM_UTILS_AVAILABLE:
|
| 185 |
+
try:
|
| 186 |
+
volume_mgr, _, model_mgr, _ = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
|
| 187 |
+
logger.info("β
Beam utilities initialized for model persistence")
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.warning(f"Failed to initialize Beam utilities: {e}")
|
| 190 |
+
model_mgr = None
|
| 191 |
+
|
| 192 |
+
results = {}
|
| 193 |
+
successful_models = []
|
| 194 |
+
|
| 195 |
+
logger.info("π Starting comprehensive teacher model distillation")
|
| 196 |
+
logger.info(f"π Processing {len(CODE_TEACHER_MODELS)} teacher models")
|
| 197 |
+
|
| 198 |
+
# Determine output base path
|
| 199 |
+
base_output_path = VOLUME_PATH if use_beam_utilities else OUTPUT_BASE_DIR
|
| 200 |
+
|
| 201 |
+
for teacher_model in CODE_TEACHER_MODELS:
|
| 202 |
+
try:
|
| 203 |
+
# Create output directory name based on teacher model
|
| 204 |
+
teacher_name = teacher_model.split("/")[-1].replace("-", "_")
|
| 205 |
+
output_dir = f"{base_output_path}/final/code_model2vec_{teacher_name}"
|
| 206 |
+
|
| 207 |
+
logger.info(f"\n{'=' * 60}")
|
| 208 |
+
logger.info(f"π Processing teacher model: {teacher_model}")
|
| 209 |
+
logger.info(f"π Output directory: {output_dir}")
|
| 210 |
+
logger.info(f"{'=' * 60}")
|
| 211 |
+
|
| 212 |
+
# Check if model already exists
|
| 213 |
+
output_path = Path(output_dir)
|
| 214 |
+
if output_path.exists():
|
| 215 |
+
# Check for essential model files
|
| 216 |
+
has_config = (output_path / "config.json").exists()
|
| 217 |
+
has_model_file = any(
|
| 218 |
+
[
|
| 219 |
+
(output_path / "model.safetensors").exists(),
|
| 220 |
+
(output_path / "model.bin").exists(),
|
| 221 |
+
(output_path / "pytorch_model.bin").exists(),
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if has_config and has_model_file:
|
| 226 |
+
logger.info(f"β
Model {teacher_name} already exists - skipping distillation")
|
| 227 |
+
|
| 228 |
+
# Still record it as successful
|
| 229 |
+
model_info = {
|
| 230 |
+
"teacher_model": teacher_model,
|
| 231 |
+
"output_dir": output_dir,
|
| 232 |
+
"teacher_name": teacher_name,
|
| 233 |
+
"distillation_time": 0.0,
|
| 234 |
+
"status": "skipped_existing",
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
results[teacher_name] = model_info
|
| 238 |
+
successful_models.append(teacher_name)
|
| 239 |
+
logger.info(f"π Using existing model at: {output_dir}")
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
# Perform distillation
|
| 243 |
+
start_time = time.time()
|
| 244 |
+
model = simplified_code_distillation(
|
| 245 |
+
teacher_model=teacher_model,
|
| 246 |
+
output_dir=output_dir,
|
| 247 |
+
)
|
| 248 |
+
distill_time = time.time() - start_time
|
| 249 |
+
|
| 250 |
+
if model is not None:
|
| 251 |
+
logger.info(f"β
Distillation successful for {teacher_model}")
|
| 252 |
+
|
| 253 |
+
# Save to Beam volume for persistence if available
|
| 254 |
+
if model_mgr:
|
| 255 |
+
try:
|
| 256 |
+
# Save model to beam volume with teacher-specific name
|
| 257 |
+
beam_model_name = f"{teacher_name}_model"
|
| 258 |
+
model_mgr.save_model(beam_model_name, output_dir)
|
| 259 |
+
logger.info(f"πΎ Saved {teacher_name} to Beam volume as {beam_model_name}")
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.warning(f"Failed to save {teacher_name} to Beam volume: {e}")
|
| 262 |
+
|
| 263 |
+
# Store results
|
| 264 |
+
model_info = {
|
| 265 |
+
"teacher_model": teacher_model,
|
| 266 |
+
"output_dir": output_dir,
|
| 267 |
+
"teacher_name": teacher_name,
|
| 268 |
+
"distillation_time": distill_time,
|
| 269 |
+
"status": "success",
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
results[teacher_name] = model_info
|
| 273 |
+
successful_models.append(teacher_name)
|
| 274 |
+
|
| 275 |
+
logger.info(f"πΎ Model saved to: {output_dir}")
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.exception(f"β Failed with {teacher_model}")
|
| 279 |
+
results[teacher_model.split("/")[-1]] = {
|
| 280 |
+
"teacher_model": teacher_model,
|
| 281 |
+
"status": "failed",
|
| 282 |
+
"error": str(e),
|
| 283 |
+
}
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
# Summary
|
| 287 |
+
if successful_models:
|
| 288 |
+
logger.info("\nπ DISTILLATION COMPLETE!")
|
| 289 |
+
logger.info(f"π Successful models: {len(successful_models)}")
|
| 290 |
+
|
| 291 |
+
for model_name in successful_models:
|
| 292 |
+
model_info = results[model_name]
|
| 293 |
+
logger.info(f"β
{model_name}: {model_info['teacher_model']}")
|
| 294 |
+
|
| 295 |
+
# Save comprehensive results
|
| 296 |
+
results_summary = {
|
| 297 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 298 |
+
"successful_models": successful_models,
|
| 299 |
+
"all_results": results,
|
| 300 |
+
"total_successful": len(successful_models),
|
| 301 |
+
"total_attempted": len(CODE_TEACHER_MODELS),
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
# Save results to file
|
| 305 |
+
results_file = Path(f"{base_output_path}/distillation_results.json")
|
| 306 |
+
results_file.parent.mkdir(parents=True, exist_ok=True)
|
| 307 |
+
with results_file.open("w") as f:
|
| 308 |
+
json.dump(results_summary, f, indent=2)
|
| 309 |
+
|
| 310 |
+
logger.info(f"π Results summary saved to: {results_file}")
|
| 311 |
+
|
| 312 |
+
return results_summary
|
| 313 |
+
|
| 314 |
+
logger.error("β No models succeeded")
|
| 315 |
+
msg = "All teacher models failed distillation"
|
| 316 |
+
raise RuntimeError(msg)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def run_local_distillation() -> dict[str, Any]:
|
| 320 |
+
"""Run distillation locally without Beam."""
|
| 321 |
+
logger.info("π₯οΈ Running simplified distillation locally")
|
| 322 |
+
return core_distill_all_teachers(use_beam_utilities=False)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
@function(
|
| 326 |
+
gpu=GPU_NAME,
|
| 327 |
+
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
| 328 |
+
image=IMAGE,
|
| 329 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 330 |
+
env={
|
| 331 |
+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:512",
|
| 332 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 333 |
+
"CUDA_LAUNCH_BLOCKING": "0", # Allow async CUDA operations
|
| 334 |
+
"TORCH_CUDNN_V8_API_ENABLED": "1", # Enable optimized cuDNN
|
| 335 |
+
},
|
| 336 |
+
timeout=3600 * 12, # 12 hours
|
| 337 |
+
)
|
| 338 |
+
def beam_distill_all_teachers() -> dict[str, Any]:
|
| 339 |
+
"""
|
| 340 |
+
Beam version: Try all teacher models and create distilled models from each.
|
| 341 |
+
|
| 342 |
+
Returns information about all models that were successfully created.
|
| 343 |
+
"""
|
| 344 |
+
logger.info("βοΈ Running simplified distillation on Beam")
|
| 345 |
+
return core_distill_all_teachers(use_beam_utilities=True)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def main() -> None:
|
| 349 |
+
"""Main function with argument parsing."""
|
| 350 |
+
global OUTPUT_BASE_DIR # Declare global at the top # noqa: PLW0603
|
| 351 |
+
|
| 352 |
+
parser = argparse.ArgumentParser(
|
| 353 |
+
description="Simplified Code-Specialized Model2Vec Distillation",
|
| 354 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 355 |
+
epilog="""
|
| 356 |
+
Examples:
|
| 357 |
+
python -m src.distiller.distill_simplified # Run locally
|
| 358 |
+
python -m src.distiller.distill_simplified --use-beam # Run on Beam
|
| 359 |
+
distiller distill-simple # CLI shortcut (runs on Beam)
|
| 360 |
+
""",
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--use-beam",
|
| 365 |
+
action="store_true",
|
| 366 |
+
help="Run on Beam instead of locally",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--output-dir",
|
| 371 |
+
type=str,
|
| 372 |
+
default=OUTPUT_BASE_DIR,
|
| 373 |
+
help=f"Output directory for models (default: {OUTPUT_BASE_DIR})",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
args = parser.parse_args()
|
| 377 |
+
|
| 378 |
+
# Update output directory if specified
|
| 379 |
+
if args.output_dir != OUTPUT_BASE_DIR:
|
| 380 |
+
OUTPUT_BASE_DIR = args.output_dir
|
| 381 |
+
|
| 382 |
+
try:
|
| 383 |
+
if args.use_beam:
|
| 384 |
+
logger.info("π Starting Beam execution...")
|
| 385 |
+
results = beam_distill_all_teachers()
|
| 386 |
+
else:
|
| 387 |
+
logger.info("π₯οΈ Starting local execution...")
|
| 388 |
+
results = run_local_distillation()
|
| 389 |
+
|
| 390 |
+
# Print final summary
|
| 391 |
+
print("\nπ Distillation complete!")
|
| 392 |
+
print(f"π Successfully created {results['total_successful']} models")
|
| 393 |
+
|
| 394 |
+
if args.use_beam:
|
| 395 |
+
print(f"π Models location: {VOLUME_PATH}/final/")
|
| 396 |
+
else:
|
| 397 |
+
print(f"π Models location: {OUTPUT_BASE_DIR}/final/")
|
| 398 |
+
|
| 399 |
+
print("\nβ
Created models:")
|
| 400 |
+
for model_name in results["successful_models"]:
|
| 401 |
+
model_info = results["all_results"][model_name]
|
| 402 |
+
print(f" β’ {model_name} (from {model_info['teacher_model']})")
|
| 403 |
+
|
| 404 |
+
except KeyboardInterrupt:
|
| 405 |
+
logger.info("π Distillation interrupted by user")
|
| 406 |
+
sys.exit(1)
|
| 407 |
+
except Exception:
|
| 408 |
+
logger.exception("β Distillation failed with error")
|
| 409 |
+
sys.exit(1)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
main()
|
src/distiller/evaluate.py
ADDED
|
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CodeSearchNet Evaluation Script for Code-Specialized Embedding Models.
|
| 3 |
+
|
| 4 |
+
This script evaluates embedding models on code search tasks using the CodeSearchNet
|
| 5 |
+
dataset and methodology. It implements the same evaluation approach as the original
|
| 6 |
+
CodeSearchNet challenge, including NDCG and other information retrieval metrics.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
distiller evaluate # Run evaluation on all default models with Beam
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
import time
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
from beam import GpuType, Image, Volume, function
|
| 21 |
+
from datasets import Dataset, load_dataset
|
| 22 |
+
from sentence_transformers import SentenceTransformer
|
| 23 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from .beam_utils import (
|
| 27 |
+
BeamCheckpointManager,
|
| 28 |
+
BeamEvaluationManager,
|
| 29 |
+
create_beam_utilities,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Configure logging
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# =============================================================================
|
| 37 |
+
# BEAM CONFIGURATION
|
| 38 |
+
# =============================================================================
|
| 39 |
+
|
| 40 |
+
GPU_NAME = GpuType.A100_40
|
| 41 |
+
VOLUME_NAME = "code_model2vec" # Same volume as distill_simplified.py
|
| 42 |
+
VOLUME_PATH = "./code_model2vec" # Same mount path as distill_simplified.py
|
| 43 |
+
EVALUATION_RESULTS_DIR = "evaluation_results" # Subdirectory within volume
|
| 44 |
+
EVALUATION_CACHE_DIR = "evaluation_cache" # Cache for datasets and models
|
| 45 |
+
|
| 46 |
+
IMAGE = Image(python_version="python3.12").add_python_packages(
|
| 47 |
+
[
|
| 48 |
+
"torch>=2.7.0",
|
| 49 |
+
"transformers>=4.40.0",
|
| 50 |
+
"datasets>=3.2.0",
|
| 51 |
+
"sentence-transformers>=4.1.0",
|
| 52 |
+
"model2vec[train]>=0.5.0",
|
| 53 |
+
"numpy>=1.26.4",
|
| 54 |
+
"scikit-learn>=1.6.1",
|
| 55 |
+
"pandas>=2.0.0",
|
| 56 |
+
"tqdm>=4.65.0",
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# =============================================================================
|
| 61 |
+
# CONFIGURATION
|
| 62 |
+
# =============================================================================
|
| 63 |
+
|
| 64 |
+
CODESEARCHNET_EVAL_DATASET = "code_search_net"
|
| 65 |
+
BATCH_SIZE = 32
|
| 66 |
+
DEFAULT_OUTPUT_DIR = "code_evaluation_results" # Local fallback directory
|
| 67 |
+
EVALUATION_LANGUAGES = ["python", "javascript", "java", "php", "ruby", "go"]
|
| 68 |
+
|
| 69 |
+
# Default models to evaluate (can be overridden via command line)
|
| 70 |
+
DEFAULT_EVALUATION_MODELS = [
|
| 71 |
+
# Established Code Models
|
| 72 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 73 |
+
"microsoft/codebert-base",
|
| 74 |
+
"microsoft/graphcodebert-base",
|
| 75 |
+
"huggingface/CodeBERTa-small-v1",
|
| 76 |
+
"sentence-transformers/all-mpnet-base-v2",
|
| 77 |
+
"sentence-transformers/all-MiniLM-L12-v2",
|
| 78 |
+
# Model2Vec & Efficiency Models (Direct Competitors)
|
| 79 |
+
"minishlab/potion-base-8M",
|
| 80 |
+
"minishlab/potion-retrieval-32M",
|
| 81 |
+
# Small Transformer-Based Code Models
|
| 82 |
+
"Salesforce/codet5-base",
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# =============================================================================
|
| 86 |
+
# CHECKPOINT CONFIGURATION
|
| 87 |
+
# =============================================================================
|
| 88 |
+
|
| 89 |
+
# Prevent conflicts with distill.py checkpoints by using different prefixes
|
| 90 |
+
EVAL_CHECKPOINT_PREFIX = "evaluation_checkpoints"
|
| 91 |
+
DATASET_CHECKPOINT_PREFIX = "dataset_cache"
|
| 92 |
+
MODEL_CACHE_PREFIX = "model_cache"
|
| 93 |
+
|
| 94 |
+
# =============================================================================
|
| 95 |
+
# CORE EVALUATION CLASSES
|
| 96 |
+
# =============================================================================
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class CodeSearchNetEvaluator:
|
| 100 |
+
"""Evaluator for CodeSearchNet-style code search tasks."""
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
model_path: str,
|
| 105 |
+
model_name: str | None = None,
|
| 106 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 107 |
+
eval_manager: BeamEvaluationManager | None = None,
|
| 108 |
+
) -> None:
|
| 109 |
+
"""Initialize the evaluator with a model and optional Beam utilities."""
|
| 110 |
+
self.model_path = model_path
|
| 111 |
+
self.model_name = model_name or Path(model_path).name
|
| 112 |
+
self.model: SentenceTransformer | None = None
|
| 113 |
+
self.checkpoint_manager = checkpoint_manager
|
| 114 |
+
self.eval_manager = eval_manager
|
| 115 |
+
self._load_model()
|
| 116 |
+
|
| 117 |
+
def _load_model(self) -> None:
|
| 118 |
+
"""Load the embedding model with caching support."""
|
| 119 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 120 |
+
|
| 121 |
+
# Check if we have a cached evaluation result for this model
|
| 122 |
+
if self.eval_manager:
|
| 123 |
+
cached_result = self.eval_manager.load_evaluation_results(self.model_name)
|
| 124 |
+
if cached_result:
|
| 125 |
+
logger.info(f"β
Found cached evaluation results for {self.model_name}")
|
| 126 |
+
# Note: We still need to load the model for new evaluations
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
self.model = SentenceTransformer(self.model_path, trust_remote_code=True)
|
| 130 |
+
logger.info(f"Successfully loaded model: {self.model_name}")
|
| 131 |
+
except Exception:
|
| 132 |
+
logger.exception(f"Failed to load model from {self.model_path}")
|
| 133 |
+
raise
|
| 134 |
+
|
| 135 |
+
def encode_texts(self, texts: list[str], desc: str = "Encoding") -> np.ndarray:
|
| 136 |
+
"""Encode texts into embeddings with batching."""
|
| 137 |
+
if self.model is None:
|
| 138 |
+
msg = "Model not loaded"
|
| 139 |
+
raise RuntimeError(msg)
|
| 140 |
+
|
| 141 |
+
embeddings = []
|
| 142 |
+
|
| 143 |
+
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc=desc):
|
| 144 |
+
batch = texts[i : i + BATCH_SIZE]
|
| 145 |
+
batch_embeddings = self.model.encode(batch, convert_to_tensor=False, normalize_embeddings=True)
|
| 146 |
+
embeddings.append(batch_embeddings)
|
| 147 |
+
|
| 148 |
+
return np.vstack(embeddings)
|
| 149 |
+
|
| 150 |
+
def evaluate_language(self, language: str, max_queries: int = 1000) -> dict[str, Any]:
|
| 151 |
+
"""Evaluate on a specific programming language with checkpoint support."""
|
| 152 |
+
logger.info(f"Evaluating on {language} language (max {max_queries} queries)")
|
| 153 |
+
|
| 154 |
+
# Check for existing evaluation checkpoint
|
| 155 |
+
if self.checkpoint_manager:
|
| 156 |
+
cached_result = self.checkpoint_manager.load_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", 0)
|
| 157 |
+
if cached_result and cached_result.get("data", {}).get("model_name") == self.model_name:
|
| 158 |
+
logger.info(f"β
Resuming from cached {language} evaluation")
|
| 159 |
+
return cached_result.get("data", {})
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
# Load test split for the language
|
| 163 |
+
dataset = load_dataset(
|
| 164 |
+
CODESEARCHNET_EVAL_DATASET,
|
| 165 |
+
language,
|
| 166 |
+
split="test",
|
| 167 |
+
trust_remote_code=True,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Ensure we have a Dataset object
|
| 171 |
+
if not isinstance(dataset, Dataset):
|
| 172 |
+
logger.error(f"Unexpected dataset type for {language}: {type(dataset)}")
|
| 173 |
+
return {}
|
| 174 |
+
|
| 175 |
+
# Sample queries for evaluation (to make it manageable)
|
| 176 |
+
if len(dataset) > max_queries:
|
| 177 |
+
rng = np.random.default_rng(42) # Use seeded generator for reproducibility
|
| 178 |
+
indices = rng.choice(len(dataset), max_queries, replace=False)
|
| 179 |
+
dataset = dataset.select(indices)
|
| 180 |
+
|
| 181 |
+
queries = []
|
| 182 |
+
codes = []
|
| 183 |
+
query_ids = []
|
| 184 |
+
|
| 185 |
+
for i, example in enumerate(dataset):
|
| 186 |
+
doc_string = example.get("func_documentation_string", "").strip()
|
| 187 |
+
code_string = example.get("func_code_string", "").strip()
|
| 188 |
+
|
| 189 |
+
if doc_string and code_string and len(doc_string.split()) >= 3:
|
| 190 |
+
queries.append(doc_string)
|
| 191 |
+
codes.append(code_string)
|
| 192 |
+
query_ids.append(f"{language}_{i}")
|
| 193 |
+
|
| 194 |
+
if len(queries) == 0:
|
| 195 |
+
logger.warning(f"No valid query-code pairs found for {language}")
|
| 196 |
+
return {}
|
| 197 |
+
|
| 198 |
+
logger.info(f"Found {len(queries)} valid query-code pairs for {language}")
|
| 199 |
+
|
| 200 |
+
# Encode queries and codes
|
| 201 |
+
query_embeddings = self.encode_texts(queries, f"Encoding {language} queries")
|
| 202 |
+
code_embeddings = self.encode_texts(codes, f"Encoding {language} codes")
|
| 203 |
+
|
| 204 |
+
# Compute similarities
|
| 205 |
+
similarities = cosine_similarity(query_embeddings, code_embeddings)
|
| 206 |
+
|
| 207 |
+
# Evaluate retrieval metrics
|
| 208 |
+
metrics = self._compute_retrieval_metrics(similarities)
|
| 209 |
+
|
| 210 |
+
result = {
|
| 211 |
+
"language": language,
|
| 212 |
+
"num_queries": len(queries),
|
| 213 |
+
"metrics": metrics,
|
| 214 |
+
"model_name": self.model_name,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
# Save checkpoint
|
| 218 |
+
if self.checkpoint_manager:
|
| 219 |
+
checkpoint_data = {
|
| 220 |
+
"data": result,
|
| 221 |
+
"timestamp": time.time(),
|
| 222 |
+
"config": {
|
| 223 |
+
"language": language,
|
| 224 |
+
"max_queries": max_queries,
|
| 225 |
+
"model_name": self.model_name,
|
| 226 |
+
},
|
| 227 |
+
}
|
| 228 |
+
self.checkpoint_manager.save_checkpoint(f"{EVAL_CHECKPOINT_PREFIX}_{language}", checkpoint_data, 0)
|
| 229 |
+
logger.info(f"πΎ Saved {language} evaluation checkpoint")
|
| 230 |
+
|
| 231 |
+
return result
|
| 232 |
+
|
| 233 |
+
except Exception:
|
| 234 |
+
logger.exception(f"Error evaluating {language}")
|
| 235 |
+
return {}
|
| 236 |
+
|
| 237 |
+
def _compute_retrieval_metrics(self, similarities: np.ndarray) -> dict[str, float]:
|
| 238 |
+
"""Compute retrieval metrics like NDCG, MRR, etc."""
|
| 239 |
+
num_queries = similarities.shape[0]
|
| 240 |
+
|
| 241 |
+
# For each query, the correct code is at the same index (diagonal)
|
| 242 |
+
ranks = []
|
| 243 |
+
reciprocal_ranks = []
|
| 244 |
+
ndcg_scores = []
|
| 245 |
+
|
| 246 |
+
for i in range(num_queries):
|
| 247 |
+
# Get similarity scores for query i
|
| 248 |
+
scores = similarities[i]
|
| 249 |
+
|
| 250 |
+
# Rank all codes by similarity to query i
|
| 251 |
+
ranked_indices = np.argsort(scores)[::-1] # Descending order
|
| 252 |
+
|
| 253 |
+
# Find rank of the correct code (index i)
|
| 254 |
+
correct_rank = np.where(ranked_indices == i)[0][0] + 1 # 1-indexed
|
| 255 |
+
ranks.append(correct_rank)
|
| 256 |
+
reciprocal_ranks.append(1.0 / correct_rank)
|
| 257 |
+
|
| 258 |
+
# Compute NDCG@10
|
| 259 |
+
ndcg_scores.append(self._compute_ndcg(ranked_indices, i, k=10))
|
| 260 |
+
|
| 261 |
+
return {
|
| 262 |
+
"mrr": float(np.mean(reciprocal_ranks)),
|
| 263 |
+
"ndcg@1": float(
|
| 264 |
+
np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=1) for i in range(num_queries)])
|
| 265 |
+
),
|
| 266 |
+
"ndcg@5": float(
|
| 267 |
+
np.mean([self._compute_ndcg(np.argsort(similarities[i])[::-1], i, k=5) for i in range(num_queries)])
|
| 268 |
+
),
|
| 269 |
+
"ndcg@10": float(np.mean(ndcg_scores)),
|
| 270 |
+
"recall@1": float(np.mean([1.0 if rank == 1 else 0.0 for rank in ranks])),
|
| 271 |
+
"recall@5": float(np.mean([1.0 if rank <= 5 else 0.0 for rank in ranks])),
|
| 272 |
+
"recall@10": float(np.mean([1.0 if rank <= 10 else 0.0 for rank in ranks])),
|
| 273 |
+
"mean_rank": float(np.mean(ranks)),
|
| 274 |
+
"median_rank": float(np.median(ranks)),
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
def _compute_ndcg(self, ranked_indices: np.ndarray, correct_idx: int, k: int) -> float:
|
| 278 |
+
"""Compute NDCG@k for a single query."""
|
| 279 |
+
if k == 0:
|
| 280 |
+
return 0.0
|
| 281 |
+
|
| 282 |
+
# Find position of correct item in top-k
|
| 283 |
+
top_k = ranked_indices[:k]
|
| 284 |
+
if correct_idx in top_k:
|
| 285 |
+
position = np.where(top_k == correct_idx)[0][0]
|
| 286 |
+
return 1.0 / np.log2(position + 2) # +2 because log2(1) is 0
|
| 287 |
+
return 0.0
|
| 288 |
+
|
| 289 |
+
def evaluate_all_languages(
|
| 290 |
+
self, max_queries_per_lang: int = 1000, languages: list[str] | None = None
|
| 291 |
+
) -> dict[str, Any]:
|
| 292 |
+
"""Evaluate on all supported programming languages with comprehensive result saving."""
|
| 293 |
+
if languages is None:
|
| 294 |
+
languages = EVALUATION_LANGUAGES
|
| 295 |
+
|
| 296 |
+
logger.info(f"Starting evaluation on all languages for model: {self.model_name}")
|
| 297 |
+
|
| 298 |
+
# Check for existing comprehensive evaluation results
|
| 299 |
+
if self.eval_manager:
|
| 300 |
+
cached_comprehensive = self.eval_manager.load_evaluation_results(self.model_name)
|
| 301 |
+
if cached_comprehensive:
|
| 302 |
+
logger.info(f"β
Found comprehensive cached evaluation for {self.model_name}")
|
| 303 |
+
return cached_comprehensive
|
| 304 |
+
|
| 305 |
+
start_time = time.time()
|
| 306 |
+
|
| 307 |
+
results: dict[str, Any] = {
|
| 308 |
+
"model_name": self.model_name,
|
| 309 |
+
"model_path": self.model_path,
|
| 310 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 311 |
+
"languages": {},
|
| 312 |
+
"overall": {},
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
all_metrics = []
|
| 316 |
+
|
| 317 |
+
for language in languages:
|
| 318 |
+
logger.info(f"Evaluating {language}...")
|
| 319 |
+
lang_results = self.evaluate_language(language, max_queries_per_lang)
|
| 320 |
+
|
| 321 |
+
if lang_results:
|
| 322 |
+
results["languages"][language] = lang_results
|
| 323 |
+
all_metrics.append(lang_results["metrics"])
|
| 324 |
+
else:
|
| 325 |
+
logger.warning(f"Skipping {language} due to evaluation error")
|
| 326 |
+
|
| 327 |
+
# Compute overall metrics (average across languages)
|
| 328 |
+
if all_metrics:
|
| 329 |
+
overall_metrics = {}
|
| 330 |
+
for metric_name in all_metrics[0]:
|
| 331 |
+
values = [m[metric_name] for m in all_metrics if metric_name in m]
|
| 332 |
+
overall_metrics[metric_name] = np.mean(values)
|
| 333 |
+
|
| 334 |
+
results["overall"] = overall_metrics
|
| 335 |
+
|
| 336 |
+
total_time = time.time() - start_time
|
| 337 |
+
results["evaluation_time_seconds"] = total_time
|
| 338 |
+
|
| 339 |
+
# Save comprehensive results to Beam volume
|
| 340 |
+
if self.eval_manager:
|
| 341 |
+
self.eval_manager.save_evaluation_results(self.model_name, results)
|
| 342 |
+
logger.info("πΎ Saved comprehensive evaluation results to Beam volume")
|
| 343 |
+
|
| 344 |
+
logger.info(f"Evaluation completed in {total_time:.2f} seconds")
|
| 345 |
+
return results
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def load_peer_models(peers_file: str) -> list[tuple[str, str]]:
|
| 349 |
+
"""Load peer models from CSV file."""
|
| 350 |
+
try:
|
| 351 |
+
df = pd.read_csv(peers_file)
|
| 352 |
+
models = []
|
| 353 |
+
for _, row in df.iterrows():
|
| 354 |
+
model_name = row.get("model_name", row.get("Model", ""))
|
| 355 |
+
model_path = row.get("model_path", row.get("Path", model_name))
|
| 356 |
+
if model_name:
|
| 357 |
+
models.append((model_name, model_path))
|
| 358 |
+
logger.info(f"Loaded {len(models)} peer models from {peers_file}")
|
| 359 |
+
return models
|
| 360 |
+
except Exception:
|
| 361 |
+
logger.exception("Error loading peer models from {peers_file}")
|
| 362 |
+
return []
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def save_results(
|
| 366 |
+
results: dict[str, Any],
|
| 367 |
+
output_dir: str,
|
| 368 |
+
model_name: str,
|
| 369 |
+
eval_manager: BeamEvaluationManager | None = None,
|
| 370 |
+
volume_results_dir: Path | None = None,
|
| 371 |
+
) -> None:
|
| 372 |
+
"""Save evaluation results to JSON file with Beam volume support."""
|
| 373 |
+
# Save to Beam volume if available
|
| 374 |
+
if volume_results_dir:
|
| 375 |
+
volume_output_path = volume_results_dir / f"codesearchnet_eval_{model_name}.json"
|
| 376 |
+
try:
|
| 377 |
+
with volume_output_path.open("w") as f:
|
| 378 |
+
json.dump(results, f, indent=2, default=str)
|
| 379 |
+
logger.info(f"πΎ Results saved to Beam volume: {volume_output_path}")
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.warning(f"β οΈ Failed to save to Beam volume: {e}")
|
| 382 |
+
|
| 383 |
+
# Also try eval_manager if available (for compatibility)
|
| 384 |
+
if eval_manager:
|
| 385 |
+
success = eval_manager.save_evaluation_results(model_name, results)
|
| 386 |
+
if success:
|
| 387 |
+
logger.info(f"πΎ Results also saved via eval_manager for {model_name}")
|
| 388 |
+
else:
|
| 389 |
+
logger.warning(f"β οΈ Failed to save via eval_manager for {model_name}")
|
| 390 |
+
|
| 391 |
+
# Always save local backup
|
| 392 |
+
output_path = Path(output_dir)
|
| 393 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 394 |
+
|
| 395 |
+
# Clean model name for filename
|
| 396 |
+
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
| 397 |
+
filename = f"codesearchnet_eval_{safe_name}.json"
|
| 398 |
+
filepath = output_path / filename
|
| 399 |
+
|
| 400 |
+
with Path(filepath).open("w") as f:
|
| 401 |
+
json.dump(results, f, indent=2, default=str)
|
| 402 |
+
|
| 403 |
+
logger.info(f"π Local backup saved to {filepath}")
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def print_results_summary(results: dict[str, Any]) -> None:
|
| 407 |
+
"""Print a summary of evaluation results."""
|
| 408 |
+
model_name = results["model_name"]
|
| 409 |
+
overall = results.get("overall", {})
|
| 410 |
+
|
| 411 |
+
print(f"\n{'=' * 60}")
|
| 412 |
+
print(f"CodeSearchNet Evaluation Results: {model_name}")
|
| 413 |
+
print(f"{'=' * 60}")
|
| 414 |
+
|
| 415 |
+
if overall:
|
| 416 |
+
print("\nOverall Metrics (averaged across languages):")
|
| 417 |
+
print(f" MRR: {overall.get('mrr', 0):.4f}")
|
| 418 |
+
print(f" NDCG@1: {overall.get('ndcg@1', 0):.4f}")
|
| 419 |
+
print(f" NDCG@5: {overall.get('ndcg@5', 0):.4f}")
|
| 420 |
+
print(f" NDCG@10: {overall.get('ndcg@10', 0):.4f}")
|
| 421 |
+
print(f" Recall@1: {overall.get('recall@1', 0):.4f}")
|
| 422 |
+
print(f" Recall@5: {overall.get('recall@5', 0):.4f}")
|
| 423 |
+
print(f" Recall@10: {overall.get('recall@10', 0):.4f}")
|
| 424 |
+
|
| 425 |
+
print("\nPer-Language Results:")
|
| 426 |
+
for lang, lang_results in results.get("languages", {}).items():
|
| 427 |
+
metrics = lang_results.get("metrics", {})
|
| 428 |
+
print(
|
| 429 |
+
f" {lang:12s}: MRR={metrics.get('mrr', 0):.3f}, "
|
| 430 |
+
f"NDCG@10={metrics.get('ndcg@10', 0):.3f}, "
|
| 431 |
+
f"Recall@5={metrics.get('recall@5', 0):.3f}"
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def create_comparison_report(all_results: list[dict[str, Any]], output_dir: str) -> None:
|
| 436 |
+
"""Create a comparison report across all evaluated models."""
|
| 437 |
+
if not all_results:
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
output_path = Path(output_dir)
|
| 441 |
+
|
| 442 |
+
# Create comparison DataFrame
|
| 443 |
+
comparison_data = []
|
| 444 |
+
for results in all_results:
|
| 445 |
+
overall = results.get("overall", {})
|
| 446 |
+
row = {
|
| 447 |
+
"Model": results["model_name"],
|
| 448 |
+
"MRR": overall.get("mrr", 0),
|
| 449 |
+
"NDCG@1": overall.get("ndcg@1", 0),
|
| 450 |
+
"NDCG@5": overall.get("ndcg@5", 0),
|
| 451 |
+
"NDCG@10": overall.get("ndcg@10", 0),
|
| 452 |
+
"Recall@1": overall.get("recall@1", 0),
|
| 453 |
+
"Recall@5": overall.get("recall@5", 0),
|
| 454 |
+
"Recall@10": overall.get("recall@10", 0),
|
| 455 |
+
"Mean Rank": overall.get("mean_rank", 0),
|
| 456 |
+
}
|
| 457 |
+
comparison_data.append(row)
|
| 458 |
+
|
| 459 |
+
df = pd.DataFrame(comparison_data)
|
| 460 |
+
df = df.sort_values("NDCG@10", ascending=False) # Sort by NDCG@10
|
| 461 |
+
|
| 462 |
+
# Save to CSV
|
| 463 |
+
csv_path = output_path / "codesearchnet_comparison.csv"
|
| 464 |
+
df.to_csv(csv_path, index=False, float_format="%.4f")
|
| 465 |
+
logger.info(f"Comparison report saved to {csv_path}")
|
| 466 |
+
|
| 467 |
+
# Print comparison table
|
| 468 |
+
print(f"\n{'=' * 80}")
|
| 469 |
+
print("CodeSearchNet Model Comparison")
|
| 470 |
+
print(f"{'=' * 80}")
|
| 471 |
+
print(df.to_string(index=False, float_format="%.4f"))
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def beam_evaluate_models(
|
| 475 |
+
models: list[str],
|
| 476 |
+
max_queries: int = 1000,
|
| 477 |
+
languages: list[str] | None = None,
|
| 478 |
+
output_dir: str = DEFAULT_OUTPUT_DIR,
|
| 479 |
+
volume_name: str = VOLUME_NAME,
|
| 480 |
+
mount_path: str = VOLUME_PATH,
|
| 481 |
+
) -> list[dict[str, Any]]:
|
| 482 |
+
"""Main evaluation function for Beam execution with checkpoint support."""
|
| 483 |
+
logger.info("π Starting Beam-powered CodeSearchNet evaluation")
|
| 484 |
+
logger.info(f"π Evaluating {len(models)} models on {len(languages or EVALUATION_LANGUAGES)} languages")
|
| 485 |
+
|
| 486 |
+
# Initialize Beam utilities
|
| 487 |
+
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(volume_name, mount_path)
|
| 488 |
+
|
| 489 |
+
# Create evaluation results directory in volume
|
| 490 |
+
results_dir = Path(mount_path) / EVALUATION_RESULTS_DIR
|
| 491 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 492 |
+
|
| 493 |
+
logger.info(f"π Using Beam volume: {volume_name} at {mount_path}")
|
| 494 |
+
logger.info(f"πΎ Evaluation results directory: {results_dir}")
|
| 495 |
+
|
| 496 |
+
all_results = []
|
| 497 |
+
skipped_models = []
|
| 498 |
+
|
| 499 |
+
for model_path in models:
|
| 500 |
+
model_name = Path(model_path).name
|
| 501 |
+
|
| 502 |
+
# Check for existing evaluation results
|
| 503 |
+
existing_result_file = results_dir / f"codesearchnet_eval_{model_name}.json"
|
| 504 |
+
if existing_result_file.exists():
|
| 505 |
+
logger.info(f"β
Model {model_name} already evaluated - loading existing results")
|
| 506 |
+
try:
|
| 507 |
+
with existing_result_file.open("r") as f:
|
| 508 |
+
existing_results = json.load(f)
|
| 509 |
+
all_results.append(existing_results)
|
| 510 |
+
skipped_models.append(model_name)
|
| 511 |
+
continue
|
| 512 |
+
except Exception as e:
|
| 513 |
+
logger.warning(f"β οΈ Failed to load existing results for {model_name}: {e}")
|
| 514 |
+
# Continue with evaluation if loading fails
|
| 515 |
+
|
| 516 |
+
logger.info(f"\n{'=' * 60}")
|
| 517 |
+
logger.info(f"π Evaluating model: {model_name}")
|
| 518 |
+
logger.info(f"π Path: {model_path}")
|
| 519 |
+
logger.info(f"{'=' * 60}")
|
| 520 |
+
|
| 521 |
+
try:
|
| 522 |
+
# Distinguish between local paths and HuggingFace model names
|
| 523 |
+
is_huggingface_model = (
|
| 524 |
+
"/" in model_path and not model_path.startswith("/") and not Path(model_path).exists()
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if is_huggingface_model:
|
| 528 |
+
# This is a HuggingFace model name - pass directly to evaluator
|
| 529 |
+
logger.info(f"π₯ Loading HuggingFace model: {model_path}")
|
| 530 |
+
evaluator = CodeSearchNetEvaluator(
|
| 531 |
+
model_path,
|
| 532 |
+
model_name,
|
| 533 |
+
checkpoint_manager=checkpoint_mgr,
|
| 534 |
+
eval_manager=eval_mgr,
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
# This is a local path - check if it exists in Beam volume
|
| 538 |
+
actual_model_path = model_path # Default to original path
|
| 539 |
+
if not Path(model_path).exists() and not model_path.startswith("/"):
|
| 540 |
+
# Try to load from Beam volume
|
| 541 |
+
local_model_path = Path(mount_path) / MODEL_CACHE_PREFIX / model_name
|
| 542 |
+
logger.info(f"π Trying to load {model_name} from Beam volume: {local_model_path}")
|
| 543 |
+
if model_mgr.load_model(model_name, local_model_path.parent):
|
| 544 |
+
actual_model_path = str(local_model_path)
|
| 545 |
+
logger.info(f"β
Loaded model from Beam volume: {actual_model_path}")
|
| 546 |
+
else:
|
| 547 |
+
logger.warning(f"β οΈ Model not found locally or in Beam volume: {model_name}")
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
evaluator = CodeSearchNetEvaluator(
|
| 551 |
+
actual_model_path,
|
| 552 |
+
model_name,
|
| 553 |
+
checkpoint_manager=checkpoint_mgr,
|
| 554 |
+
eval_manager=eval_mgr,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
results = evaluator.evaluate_all_languages(max_queries, languages)
|
| 558 |
+
|
| 559 |
+
# Save results with Beam support
|
| 560 |
+
save_results(results, output_dir, model_name, eval_mgr, results_dir)
|
| 561 |
+
|
| 562 |
+
# Print summary
|
| 563 |
+
print_results_summary(results)
|
| 564 |
+
|
| 565 |
+
all_results.append(results)
|
| 566 |
+
|
| 567 |
+
except Exception:
|
| 568 |
+
logger.exception(f"β Failed to evaluate {model_name}")
|
| 569 |
+
continue
|
| 570 |
+
|
| 571 |
+
# Create comparison report in Beam volume
|
| 572 |
+
if len(all_results) > 1:
|
| 573 |
+
comparison_dir = Path(mount_path) / EVALUATION_RESULTS_DIR / "comparisons"
|
| 574 |
+
comparison_dir.mkdir(parents=True, exist_ok=True)
|
| 575 |
+
create_comparison_report(all_results, str(comparison_dir))
|
| 576 |
+
logger.info(f"π Comparison report saved to Beam volume: {comparison_dir}")
|
| 577 |
+
|
| 578 |
+
# Log summary of what was done
|
| 579 |
+
newly_evaluated = len(all_results) - len(skipped_models)
|
| 580 |
+
logger.info("\nβ
Beam evaluation complete!")
|
| 581 |
+
logger.info(f"π Newly evaluated: {newly_evaluated} models")
|
| 582 |
+
logger.info(f"βοΈ Skipped (already done): {len(skipped_models)} models")
|
| 583 |
+
logger.info(f"π Total results: {len(all_results)} models")
|
| 584 |
+
logger.info(f"πΎ Results available in Beam volume: {volume_name}")
|
| 585 |
+
|
| 586 |
+
if skipped_models:
|
| 587 |
+
logger.info(f"βοΈ Skipped models: {', '.join(skipped_models)}")
|
| 588 |
+
|
| 589 |
+
return all_results
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@function(
|
| 593 |
+
gpu=GPU_NAME,
|
| 594 |
+
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
| 595 |
+
image=IMAGE,
|
| 596 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 597 |
+
env={
|
| 598 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 599 |
+
"CUDA_LAUNCH_BLOCKING": "0",
|
| 600 |
+
},
|
| 601 |
+
timeout=3600 * 6, # 6 hours for evaluation
|
| 602 |
+
)
|
| 603 |
+
def main(skip_third_party: bool = False) -> None:
|
| 604 |
+
"""Main evaluation function - runs all default models on Beam."""
|
| 605 |
+
logger.info("π Starting comprehensive CodeSearchNet evaluation on Beam")
|
| 606 |
+
|
| 607 |
+
# Use default models or skip them based on flag
|
| 608 |
+
if skip_third_party:
|
| 609 |
+
logger.info("βοΈ Skipping 3rd party models - evaluating only simplified distillation models")
|
| 610 |
+
models = []
|
| 611 |
+
else:
|
| 612 |
+
logger.info("π Including 3rd party peer models for comparison")
|
| 613 |
+
models = DEFAULT_EVALUATION_MODELS.copy()
|
| 614 |
+
|
| 615 |
+
# Discover simplified distillation models in the current directory
|
| 616 |
+
logger.info("π Discovering simplified distillation models...")
|
| 617 |
+
discovered_models = discover_simplified_models(".")
|
| 618 |
+
|
| 619 |
+
# Add discovered models (they're already sorted alphabetically)
|
| 620 |
+
if discovered_models:
|
| 621 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 622 |
+
for model_path in discovered_models:
|
| 623 |
+
models.append(model_path)
|
| 624 |
+
logger.info(f" π {model_path}")
|
| 625 |
+
else:
|
| 626 |
+
logger.warning("β οΈ No simplified distillation models found")
|
| 627 |
+
if skip_third_party:
|
| 628 |
+
logger.error("β No models to evaluate! Either create simplified models or include 3rd party models.")
|
| 629 |
+
return
|
| 630 |
+
|
| 631 |
+
logger.info(f"π Evaluating {len(models)} models:")
|
| 632 |
+
for i, model in enumerate(models, 1):
|
| 633 |
+
logger.info(f" {i}. {model}")
|
| 634 |
+
|
| 635 |
+
logger.info("\nπ‘ Checkpoint Info:")
|
| 636 |
+
logger.info(" - Already evaluated models will be skipped")
|
| 637 |
+
logger.info(" - Results are saved persistently to Beam volume")
|
| 638 |
+
|
| 639 |
+
# Run comprehensive evaluation using Beam utilities
|
| 640 |
+
results = beam_evaluate_models(
|
| 641 |
+
models=models,
|
| 642 |
+
max_queries=1000,
|
| 643 |
+
languages=EVALUATION_LANGUAGES,
|
| 644 |
+
output_dir=str(Path(VOLUME_PATH) / EVALUATION_RESULTS_DIR),
|
| 645 |
+
volume_name=VOLUME_NAME,
|
| 646 |
+
mount_path=VOLUME_PATH,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# Print final summary
|
| 650 |
+
print("\nπ― Evaluation Summary:")
|
| 651 |
+
print(f"π Total models processed: {len(results)}")
|
| 652 |
+
print(f"πΎ Results saved to Beam volume: {VOLUME_NAME}")
|
| 653 |
+
print(f"π Directory: {EVALUATION_RESULTS_DIR}")
|
| 654 |
+
if skip_third_party:
|
| 655 |
+
print("βοΈ 3rd party models were skipped")
|
| 656 |
+
print("\nπ To view analysis:")
|
| 657 |
+
print(" beam run src.distiller.analyze:beam_analysis")
|
| 658 |
+
print("\nπ To run evaluations again:")
|
| 659 |
+
print(" distiller evaluate (will skip already completed models)")
|
| 660 |
+
print(" distiller evaluate --skip-third-party (evaluate only simplified models)")
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def discover_simplified_models(base_path: str = ".") -> list[str]:
|
| 664 |
+
"""
|
| 665 |
+
Discover all simplified distillation models in the correct directory.
|
| 666 |
+
|
| 667 |
+
Looks for directories matching the pattern: ./code_model2vec/final/code_model2vec_*
|
| 668 |
+
"""
|
| 669 |
+
discovered_models: list[str] = []
|
| 670 |
+
|
| 671 |
+
# Look in the correct location where distill_simplified.py saves models
|
| 672 |
+
models_dir = Path(base_path) / "code_model2vec" / "final"
|
| 673 |
+
|
| 674 |
+
if not models_dir.exists():
|
| 675 |
+
logger.warning(f"Models directory not found: {models_dir}")
|
| 676 |
+
return discovered_models
|
| 677 |
+
|
| 678 |
+
# Look for simplified model directories with the updated pattern
|
| 679 |
+
pattern = "code_model2vec_*"
|
| 680 |
+
for model_dir in models_dir.glob(pattern):
|
| 681 |
+
if model_dir.is_dir() and (model_dir / "config.json").exists():
|
| 682 |
+
discovered_models.append(str(model_dir))
|
| 683 |
+
logger.info(f"π Discovered simplified model: {model_dir}")
|
| 684 |
+
|
| 685 |
+
# Sort alphabetically for consistent ordering
|
| 686 |
+
discovered_models.sort()
|
| 687 |
+
|
| 688 |
+
return discovered_models
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
@function(
|
| 692 |
+
gpu=GPU_NAME,
|
| 693 |
+
volumes=[Volume(name=VOLUME_NAME, mount_path=VOLUME_PATH)],
|
| 694 |
+
image=IMAGE,
|
| 695 |
+
secrets=["HF_ACCESS_TOKEN"],
|
| 696 |
+
env={
|
| 697 |
+
"TOKENIZERS_PARALLELISM": "false",
|
| 698 |
+
"CUDA_LAUNCH_BLOCKING": "0",
|
| 699 |
+
},
|
| 700 |
+
timeout=3600 * 6, # 6 hours for evaluation
|
| 701 |
+
)
|
| 702 |
+
def evaluate_simplified_only() -> None:
|
| 703 |
+
"""Evaluate only simplified distillation models, skipping 3rd party models."""
|
| 704 |
+
main(skip_third_party=True)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def run_local_evaluation(
|
| 708 |
+
models: list[str] | None = None,
|
| 709 |
+
max_queries: int = 1000,
|
| 710 |
+
languages: list[str] | None = None,
|
| 711 |
+
output_dir: str = DEFAULT_OUTPUT_DIR,
|
| 712 |
+
) -> list[dict[str, Any]]:
|
| 713 |
+
"""Main evaluation function for local execution without Beam utilities."""
|
| 714 |
+
logger.info("π₯οΈ Running CodeSearchNet evaluation locally")
|
| 715 |
+
|
| 716 |
+
if models is None:
|
| 717 |
+
models = DEFAULT_EVALUATION_MODELS.copy()
|
| 718 |
+
|
| 719 |
+
# Discover simplified distillation models in the current directory
|
| 720 |
+
logger.info("π Discovering simplified distillation models...")
|
| 721 |
+
discovered_models = discover_simplified_models(".")
|
| 722 |
+
|
| 723 |
+
# Add discovered models
|
| 724 |
+
if discovered_models:
|
| 725 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 726 |
+
for model_path in discovered_models:
|
| 727 |
+
models.append(model_path)
|
| 728 |
+
logger.info(f" π {model_path}")
|
| 729 |
+
else:
|
| 730 |
+
logger.warning("β οΈ No simplified distillation models found")
|
| 731 |
+
|
| 732 |
+
if languages is None:
|
| 733 |
+
languages = EVALUATION_LANGUAGES
|
| 734 |
+
|
| 735 |
+
logger.info(f"π Evaluating {len(models)} models on {len(languages)} languages")
|
| 736 |
+
logger.info(f"π Using local output directory: {output_dir}")
|
| 737 |
+
|
| 738 |
+
# Create local output directory
|
| 739 |
+
output_path = Path(output_dir)
|
| 740 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 741 |
+
|
| 742 |
+
all_results = []
|
| 743 |
+
skipped_models = []
|
| 744 |
+
|
| 745 |
+
for model_path in models:
|
| 746 |
+
model_name = Path(model_path).name
|
| 747 |
+
|
| 748 |
+
# Check for existing evaluation results locally
|
| 749 |
+
safe_name = "".join(c for c in model_name if c.isalnum() or c in ("-", "_", "."))
|
| 750 |
+
result_file = output_path / f"codesearchnet_eval_{safe_name}.json"
|
| 751 |
+
|
| 752 |
+
if result_file.exists():
|
| 753 |
+
logger.info(f"β
Model {model_name} already evaluated - loading existing results")
|
| 754 |
+
try:
|
| 755 |
+
with result_file.open("r") as f:
|
| 756 |
+
existing_results = json.load(f)
|
| 757 |
+
all_results.append(existing_results)
|
| 758 |
+
skipped_models.append(model_name)
|
| 759 |
+
continue
|
| 760 |
+
except Exception as e:
|
| 761 |
+
logger.warning(f"β οΈ Failed to load existing results for {model_name}: {e}")
|
| 762 |
+
|
| 763 |
+
logger.info(f"\n{'=' * 60}")
|
| 764 |
+
logger.info(f"π Evaluating model: {model_name}")
|
| 765 |
+
logger.info(f"π Path: {model_path}")
|
| 766 |
+
logger.info(f"{'=' * 60}")
|
| 767 |
+
|
| 768 |
+
try:
|
| 769 |
+
# Create evaluator without Beam utilities (no checkpointing)
|
| 770 |
+
evaluator = CodeSearchNetEvaluator(
|
| 771 |
+
model_path,
|
| 772 |
+
model_name,
|
| 773 |
+
checkpoint_manager=None, # No checkpointing for local evaluation
|
| 774 |
+
eval_manager=None,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
results = evaluator.evaluate_all_languages(max_queries, languages)
|
| 778 |
+
|
| 779 |
+
# Save results locally only
|
| 780 |
+
save_results(results, output_dir, model_name, eval_manager=None, volume_results_dir=None)
|
| 781 |
+
|
| 782 |
+
# Print summary
|
| 783 |
+
print_results_summary(results)
|
| 784 |
+
|
| 785 |
+
all_results.append(results)
|
| 786 |
+
|
| 787 |
+
except Exception:
|
| 788 |
+
logger.exception(f"β Failed to evaluate {model_name}")
|
| 789 |
+
continue
|
| 790 |
+
|
| 791 |
+
# Create comparison report locally
|
| 792 |
+
if len(all_results) > 1:
|
| 793 |
+
create_comparison_report(all_results, output_dir)
|
| 794 |
+
logger.info(f"π Comparison report saved locally: {output_dir}")
|
| 795 |
+
|
| 796 |
+
# Log summary
|
| 797 |
+
newly_evaluated = len(all_results) - len(skipped_models)
|
| 798 |
+
logger.info("\nβ
Local evaluation complete!")
|
| 799 |
+
logger.info(f"π Newly evaluated: {newly_evaluated} models")
|
| 800 |
+
logger.info(f"βοΈ Skipped (already done): {len(skipped_models)} models")
|
| 801 |
+
logger.info(f"π Total results: {len(all_results)} models")
|
| 802 |
+
logger.info(f"πΎ Results available locally: {output_dir}")
|
| 803 |
+
|
| 804 |
+
if skipped_models:
|
| 805 |
+
logger.info(f"βοΈ Skipped models: {', '.join(skipped_models)}")
|
| 806 |
+
|
| 807 |
+
return all_results
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def run_local_evaluation_simplified(
|
| 811 |
+
max_queries: int = 1000,
|
| 812 |
+
languages: list[str] | None = None,
|
| 813 |
+
output_dir: str = DEFAULT_OUTPUT_DIR,
|
| 814 |
+
) -> list[dict[str, Any]]:
|
| 815 |
+
"""Local evaluation function for simplified models only."""
|
| 816 |
+
logger.info("π₯οΈ Running simplified model evaluation locally")
|
| 817 |
+
|
| 818 |
+
# Discover simplified distillation models only
|
| 819 |
+
logger.info("π Discovering simplified distillation models...")
|
| 820 |
+
discovered_models = discover_simplified_models(".")
|
| 821 |
+
|
| 822 |
+
if not discovered_models:
|
| 823 |
+
logger.error("β No simplified distillation models found! Run 'distiller distill-simple' first.")
|
| 824 |
+
return []
|
| 825 |
+
|
| 826 |
+
logger.info(f"β
Found {len(discovered_models)} simplified models:")
|
| 827 |
+
for model_path in discovered_models:
|
| 828 |
+
logger.info(f" π {model_path}")
|
| 829 |
+
|
| 830 |
+
return run_local_evaluation(
|
| 831 |
+
models=discovered_models,
|
| 832 |
+
max_queries=max_queries,
|
| 833 |
+
languages=languages,
|
| 834 |
+
output_dir=output_dir,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
if __name__ == "__main__":
|
| 839 |
+
main()
|
src/distiller/patch_utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Patch utilities for applying fixes to installed packages.
|
| 3 |
+
|
| 4 |
+
This module provides functionality to automatically apply all patches
|
| 5 |
+
from the patches directory to fix bugs in third-party libraries.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def find_patches_directory() -> Path:
|
| 17 |
+
"""Find the patches directory relative to the current script location."""
|
| 18 |
+
# Go up from src/distiller/ to project root, then to patches/
|
| 19 |
+
current_file = Path(__file__)
|
| 20 |
+
project_root = current_file.parent.parent.parent # Go up 3 levels: distiller -> src -> project_root
|
| 21 |
+
patches_dir = project_root / "patches"
|
| 22 |
+
|
| 23 |
+
if not patches_dir.exists():
|
| 24 |
+
# Alternative: try relative to current working directory
|
| 25 |
+
patches_dir = Path("patches")
|
| 26 |
+
|
| 27 |
+
return patches_dir
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_site_packages_path() -> Path:
|
| 31 |
+
"""Get the site-packages directory path."""
|
| 32 |
+
import site
|
| 33 |
+
|
| 34 |
+
# Try to get the site-packages from the current environment
|
| 35 |
+
site_packages_dirs = site.getsitepackages()
|
| 36 |
+
|
| 37 |
+
# Prefer the first site-packages directory
|
| 38 |
+
if site_packages_dirs:
|
| 39 |
+
return Path(site_packages_dirs[0])
|
| 40 |
+
|
| 41 |
+
# Fallback: try to find it relative to Python executable
|
| 42 |
+
python_path = Path(sys.executable)
|
| 43 |
+
if python_path.name == "python" or python_path.name.startswith("python"):
|
| 44 |
+
# Standard virtual environment structure
|
| 45 |
+
venv_lib = python_path.parent.parent / "lib"
|
| 46 |
+
for item in venv_lib.iterdir():
|
| 47 |
+
if item.name.startswith("python"):
|
| 48 |
+
site_packages = item / "site-packages"
|
| 49 |
+
if site_packages.exists():
|
| 50 |
+
return site_packages
|
| 51 |
+
|
| 52 |
+
# Last resort: use current directory
|
| 53 |
+
return Path()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
|
| 57 |
+
"""
|
| 58 |
+
Apply a single patch file to the target directory.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
patch_file: Path to the .patch file
|
| 62 |
+
target_dir: Target directory (usually site-packages)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
True if patch was applied successfully, False otherwise
|
| 66 |
+
"""
|
| 67 |
+
try:
|
| 68 |
+
logger.info(f"Applying patch: {patch_file.name}")
|
| 69 |
+
|
| 70 |
+
# Use patch command with the following options:
|
| 71 |
+
# -p1: strip 1 leading directory from paths
|
| 72 |
+
# -d: change to directory before applying
|
| 73 |
+
# -f: force (don't ask questions)
|
| 74 |
+
# -N: don't reverse patches that appear to be already applied
|
| 75 |
+
result = subprocess.run( # noqa: S603
|
| 76 |
+
["patch", "-p1", "-d", str(target_dir), "-f", "-N"], # noqa: S607
|
| 77 |
+
input=patch_file.read_text(),
|
| 78 |
+
text=True,
|
| 79 |
+
capture_output=True,
|
| 80 |
+
check=False, # Don't raise exception on non-zero exit
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if result.returncode == 0:
|
| 84 |
+
logger.info(f"Successfully applied patch: {patch_file.name}")
|
| 85 |
+
return True
|
| 86 |
+
if "already applied" in result.stderr.lower() or "reversed" in result.stderr.lower():
|
| 87 |
+
logger.info(f"Patch {patch_file.name} already applied")
|
| 88 |
+
return True
|
| 89 |
+
logger.warning(f"Failed to apply patch {patch_file.name}: {result.stderr}")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
except FileNotFoundError:
|
| 93 |
+
logger.exception("'patch' command not found. Please install patch utility.")
|
| 94 |
+
return False
|
| 95 |
+
except Exception:
|
| 96 |
+
logger.exception(f"Error applying patch {patch_file.name}")
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def apply_all_patches() -> int:
|
| 101 |
+
"""
|
| 102 |
+
Apply all patches from the patches directory.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Number of patches successfully applied
|
| 106 |
+
"""
|
| 107 |
+
patches_dir = find_patches_directory()
|
| 108 |
+
|
| 109 |
+
if not patches_dir.exists():
|
| 110 |
+
logger.warning(f"Patches directory not found: {patches_dir}")
|
| 111 |
+
return 0
|
| 112 |
+
|
| 113 |
+
# Find all .patch files
|
| 114 |
+
patch_files = list(patches_dir.glob("*.patch"))
|
| 115 |
+
|
| 116 |
+
if not patch_files:
|
| 117 |
+
logger.info("No patch files found")
|
| 118 |
+
return 0
|
| 119 |
+
|
| 120 |
+
# Get target directory (site-packages)
|
| 121 |
+
target_dir = get_site_packages_path()
|
| 122 |
+
logger.info(f"Applying patches to: {target_dir}")
|
| 123 |
+
|
| 124 |
+
success_count = 0
|
| 125 |
+
|
| 126 |
+
# Sort patch files for consistent ordering
|
| 127 |
+
for patch_file in sorted(patch_files):
|
| 128 |
+
if apply_patch_file(patch_file, target_dir):
|
| 129 |
+
success_count += 1
|
| 130 |
+
|
| 131 |
+
logger.info(f"Applied {success_count}/{len(patch_files)} patches successfully")
|
| 132 |
+
return success_count
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def main() -> None:
|
| 136 |
+
"""Main function for standalone execution."""
|
| 137 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 138 |
+
|
| 139 |
+
print("Applying all patches...")
|
| 140 |
+
success_count = apply_all_patches()
|
| 141 |
+
print(f"Done. Applied {success_count} patches.")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
main()
|
src/distiller/sync.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sync utility for downloading files from Beam volume to local directory.
|
| 3 |
+
|
| 4 |
+
This module provides functionality to download generated files from the Beam volume
|
| 5 |
+
back to the local filesystem, including:
|
| 6 |
+
- Final distilled model files (model.safetensors, tokenizer.json, etc.)
|
| 7 |
+
- Analysis reports and charts (README.md, comparison charts, etc.)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import shutil
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from .beam_utils import create_beam_utilities
|
| 15 |
+
|
| 16 |
+
# Configure logging
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Beam volume configuration (must match distill.py)
|
| 21 |
+
VOLUME_NAME = "gte_qwen2_m2v_code"
|
| 22 |
+
VOLUME_PATH = "./gte_qwen2_m2v_code"
|
| 23 |
+
|
| 24 |
+
# Model files to sync
|
| 25 |
+
MODEL_FILES = [
|
| 26 |
+
"model.safetensors",
|
| 27 |
+
"tokenizer.json",
|
| 28 |
+
"modules.json",
|
| 29 |
+
"config.json",
|
| 30 |
+
"pytorch_model.bin", # Backup format
|
| 31 |
+
"vocab.txt", # If present
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Analysis directories and files
|
| 35 |
+
ANALYSIS_DIRS = [
|
| 36 |
+
"analysis_results/reports",
|
| 37 |
+
"analysis_results/charts",
|
| 38 |
+
"evaluation_results",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
ANALYSIS_FILES = [
|
| 42 |
+
"analysis_results/reports/analysis_report.md",
|
| 43 |
+
"analysis_results/reports/README.md",
|
| 44 |
+
"analysis_results/charts/*.png",
|
| 45 |
+
"analysis_results/charts/*.html",
|
| 46 |
+
"evaluation_results/*.json",
|
| 47 |
+
"evaluation_results/comparisons/*.csv",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def sync_model_files(output_dir: str) -> bool:
|
| 52 |
+
"""Download final model files from Beam volume."""
|
| 53 |
+
logger.info("π Syncing model files from Beam volume...")
|
| 54 |
+
|
| 55 |
+
output_path = Path(output_dir)
|
| 56 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# First, let's debug what's actually in the volume
|
| 59 |
+
volume_root = Path(VOLUME_PATH)
|
| 60 |
+
logger.info(f"π Debugging volume contents at: {volume_root}")
|
| 61 |
+
|
| 62 |
+
if volume_root.exists():
|
| 63 |
+
logger.info("π Volume root directory contents:")
|
| 64 |
+
for item in volume_root.iterdir():
|
| 65 |
+
if item.is_file():
|
| 66 |
+
logger.info(f" π {item.name} ({item.stat().st_size} bytes)")
|
| 67 |
+
elif item.is_dir():
|
| 68 |
+
logger.info(f" π {item.name}/ (directory)")
|
| 69 |
+
# List files in important subdirectories
|
| 70 |
+
if item.name in ["models", "checkpoints", "gte_qwen2_m2v_code"]:
|
| 71 |
+
try:
|
| 72 |
+
logger.info(f" Contents of {item.name}/:")
|
| 73 |
+
for subitem in item.iterdir():
|
| 74 |
+
if subitem.is_file():
|
| 75 |
+
logger.info(f" π {subitem.name} ({subitem.stat().st_size} bytes)")
|
| 76 |
+
else:
|
| 77 |
+
logger.info(f" π {subitem.name}/")
|
| 78 |
+
# Check one level deeper for model files
|
| 79 |
+
if subitem.is_dir():
|
| 80 |
+
for subsubitem in subitem.iterdir():
|
| 81 |
+
if subsubitem.is_file() and subsubitem.name in MODEL_FILES:
|
| 82 |
+
logger.info(f" π― FOUND MODEL FILE: {subsubitem}")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.warning(f" Error exploring {item.name}: {e}")
|
| 85 |
+
|
| 86 |
+
# Also check for model files directly in root
|
| 87 |
+
logger.info("π Checking for model files directly in volume root:")
|
| 88 |
+
for model_file in MODEL_FILES:
|
| 89 |
+
root_file = volume_root / model_file
|
| 90 |
+
if root_file.exists():
|
| 91 |
+
logger.info(f" π― FOUND: {model_file} in root ({root_file.stat().st_size} bytes)")
|
| 92 |
+
else:
|
| 93 |
+
logger.error(f"β Volume root does not exist: {volume_root}")
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
# Since training completed successfully, look for model files in all possible locations
|
| 97 |
+
model_locations = [
|
| 98 |
+
Path(VOLUME_PATH), # Root of volume (where final model was saved)
|
| 99 |
+
Path(VOLUME_PATH) / "models" / "refined_model", # Refined model directory
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
synced_files = []
|
| 103 |
+
|
| 104 |
+
for location in model_locations:
|
| 105 |
+
logger.info(f"π Checking model location: {location}")
|
| 106 |
+
|
| 107 |
+
if not location.exists():
|
| 108 |
+
logger.info(f" β οΈ Location does not exist: {location}")
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# Try to download each model file directly
|
| 112 |
+
for model_file in MODEL_FILES:
|
| 113 |
+
source_path = location / model_file
|
| 114 |
+
dest_path = output_path / model_file
|
| 115 |
+
|
| 116 |
+
if source_path.exists():
|
| 117 |
+
try:
|
| 118 |
+
shutil.copy2(source_path, dest_path)
|
| 119 |
+
synced_files.append(model_file)
|
| 120 |
+
logger.info(f"β
Downloaded: {model_file}")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning(f"β οΈ Failed to copy {model_file}: {e}")
|
| 123 |
+
|
| 124 |
+
if synced_files:
|
| 125 |
+
logger.info(f"π Successfully synced {len(synced_files)} model files:")
|
| 126 |
+
for file in synced_files:
|
| 127 |
+
logger.info(f" β {file}")
|
| 128 |
+
return True
|
| 129 |
+
logger.error("β No model files found to sync")
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def sync_analysis_files(output_dir: str) -> bool:
|
| 134 |
+
"""Download analysis reports and charts from Beam volume."""
|
| 135 |
+
logger.info("π Syncing analysis files from Beam volume...")
|
| 136 |
+
|
| 137 |
+
output_path = Path(output_dir)
|
| 138 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
synced_files = []
|
| 141 |
+
|
| 142 |
+
# Sync analysis reports (including README.md)
|
| 143 |
+
analysis_reports_dir = Path(VOLUME_PATH) / "analysis_results" / "reports"
|
| 144 |
+
if analysis_reports_dir.exists():
|
| 145 |
+
for report_file in analysis_reports_dir.glob("*.md"):
|
| 146 |
+
dest_path = output_path / report_file.name
|
| 147 |
+
try:
|
| 148 |
+
shutil.copy2(report_file, dest_path)
|
| 149 |
+
synced_files.append(str(report_file.name))
|
| 150 |
+
logger.info(f"β
Downloaded report: {report_file.name}")
|
| 151 |
+
|
| 152 |
+
# Special handling for README.md - copy to root
|
| 153 |
+
if report_file.name in {"analysis_report.md", "README.md"}:
|
| 154 |
+
root_readme = Path(output_dir) / "README.md"
|
| 155 |
+
shutil.copy2(report_file, root_readme)
|
| 156 |
+
logger.info("β
Updated root README.md")
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.warning(f"β οΈ Failed to copy {report_file.name}: {e}")
|
| 160 |
+
|
| 161 |
+
# Sync charts
|
| 162 |
+
charts_dir = Path(VOLUME_PATH) / "analysis_results" / "charts"
|
| 163 |
+
local_charts_dir = output_path / "charts"
|
| 164 |
+
if charts_dir.exists():
|
| 165 |
+
local_charts_dir.mkdir(exist_ok=True)
|
| 166 |
+
|
| 167 |
+
for chart_file in charts_dir.glob("*"):
|
| 168 |
+
if chart_file.is_file():
|
| 169 |
+
dest_path = local_charts_dir / chart_file.name
|
| 170 |
+
try:
|
| 171 |
+
shutil.copy2(chart_file, dest_path)
|
| 172 |
+
synced_files.append(f"charts/{chart_file.name}")
|
| 173 |
+
logger.info(f"β
Downloaded chart: {chart_file.name}")
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.warning(f"β οΈ Failed to copy chart {chart_file.name}: {e}")
|
| 176 |
+
|
| 177 |
+
# Sync evaluation results
|
| 178 |
+
eval_dir = Path(VOLUME_PATH) / "evaluation_results"
|
| 179 |
+
local_eval_dir = output_path / "evaluation_results"
|
| 180 |
+
if eval_dir.exists():
|
| 181 |
+
local_eval_dir.mkdir(exist_ok=True)
|
| 182 |
+
|
| 183 |
+
for eval_file in eval_dir.glob("*.json"):
|
| 184 |
+
dest_path = local_eval_dir / eval_file.name
|
| 185 |
+
try:
|
| 186 |
+
shutil.copy2(eval_file, dest_path)
|
| 187 |
+
synced_files.append(f"evaluation_results/{eval_file.name}")
|
| 188 |
+
logger.info(f"β
Downloaded evaluation: {eval_file.name}")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
logger.warning(f"β οΈ Failed to copy evaluation {eval_file.name}: {e}")
|
| 191 |
+
|
| 192 |
+
if synced_files:
|
| 193 |
+
logger.info(f"π Successfully synced {len(synced_files)} analysis files:")
|
| 194 |
+
for file in synced_files[:10]: # Show first 10
|
| 195 |
+
logger.info(f" β {file}")
|
| 196 |
+
if len(synced_files) > 10:
|
| 197 |
+
logger.info(f" ... and {len(synced_files) - 10} more files")
|
| 198 |
+
return True
|
| 199 |
+
logger.error("β No analysis files found to sync")
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def sync_files(
|
| 204 |
+
model_files: bool = False,
|
| 205 |
+
analysis_files: bool = False,
|
| 206 |
+
all_files: bool = False,
|
| 207 |
+
output_dir: str = ".",
|
| 208 |
+
) -> None:
|
| 209 |
+
"""Main sync function to download files from Beam volume."""
|
| 210 |
+
logger.info("π Starting file sync from Beam volume")
|
| 211 |
+
logger.info(f"π Local output directory: {output_dir}")
|
| 212 |
+
|
| 213 |
+
# Initialize Beam utilities (read-only)
|
| 214 |
+
try:
|
| 215 |
+
volume_mgr, checkpoint_mgr, model_mgr, eval_mgr = create_beam_utilities(VOLUME_NAME, VOLUME_PATH)
|
| 216 |
+
logger.info(f"β
Connected to Beam volume: {VOLUME_NAME}")
|
| 217 |
+
except Exception:
|
| 218 |
+
logger.exception("β Failed to connect to Beam volume")
|
| 219 |
+
logger.info("Make sure you have run the distillation/evaluation on Beam first")
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
# Check what files to sync
|
| 223 |
+
sync_model = model_files or all_files
|
| 224 |
+
sync_analysis = analysis_files or all_files
|
| 225 |
+
|
| 226 |
+
if not (sync_model or sync_analysis):
|
| 227 |
+
logger.error("β No file types specified. Use --model-files, --analysis-files, or --all")
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
success_count = 0
|
| 231 |
+
|
| 232 |
+
# Sync model files
|
| 233 |
+
if sync_model:
|
| 234 |
+
logger.info("\n" + "=" * 60) # noqa: G003
|
| 235 |
+
logger.info("MODEL FILES SYNC")
|
| 236 |
+
logger.info("=" * 60)
|
| 237 |
+
if sync_model_files(output_dir):
|
| 238 |
+
success_count += 1
|
| 239 |
+
|
| 240 |
+
# Sync analysis files
|
| 241 |
+
if sync_analysis:
|
| 242 |
+
logger.info("\n" + "=" * 60) # noqa: G003
|
| 243 |
+
logger.info("ANALYSIS FILES SYNC")
|
| 244 |
+
logger.info("=" * 60)
|
| 245 |
+
if sync_analysis_files(output_dir):
|
| 246 |
+
success_count += 1
|
| 247 |
+
|
| 248 |
+
# Summary
|
| 249 |
+
logger.info("\n" + "=" * 60) # noqa: G003
|
| 250 |
+
logger.info("SYNC SUMMARY")
|
| 251 |
+
logger.info("=" * 60)
|
| 252 |
+
|
| 253 |
+
total_requested = sum([sync_model, sync_analysis])
|
| 254 |
+
|
| 255 |
+
if success_count == total_requested:
|
| 256 |
+
logger.info("π All requested files synced successfully!")
|
| 257 |
+
elif success_count > 0:
|
| 258 |
+
logger.info(f"β οΈ Partial sync: {success_count}/{total_requested} file types synced")
|
| 259 |
+
else:
|
| 260 |
+
logger.error("β No files were synced")
|
| 261 |
+
|
| 262 |
+
logger.info(f"π Files saved to: {Path(output_dir).absolute()}")
|