|
|
import time |
|
|
import pandas as pd |
|
|
from collections import namedtuple |
|
|
from tqdm import tqdm |
|
|
import ast |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import os |
|
|
import multiprocessing |
|
|
from neo4j.exceptions import TransientError |
|
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 5000 |
|
|
NUM_THREADS = multiprocessing.cpu_count() |
|
|
PROCESSED_IDS_FILE = "processed_ids.txt" |
|
|
|
|
|
|
|
|
def reset_database(driver): |
|
|
""" |
|
|
Efface TOUTES les données de la base Neo4j ET le fichier de suivi des IDs traités. |
|
|
À n'utiliser que pour une réinitialisation complète. |
|
|
""" |
|
|
|
|
|
with driver.session() as session: |
|
|
result = session.run("RETURN 1 AS test") |
|
|
print("Connexion OK, test result:", result.single()["test"]) |
|
|
session.run("MATCH (n) DETACH DELETE n") |
|
|
|
|
|
|
|
|
if os.path.exists(PROCESSED_IDS_FILE): |
|
|
os.remove(PROCESSED_IDS_FILE) |
|
|
|
|
|
|
|
|
def parse_list_field(value): |
|
|
""" |
|
|
Parse une chaîne de caractères qui représente une liste (ex: "['model1', 'model2']") |
|
|
en une véritable liste Python. Gère les cas où la valeur est simple ou vide. |
|
|
""" |
|
|
if isinstance(value, str) and pd.notna(value): |
|
|
try: |
|
|
parsed = ast.literal_eval(value) |
|
|
if isinstance(parsed, list): |
|
|
return parsed |
|
|
except Exception: |
|
|
pass |
|
|
return [value] if value else [] |
|
|
|
|
|
def load_processed_ids(): |
|
|
"""Charge l'ensemble des IDs déjà traités depuis le fichier de suivi.""" |
|
|
if os.path.exists(PROCESSED_IDS_FILE): |
|
|
with open(PROCESSED_IDS_FILE, "r", encoding="utf-8") as f: |
|
|
return set(line.strip() for line in f) |
|
|
return set() |
|
|
|
|
|
def append_processed_ids(ids): |
|
|
"""Ajoute une liste d'IDs au fichier de suivi.""" |
|
|
with open(PROCESSED_IDS_FILE, "a", encoding="utf-8") as f: |
|
|
for i in ids: |
|
|
f.write(f"{i}\n") |
|
|
|
|
|
def run_with_retry(session, query, parameters=None, retries=3, delay=1): |
|
|
""" |
|
|
Exécute une requête Cypher avec une logique de réessai en cas d'erreur transitoire |
|
|
""" |
|
|
for attempt in range(retries): |
|
|
try: |
|
|
session.run(query, parameters) |
|
|
return |
|
|
except TransientError as e: |
|
|
if attempt < retries - 1: |
|
|
time.sleep(delay) |
|
|
continue |
|
|
else: |
|
|
raise |
|
|
|
|
|
def process_batch(rows, fieldnames, driver): |
|
|
""" |
|
|
Traite un lot (batch) de lignes du CSV et les insère dans Neo4j. |
|
|
C'est la fonction "worker" qui sera exécutée en parallèle. |
|
|
""" |
|
|
normalized_fields = [f.strip().replace(" ", "_").replace("-", "_") for f in fieldnames] |
|
|
Row = namedtuple("Row", normalized_fields) |
|
|
ids_successfully_processed = [] |
|
|
|
|
|
with driver.session() as session: |
|
|
for row in rows: |
|
|
obj = Row(**{k: v if v != "" else None for k, v in row.items()}) |
|
|
data = obj._asdict() |
|
|
|
|
|
if not data.get("id") or pd.isna(data.get("id")) or pd.isna(data.get("author")) or str(data.get("id")).strip() == "": |
|
|
continue |
|
|
|
|
|
base_models = parse_list_field(data.get("base_model")) |
|
|
base_model_rels = parse_list_field(data.get("base_model_relation")) |
|
|
datasets = parse_list_field(data.get("dataset")) |
|
|
orgs_author_model = parse_list_field(data.get("organizations_author_model")) |
|
|
orgs_author_dataset = parse_list_field(data.get("organizations_author_dataset")) |
|
|
|
|
|
if data.get("id") and data.get("author"): |
|
|
run_with_retry(session, """ |
|
|
MERGE (m:Model {name: $id}) |
|
|
SET m.downloads = $downloadsAllTime, |
|
|
m.task = $pipeline_tag, |
|
|
m.createdAt = $createdAt, |
|
|
m.parameters = $total_parameters_formatted, |
|
|
m.likes = $likes, |
|
|
m.license = $license |
|
|
MERGE (a:Author {name: $author}) |
|
|
SET a.type = $author_type, |
|
|
a.followers = $followers_count_author_model |
|
|
WITH a |
|
|
MATCH (m:Model {name: $id}) |
|
|
MERGE (a)-[p:POSTED]->(m) |
|
|
SET p.name = "A publié" |
|
|
""", data) |
|
|
|
|
|
|
|
|
orgs_data = [ |
|
|
{"org": org, "author": data["author"]} |
|
|
for org in orgs_author_model |
|
|
if pd.notna(org) and data.get("author") |
|
|
] |
|
|
if orgs_data: |
|
|
run_with_retry(session, """ |
|
|
UNWIND $orgs_data AS row |
|
|
MERGE (o:Author {name: row.org}) |
|
|
WITH o, row |
|
|
MATCH (a:Author {name: row.author}) |
|
|
MERGE (a)-[r:IS_IN]->(o) |
|
|
SET r.name = "Fait partie de cette organisation", a.type = "personne",o.type = "organisation" |
|
|
""", {"orgs_data": orgs_data}) |
|
|
|
|
|
|
|
|
|
|
|
if len(base_models) == len(base_model_rels) : |
|
|
base_model_data = [ |
|
|
{"bm": bm, "id": data["id"], "rel": rel} |
|
|
for bm, rel in zip(base_models, base_model_rels) |
|
|
if pd.notna(bm) and data.get("id") |
|
|
] |
|
|
elif len(base_models) >len(base_model_rels) : |
|
|
if base_model_rels==['merge'] : |
|
|
base_model_data = [ |
|
|
{"bm": bm, "id": data["id"], "rel": "merge"} |
|
|
for bm in base_models |
|
|
if pd.notna(bm) and data.get("id") |
|
|
] |
|
|
else : |
|
|
base_model_data = [ |
|
|
{"bm": bm, "id": data["id"], "rel": "A généré"} |
|
|
for bm in base_models |
|
|
if pd.notna(bm) and data.get("id") |
|
|
] |
|
|
|
|
|
if base_model_data: |
|
|
run_with_retry(session, """ |
|
|
UNWIND $base_model_data AS row |
|
|
MERGE (bm:Model {name: row.bm}) |
|
|
WITH bm, row |
|
|
MATCH (m:Model {name: row.id}) |
|
|
MERGE (bm)-[r:USED_IN]->(m) |
|
|
SET r.name = row.rel |
|
|
""", {"base_model_data": base_model_data}) |
|
|
|
|
|
|
|
|
datasets_data = [ |
|
|
{"ds": ds, "downloads": data.get("downloads_dataset"), |
|
|
"createdAt_dataset": data.get("createdAt_dataset"), "id": data["id"]} |
|
|
for ds in datasets |
|
|
if pd.notna(ds) and data.get("id") |
|
|
] |
|
|
if datasets_data and data.get("author_dataset") and data.get("dataset") and pd.notna(data.get("author_dataset")): |
|
|
run_with_retry(session, """ |
|
|
UNWIND $datasets_data AS row |
|
|
MERGE (d:Dataset {name: row.ds}) |
|
|
SET d.downloads = row.downloads, |
|
|
d.createdAt_dataset = row.createdAt_dataset |
|
|
WITH d, row |
|
|
MATCH (m:Model {name: row.id}) |
|
|
MERGE (d)-[r:USED_IN]->(m) |
|
|
SET r.name = "A été utilisé dans ce modèle" |
|
|
""", {"datasets_data": datasets_data}) |
|
|
|
|
|
|
|
|
run_with_retry(session, """ |
|
|
MERGE (ad:Author {name: $author_dataset}) |
|
|
SET ad.type = $author_dataset_type, |
|
|
ad.followers = $followers_count_author_dataset |
|
|
WITH ad |
|
|
MATCH (d:Dataset {name: $dataset}) |
|
|
MERGE (ad)-[r:POSTED]->(d) |
|
|
SET r.name = "A publié" |
|
|
""", data) |
|
|
|
|
|
|
|
|
orgs_dataset_data = [ |
|
|
{"org": org, "author_dataset": data["author_dataset"]} |
|
|
for org in orgs_author_dataset |
|
|
if pd.notna(org) and data.get("author_dataset") |
|
|
] |
|
|
if orgs_dataset_data and pd.notna(data.get("author_dataset")): |
|
|
run_with_retry(session, """ |
|
|
UNWIND $orgs_data AS row |
|
|
MERGE (o:Author {name: row.org}) |
|
|
WITH o, row |
|
|
MATCH (ad:Author {name: row.author_dataset}) |
|
|
MERGE (ad)-[r:IS_IN]->(o) |
|
|
SET r.name = "Fait partie de cette organisation", ad.type = "personne",o.type = "organisation" |
|
|
""", {"orgs_data": orgs_dataset_data}) |
|
|
|
|
|
ids_successfully_processed.append(data["id"]) |
|
|
|
|
|
if ids_successfully_processed: |
|
|
append_processed_ids(ids_successfully_processed) |
|
|
|
|
|
|
|
|
def insert_parallel(csv_file_path, driver, processed_ids): |
|
|
|
|
|
df = pd.read_csv(csv_file_path) |
|
|
df = df.loc[:, ~df.columns.str.contains('^Unnamed')] |
|
|
|
|
|
df = df[~df["id"].isnull()] |
|
|
df = df[df["id"].astype(str).str.strip() != ""] |
|
|
|
|
|
|
|
|
df = df[~df["id"].isin(processed_ids)] |
|
|
|
|
|
records = df.to_dict(orient="records") |
|
|
fieldnames = list(df.columns) |
|
|
|
|
|
batch = [] |
|
|
futures = [] |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: |
|
|
for row in tqdm(records, desc="Lecture CSV"): |
|
|
batch.append(row) |
|
|
if len(batch) == BATCH_SIZE: |
|
|
futures.append(executor.submit(process_batch, batch.copy(), fieldnames, driver)) |
|
|
batch = [] |
|
|
|
|
|
if batch: |
|
|
futures.append(executor.submit(process_batch, batch.copy(), fieldnames, driver)) |
|
|
|
|
|
for future in tqdm(futures, desc="Traitement parallélisé"): |
|
|
future.result() |
|
|
|
|
|
|