|
|
| import os |
| import sys |
| import argparse |
|
|
| import datetime |
| import torch |
| import json |
| from pathlib import Path |
| from tqdm import tqdm |
| from dotenv import load_dotenv |
| import huggingface_hub |
| import logging |
| import multiprocessing as mp |
| from collections import Counter, defaultdict |
|
|
| logging.getLogger("httpx").setLevel(logging.WARNING) |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
| logging.getLogger("huggingface_hub").setLevel(logging.WARNING) |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps |
| from scripts.compute_quality_score import get_token_quality_scores |
| from data.data_loader import summarize_context_window |
| from data.quant_ohlc_feature_schema import FEATURE_VERSION |
|
|
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
| _worker_dataset = None |
| _worker_return_class_map = None |
| _worker_quality_scores_map = None |
| _worker_encoder = None |
|
|
|
|
| def _to_int_list(values): |
| if values is None: |
| return [] |
| if isinstance(values, torch.Tensor): |
| return [int(v) for v in values.tolist()] |
| return [int(v) for v in values] |
|
|
|
|
| def _to_float_list(values): |
| if values is None: |
| return [] |
| if isinstance(values, torch.Tensor): |
| return [float(v) for v in values.tolist()] |
| return [float(v) for v in values] |
|
|
|
|
| def _representative_context_polarity(context): |
| labels = _to_float_list(context.get("labels")) |
| mask = _to_int_list(context.get("labels_mask")) |
| valid_returns = [label for label, keep in zip(labels, mask) if keep > 0] |
| if not valid_returns: |
| return "negative" |
| return "positive" if max(valid_returns) > 0.0 else "negative" |
|
|
|
|
| def _class_polarity_targets(class_id, target_contexts_per_class, positive_balance_min_class, positive_ratio): |
| if class_id >= positive_balance_min_class: |
| positive_target = int(round(target_contexts_per_class * positive_ratio)) |
| positive_target = min(max(positive_target, 0), target_contexts_per_class) |
| else: |
| positive_target = 0 |
| return { |
| "positive": positive_target, |
| "negative": max(0, target_contexts_per_class - positive_target), |
| } |
|
|
|
|
| def _remaining_polarity_targets(class_id, accepted_counts, target_contexts_per_class, positive_balance_min_class, positive_ratio): |
| targets = _class_polarity_targets( |
| class_id=class_id, |
| target_contexts_per_class=target_contexts_per_class, |
| positive_balance_min_class=positive_balance_min_class, |
| positive_ratio=positive_ratio, |
| ) |
| class_counts = accepted_counts[class_id] |
| return { |
| "positive": max(0, targets["positive"] - class_counts["positive"]), |
| "negative": max(0, targets["negative"] - class_counts["negative"]), |
| } |
|
|
|
|
| def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None): |
| if len(contexts) <= max_keep: |
| polarity_counts = {} |
| for context in contexts: |
| polarity = _representative_context_polarity(context) |
| polarity_counts[polarity] = polarity_counts.get(polarity, 0) + 1 |
| context["representative_context_polarity"] = polarity |
| return contexts, polarity_counts |
|
|
| positive_bucket = [] |
| negative_bucket = [] |
| for context in contexts: |
| polarity = _representative_context_polarity(context) |
| context["representative_context_polarity"] = polarity |
| if polarity == "positive": |
| positive_bucket.append(context) |
| else: |
| negative_bucket.append(context) |
|
|
| selected = [] |
| polarity_counts = {"positive": 0, "negative": 0} |
| desired_positive = max(0, int(desired_positive)) if desired_positive is not None else None |
| desired_negative = max(0, int(desired_negative)) if desired_negative is not None else None |
|
|
| if desired_positive is not None or desired_negative is not None: |
| target_positive = min(desired_positive or 0, max_keep, len(positive_bucket)) |
| target_negative = min(desired_negative or 0, max_keep - target_positive, len(negative_bucket)) |
|
|
| while polarity_counts["positive"] < target_positive and positive_bucket: |
| selected.append(positive_bucket.pop()) |
| polarity_counts["positive"] += 1 |
| while polarity_counts["negative"] < target_negative and negative_bucket: |
| selected.append(negative_bucket.pop()) |
| polarity_counts["negative"] += 1 |
|
|
| prefer_positive = len(positive_bucket) >= len(negative_bucket) |
|
|
| while len(selected) < max_keep and (positive_bucket or negative_bucket): |
| if prefer_positive and positive_bucket: |
| selected.append(positive_bucket.pop()) |
| polarity_counts["positive"] += 1 |
| elif not prefer_positive and negative_bucket: |
| selected.append(negative_bucket.pop()) |
| polarity_counts["negative"] += 1 |
| elif positive_bucket: |
| selected.append(positive_bucket.pop()) |
| polarity_counts["positive"] += 1 |
| elif negative_bucket: |
| selected.append(negative_bucket.pop()) |
| polarity_counts["negative"] += 1 |
| prefer_positive = not prefer_positive |
|
|
| return selected[:max_keep], polarity_counts |
|
|
|
|
| def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map): |
| global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map |
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
|
|
| clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port']) |
| neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password'])) |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
| |
| |
| from models.multi_modal_processor import MultiModalEncoder |
| |
| global _worker_encoder |
| try: |
| _worker_encoder = MultiModalEncoder( |
| model_id="google/siglip-so400m-patch16-256-i18n", |
| device="cuda", |
| dtype=torch.float16 |
| ) |
| except Exception as e: |
| print(f"WARN: Failed to initialize MultiModalEncoder on worker: {e}") |
| _worker_encoder = None |
|
|
| _worker_dataset = OracleDataset( |
| data_fetcher=data_fetcher, |
| start_date=dataset_config['start_date'], |
| horizons_seconds=dataset_config['horizons_seconds'], |
| quantiles=dataset_config['quantiles'], |
| min_trade_usd=dataset_config['min_trade_usd'], |
| max_seq_len=dataset_config['max_seq_len'], |
| p99_clamps=dataset_config.get('p99_clamps') |
| ) |
| _worker_dataset.sampled_mints = dataset_config['sampled_mints'] |
| _worker_return_class_map = return_class_map |
| _worker_quality_scores_map = quality_scores_map |
|
|
|
|
| def _process_single_token_context(args): |
| idx, mint_addr, samples_per_token, oversample_factor, desired_positive, desired_negative = args |
| global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map |
| try: |
| class_id = _worker_return_class_map.get(mint_addr) |
| if class_id is None: |
| return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr} |
| |
| |
| global _worker_encoder |
| encoder = _worker_encoder |
| |
| if encoder is None: |
| print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True) |
| |
| candidate_contexts = _worker_dataset.__cacheitem_context__( |
| idx, |
| num_samples_per_token=max(samples_per_token, samples_per_token * max(1, oversample_factor)), |
| encoder=encoder, |
| ) |
| contexts, polarity_counts = _select_contexts_by_polarity( |
| candidate_contexts, |
| samples_per_token, |
| desired_positive=desired_positive, |
| desired_negative=desired_negative, |
| ) |
| if not contexts: |
| return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr} |
| q_score = _worker_quality_scores_map.get(mint_addr) |
| if q_score is None: |
| return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr} |
| return { |
| 'status': 'success', |
| 'mint': mint_addr, |
| 'class_id': class_id, |
| 'q_score': q_score, |
| 'n_contexts': len(contexts), |
| 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, |
| 'contexts': contexts, |
| 'polarity_counts': polarity_counts, |
| } |
| except Exception as e: |
| import traceback |
| return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main(): |
| load_dotenv() |
| mp.set_start_method('spawn', force=True) |
|
|
| hf_token = os.getenv("HF_TOKEN") |
| if hf_token: |
| print(f"INFO: Logging in to Hugging Face...") |
| huggingface_hub.login(token=hf_token) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--output_dir", type=str, default="data/cache") |
| parser.add_argument("--start_date", type=str, default=None) |
|
|
| parser.add_argument("--min_trade_usd", type=float, default=0.0) |
|
|
| parser.add_argument("--context_length", type=int, default=8192) |
| parser.add_argument("--min_trades", type=int, default=10) |
| parser.add_argument("--samples_per_token", type=int, default=1) |
| parser.add_argument("--target_contexts_per_class", type=int, default=2500) |
| parser.add_argument("--context_oversample_factor", type=int, default=4) |
| parser.add_argument("--positive_balance_min_class", type=int, default=2) |
| parser.add_argument("--positive_context_ratio", type=float, default=0.5) |
| parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420]) |
| parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9]) |
| parser.add_argument("--num_workers", type=int, default=1) |
| parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost")) |
| parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000))) |
| parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687")) |
| parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j")) |
| parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password")) |
| args = parser.parse_args() |
|
|
| if args.num_workers == 0: |
| args.num_workers = max(1, mp.cpu_count() - 4) |
| if args.num_workers != 1: |
| raise RuntimeError("Quota-based caching requires --num_workers 1 so class counters remain exact.") |
| if args.target_contexts_per_class <= 0: |
| raise RuntimeError("--target_contexts_per_class must be positive.") |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None |
|
|
| print(f"INFO: Initializing DB Connections...") |
| clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port) |
| neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password)) |
|
|
| try: |
|
|
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
|
|
| print("INFO: Fetching Return Classification Map...") |
| return_class_map, _ = get_return_class_map(clickhouse_client) |
| print(f"INFO: Loaded {len(return_class_map)} classified tokens.") |
|
|
| print("INFO: Computing P99 clamp values...") |
| p99_clamps = compute_p99_clamps(clickhouse_client) |
|
|
| print("INFO: Fetching Quality Scores...") |
| quality_scores_map = get_token_quality_scores(clickhouse_client) |
| print(f"INFO: Loaded {len(quality_scores_map)} quality scores.") |
|
|
| dataset = OracleDataset(data_fetcher=data_fetcher, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps) |
|
|
| if len(dataset) == 0: |
| print("WARNING: No samples. Exiting.") |
| return |
|
|
| |
| original_size = len(dataset.sampled_mints) |
| filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map] |
| print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens") |
|
|
| |
| print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...") |
| trade_counts = clickhouse_client.execute(""" |
| SELECT base_address, count() as cnt |
| FROM trades |
| GROUP BY base_address |
| HAVING cnt >= %(min_trades)s |
| """, {'min_trades': args.min_trades}) |
| valid_tokens = {row[0] for row in trade_counts} |
| pre_filter_size = len(filtered_mints) |
| filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens] |
| print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)") |
|
|
| |
| pre_quality_size = len(filtered_mints) |
| filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map] |
| print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens") |
|
|
| if len(filtered_mints) == 0: |
| print("WARNING: No tokens after filtering.") |
| return |
|
|
| print(f"INFO: Workers: {args.num_workers}") |
|
|
| db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password} |
| dataset_config = {'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps} |
|
|
| import random |
|
|
| eligible_class_counts = Counter() |
| for i, m in enumerate(filtered_mints): |
| cid = return_class_map.get(m['mint_address']) |
| if cid is not None: |
| eligible_class_counts[cid] += 1 |
|
|
| print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}") |
|
|
| class_targets = { |
| int(class_id): int(args.target_contexts_per_class) |
| for class_id in sorted(eligible_class_counts.keys()) |
| } |
| class_polarity_targets = { |
| class_id: _class_polarity_targets( |
| class_id=class_id, |
| target_contexts_per_class=args.target_contexts_per_class, |
| positive_balance_min_class=args.positive_balance_min_class, |
| positive_ratio=args.positive_context_ratio, |
| ) |
| for class_id in class_targets |
| } |
|
|
| target_total = args.target_contexts_per_class * len(class_targets) |
| print(f"INFO: Target total: {target_total}, Target per class: {args.target_contexts_per_class}") |
| print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}") |
| print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}") |
|
|
| tasks = list(enumerate(filtered_mints)) |
| random.shuffle(tasks) |
| print(f"INFO: Total candidate tokens: {len(tasks)}") |
|
|
| success_count, skipped_count, error_count = 0, 0, 0 |
| class_distribution = defaultdict(int) |
| polarity_distribution = defaultdict(int) |
| file_class_map = {} |
| file_context_bucket_map = {} |
| file_context_summary_map = {} |
| accepted_counts = defaultdict(lambda: {"total": 0, "positive": 0, "negative": 0}) |
|
|
| |
| existing_files = set(f.name for f in output_dir.glob("sample_*.pt")) |
| if existing_files: |
| already_cached = 0 |
| for f in sorted(output_dir.glob("sample_*.pt")): |
| try: |
| cached = torch.load(f, map_location="cpu", weights_only=False) |
| except Exception: |
| continue |
| class_id = cached.get("class_id") |
| if class_id is None or int(class_id) not in class_targets: |
| continue |
| class_id = int(class_id) |
| context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask")) |
| polarity = _representative_context_polarity(cached) |
| file_class_map[f.name] = class_id |
| file_context_bucket_map[f.name] = context_summary["context_bucket"] |
| file_context_summary_map[f.name] = context_summary |
| class_distribution[class_id] += 1 |
| polarity_distribution[polarity] += 1 |
| accepted_counts[class_id]["total"] += 1 |
| accepted_counts[class_id][polarity] += 1 |
| already_cached += 1 |
| print(f"INFO: Resume: counted {already_cached} cached contexts toward quotas.") |
|
|
| print(f"INFO: Starting to cache {len(tasks)} tokens...") |
| process_fn = _process_single_token_context |
|
|
| import time as _time |
|
|
| def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count): |
| """Print progress with rolling ETA every 10 tokens.""" |
| if (task_num + 1) % 10 == 0 and recent_times: |
| avg_time = sum(recent_times) / len(recent_times) |
| remaining = total - (task_num + 1) |
| eta_seconds = avg_time * remaining |
| eta_hours = eta_seconds / 3600 |
| wall_elapsed = _time.perf_counter() - start_time |
| speed = (task_num + 1) / wall_elapsed |
| tqdm.write( |
| f" [PROGRESS] {task_num+1}/{total} | " |
| f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | " |
| f"Avg: {avg_time:.1f}s/tok | " |
| f"ETA: {eta_hours:.1f}h | " |
| f"OK: {success_count} Skip: {skipped_count} Err: {error_count}" |
| ) |
|
|
| |
| error_log_path = Path(args.output_dir) / "cache_errors.log" |
| error_samples = [] |
|
|
| print("INFO: Single-threaded mode...") |
| _init_worker(db_config, dataset_config, return_class_map, quality_scores_map) |
| start_time = _time.perf_counter() |
| recent_times = [] |
| completed_classes = set() |
| for task_num, (idx, mint_record) in enumerate(tqdm(tasks, desc="Caching", unit="tok")): |
| mint_addr = mint_record["mint_address"] |
| class_id = return_class_map.get(mint_addr) |
| if class_id is None: |
| skipped_count += 1 |
| _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) |
| continue |
| if accepted_counts[class_id]["total"] >= class_targets[class_id]: |
| if class_id not in completed_classes: |
| completed_classes.add(class_id) |
| tqdm.write(f"INFO: Class {class_id} quota filled. Skipping remaining tokens for this class.") |
| skipped_count += 1 |
| _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) |
| continue |
|
|
| remaining_total = class_targets[class_id] - accepted_counts[class_id]["total"] |
| remaining_polarity = _remaining_polarity_targets( |
| class_id=class_id, |
| accepted_counts=accepted_counts, |
| target_contexts_per_class=args.target_contexts_per_class, |
| positive_balance_min_class=args.positive_balance_min_class, |
| positive_ratio=args.positive_context_ratio, |
| ) |
| desired_positive = min(remaining_polarity["positive"], args.samples_per_token, remaining_total) |
| desired_negative = min( |
| remaining_polarity["negative"], |
| max(0, min(args.samples_per_token, remaining_total) - desired_positive), |
| ) |
| samples_to_keep = min(args.samples_per_token, remaining_total) |
| task = ( |
| idx, |
| mint_addr, |
| samples_to_keep, |
| args.context_oversample_factor, |
| desired_positive, |
| desired_negative, |
| ) |
|
|
| t0 = _time.perf_counter() |
| result = process_fn(task) |
| elapsed = _time.perf_counter() - t0 |
| recent_times.append(elapsed) |
| if len(recent_times) > 50: |
| recent_times.pop(0) |
|
|
| if result["status"] == "success": |
| saved_contexts = 0 |
| for ctx in result.get("contexts", []): |
| if accepted_counts[class_id]["total"] >= class_targets[class_id]: |
| break |
|
|
| polarity = _representative_context_polarity(ctx) |
| remaining_polarity = _remaining_polarity_targets( |
| class_id=class_id, |
| accepted_counts=accepted_counts, |
| target_contexts_per_class=args.target_contexts_per_class, |
| positive_balance_min_class=args.positive_balance_min_class, |
| positive_ratio=args.positive_context_ratio, |
| ) |
| other_polarity = "negative" if polarity == "positive" else "positive" |
| if remaining_polarity[polarity] <= 0 and remaining_polarity[other_polarity] > 0: |
| continue |
|
|
| ctx["quality_score"] = result["q_score"] |
| ctx["class_id"] = class_id |
| ctx["source_token"] = mint_addr |
| context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) |
| ctx["context_bucket"] = context_summary["context_bucket"] |
| ctx["context_score"] = context_summary["context_score"] |
| file_idx = accepted_counts[class_id]["total"] |
| filename = f"sample_{mint_addr[:16]}_{file_idx}.pt" |
| output_path = output_dir / filename |
| torch.save(ctx, output_path) |
|
|
| file_class_map[filename] = class_id |
| file_context_bucket_map[filename] = context_summary["context_bucket"] |
| file_context_summary_map[filename] = context_summary |
| class_distribution[class_id] += 1 |
| polarity_distribution[polarity] += 1 |
| accepted_counts[class_id]["total"] += 1 |
| accepted_counts[class_id][polarity] += 1 |
| saved_contexts += 1 |
|
|
| if saved_contexts > 0: |
| success_count += 1 |
| else: |
| skipped_count += 1 |
| elif result["status"] == "skipped": |
| skipped_count += 1 |
| else: |
| error_count += 1 |
| err_msg = result.get("error", "unknown") |
| tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}") |
| if len(error_samples) < 20: |
| error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')}) |
|
|
| if all(accepted_counts[cid]["total"] >= class_targets[cid] for cid in class_targets): |
| tqdm.write("INFO: All class quotas filled. Stopping early.") |
| _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) |
| break |
|
|
| _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) |
|
|
| |
| if error_samples: |
| with open(error_log_path, 'w') as ef: |
| for i, es in enumerate(error_samples): |
| ef.write(f"=== Error {i+1} === Token: {es['mint']}\n") |
| ef.write(f"Error: {es['error']}\n") |
| ef.write(f"Traceback:\n{es['traceback']}\n\n") |
| print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}") |
|
|
| print("INFO: Building metadata...") |
| with open(output_dir / "class_metadata.json", 'w') as f: |
| json.dump({ |
| 'file_class_map': file_class_map, |
| 'file_context_bucket_map': file_context_bucket_map, |
| 'file_context_summary_map': file_context_summary_map, |
| 'class_distribution': {str(k): v for k, v in class_distribution.items()}, |
| 'num_workers': args.num_workers, |
| 'horizons_seconds': args.horizons_seconds, |
| 'quantiles': args.quantiles, |
| 'target_total': target_total, |
| 'target_contexts_per_class': args.target_contexts_per_class, |
| 'context_polarity_distribution': polarity_distribution, |
| 'class_targets': {str(k): v for k, v in class_targets.items()}, |
| 'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()}, |
| 'accepted_counts': {str(k): v for k, v in accepted_counts.items()}, |
| 'positive_balance_min_class': args.positive_balance_min_class, |
| 'positive_context_ratio': args.positive_context_ratio, |
| 'quant_feature_version': FEATURE_VERSION, |
| }, f, indent=2) |
|
|
| print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}") |
|
|
| finally: |
| clickhouse_client.disconnect() |
| neo4j_driver.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|