lhoestq's picture
lhoestq HF Staff
fix port
33841d3
# --- Import et initialisation Flask ---
import os
from flask import Flask, request, render_template, jsonify
from neo4j import GraphDatabase, basic_auth
from graphdatascience import GraphDataScience
import app_algorithms as algo
import pandas as pd
from typing import Dict
from collections import defaultdict
import argparse
app = Flask(__name__, static_url_path="/static/") # Application Flask
# --- Connexion à Neo4j et GDS ---
NEO4J_URI = "bolt://localhost:7687"
GDS_GRAPH_NAME = "genealogie_gds"
# --- Configuration des arguments du script ---
parser = argparse.ArgumentParser(description="Script pour lancer la web application.")
parser.add_argument(
'neo4j_user', # Le nom de l'argument (positionnel, car sans tirets)
type=str, # On attend une chaîne de caractères
help="Nom de votre instance Neo4J"
)
parser.add_argument(
'neo4j_password', # Le nom de l'argument (positionnel, car sans tirets)
type=str, # On attend une chaîne de caractères
help="Mot de passe de votre instance Neo4J"
)
args = parser.parse_args()
NEO4J_PASSWORD = args.neo4j_password
NEO4J_USER = args.neo4j_user
try:
driver = GraphDatabase.driver(NEO4J_URI, auth=basic_auth(NEO4J_USER, NEO4J_PASSWORD))
gds = GraphDataScience(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
except Exception as e:
print(f"FATAL: Could not connect to Neo4j/GDS at startup. Error: {e}")
exit() # Si la connexion échoue, l'application ne peut pas tourner
# --- Projection du graphe pour GDS ---
def ensure_graph_projected(gds: GraphDataScience, graph_name):
"""
Projette deux graphes dans :
- Graphe "naturel" : relations descendant (source -> target)
- Graphe "inverse" : relations ascendant (target -> source)
Les projections sont nécessaires pour lancer des algorithmes GDS comme BFS
"""
natural_graph_name = f"{graph_name}_natural"
reverse_graph_name = f"{graph_name}_reverse"
# Suppression de toute projection existante
g_natural_exists = gds.graph.exists(natural_graph_name).exists
if g_natural_exists:
gds.graph.get(natural_graph_name).drop()
g_reverse_exists = gds.graph.exists(reverse_graph_name).exists
if g_reverse_exists:
gds.graph.get(reverse_graph_name).drop()
# --- 1. Projection pour les descendants (sens normal) ---
# On sélectionne les relations et on les projette en gardant source -> target
print("Projection du graphe naturel (descendant)...")
gds.run_cypher(f"""
MATCH (source)-[r:IS_IN|POSTED|USED_IN]->(target)
WITH gds.graph.project(
'{natural_graph_name}',
source,
target,
{{
relationshipType: type(r)
}}
) AS g
RETURN g.graphName AS graph, g.nodeCount AS nodes, g.relationshipCount AS rels
""")
print(f"Graphe '{natural_graph_name}' projeté.")
# --- 2. Projection pour les ascendants (sens inversé) ---
# On sélectionne les mêmes relations, mais on inverse source et target dans l'appel
print("Projection du graphe inversé (ascendant)...")
gds.run_cypher(f"""
MATCH (source)-[r:IS_IN|POSTED|USED_IN]->(target)
WITH gds.graph.project(
'{reverse_graph_name}',
target, // <<< Le 'target' devient la source dans la projection
source, // <<< Le 'source' devient la cible dans la projection
{{
relationshipType: type(r)
}}
) AS g
RETURN g.graphName AS graph, g.nodeCount AS nodes, g.relationshipCount AS rels
""")
print(f"Graphe '{reverse_graph_name}' projeté.")
@app.route("/") # Page d'accueil
def home():
return render_template("index.html")
@app.route("/info") # Page d'informations
def info_page():
return render_template("infos.html")
# Suggestions pour l'autocomplétion
@app.route("/autocomplete")
def autocomplete():
query = request.args.get("q", "")
node_filter = request.args.get("filter")
if not query:
return jsonify([])
# Filtrage optionnel par label (Model ou Dataset)
label_cypher = ""
if node_filter and node_filter in ["Model", "Dataset"]: # Mesure de sécurité
label_cypher = f":{node_filter}"
# Récupère les noms commençant par le préfixe fourni
cypher = f"""
MATCH (n{label_cypher})
WHERE toLower(n.name) STARTS WITH toLower($prefix)
AND n.name IS NOT NULL
RETURN n.name AS name, labels(n)[0] as label
ORDER BY size(n.name) ASC
LIMIT 10
"""
try:
results_df = gds.run_cypher(cypher, {"prefix": query})
suggestions = results_df.to_dict('records')
return jsonify(suggestions)
except Exception as e:
print(f"Autocomplete error: {e}")
return jsonify([])
@app.route("/search", methods=["GET", "POST"]) # Recherche d'un noeud
def findnode():
"""
1. Récupère le nom à chercher et la profondeur
2. Appelle ensure_graph_projected pour s'assurer que les graphes GDS existent
3. Lance l'algorithme BFS via algo.run_gds_bfs
4. Traite les résultats et construit le sous-graphe à afficher
5. Met à jour les données pour le template Flask
"""
message = None
highlights = {}
graph_data = {"nodes": [], "edges": [], "models_count": []}
current_filters = []
if request.method == "GET":
# Pour les liens depuis la page d'accueil
filter_from_url = request.args.get('filter')
if filter_from_url:
current_filters.append(filter_from_url)
else: # POST
# Pour le formulaire soumis avec les checkboxes
current_filters = request.form.getlist('filters')
search_info = {
"name": request.form.get("name", ""),
"depth": int(request.form.get("depth", 3)),
"unlimited": 'unlimited_depth' in request.form,
"filters": current_filters # On passe la liste des filtres au template
}
if request.method == "POST" and request.form.get("submit") == "find_node":
name = request.form.get("name", "").strip()
is_unlimited = 'unlimited_depth' in request.form
depth = None if is_unlimited else int(request.form.get("depth", 3))
search_info = {"name": name, "depth": request.form.get("depth", 3), "unlimited": is_unlimited}
if not name:
message = "Veuillez entrer un nom à rechercher."
return render_template("search.html", message=message, search=search_info, graph_data=graph_data, highlights=highlights)
try:
# Projection des graphes ascendant et descendant
ensure_graph_projected(gds, GDS_GRAPH_NAME)
natural_graph_name = f"{GDS_GRAPH_NAME}_natural"
reverse_graph_name = f"{GDS_GRAPH_NAME}_reverse"
# Appeler la fonction GDS BFS
gds_result = algo.run_gds_bfs(gds, natural_graph_name,reverse_graph_name, name, depth,False)
if not gds_result :
message = f"Le modèle/dataset '{name}' n'a pas été trouvé."
return render_template("search.html", message=message, search=search_info, graph_data=graph_data,highlights=highlights)
process_gds_bfs_results(gds_result, graph_data, name)
if gds_result["source_label"] == "Model" :
highlights = algo.get_genealogy_highlights(gds, name)
return render_template("search.html", message=message, search=search_info, graph_data=graph_data,highlights=highlights)
if gds_result["source_label"] == "Dataset" :
return render_template("search_dataset.html", message=message, search=search_info, graph_data=graph_data)
if not graph_data["nodes"] and not graph_data["edges"]:
message = f"Le nœud '{name}' a été trouvé, mais il n'a pas de voisins dans la profondeur spécifiée."
return render_template("search.html", message=message, search=search_info, graph_data=graph_data,highlights=highlights)
except Exception as e:
# Gérer le cas où le nœud source n'existe pas du tout
if "Failed to find a node" in str(e):
message = f"Le nœud '{name}' n'a pas été trouvé dans le graphe."
else:
print(f"GDS BFS Error: {e}")
message = f"Erreur lors de la recherche GDS: {str(e)}"
return render_template("search.html", message=message, search=search_info, graph_data=graph_data,highlights=highlights)
@app.route("/expert", methods=["GET", "POST"]) # Recherche d'un noeud
def findnode_expert():
"""
1. Récupère le nom à chercher et la profondeur
2. Appelle ensure_graph_projected pour s'assurer que les graphes GDS existent
3. Lance l'algorithme BFS via algo.run_gds_bfs
4. Traite les résultats et construit le sous-graphe à afficher
5. Met à jour les données pour le template Flask
"""
message = None
graph_data = {"nodes": [], "edges": [], "models_count": []}
current_filters = []
if request.method == "GET":
# Pour les liens depuis la page d'accueil
filter_from_url = request.args.get('filter')
if filter_from_url:
current_filters.append(filter_from_url)
else: # POST
# Pour le formulaire soumis avec les checkboxes
current_filters = request.form.getlist('filters')
search_info = {
"name": request.form.get("name", ""),
"depth": int(request.form.get("depth", 3)),
"unlimited": 'unlimited_depth' in request.form,
"filters": current_filters # On passe la liste des filtres au template
}
if request.method == "POST" and request.form.get("submit") == "findnode_expert":
name = request.form.get("name", "").strip()
is_unlimited = 'unlimited_depth' in request.form
depth = None if is_unlimited else int(request.form.get("depth", 3))
search_info = {"name": name, "depth": request.form.get("depth", 3), "unlimited": is_unlimited,"filters": current_filters }
if not name:
message = "Veuillez entrer un nom à rechercher."
return render_template("expert.html", message=message, search=search_info, graph_data=graph_data)
try:
# Projection des graphes ascendant et descendant
ensure_graph_projected(gds, GDS_GRAPH_NAME)
natural_graph_name = f"{GDS_GRAPH_NAME}_natural"
reverse_graph_name = f"{GDS_GRAPH_NAME}_reverse"
# Appeler la fonction GDS BFS
gds_result = algo.run_gds_bfs(gds, natural_graph_name,reverse_graph_name, name, depth,True)
if not gds_result :
message = f"Le noeud '{name}' n'a pas été trouvé."
return render_template("expert.html", message=message, search=search_info, graph_data=graph_data)
process_gds_bfs_results(gds_result, graph_data, name)
if not graph_data["nodes"] and not graph_data["edges"]:
message = f"Le nœud '{name}' a été trouvé, mais il n'a pas de voisins dans la profondeur spécifiée."
except Exception as e:
# Gérer le cas où le nœud source n'existe pas du tout
if "Failed to find a node" in str(e):
message = f"Le nœud '{name}' n'a pas été trouvé dans le graphe."
else:
print(f"GDS BFS Error: {e}")
message = f"Erreur lors de la recherche GDS: {str(e)}"
return render_template("expert.html", message=message, search=search_info, graph_data=graph_data)
def process_gds_bfs_results(gds_result: Dict, graph_data: Dict, origin_name: str):
"""
Transforme les résultats d'un BFS GDS en un sous-graphe utilisable pour le front-end.
Étapes principales :
1. Collecte des IDs de tous les nœuds parcourus par le BFS (ascendants et descendants).
2. Récupération des nœuds et relations réels dans Neo4j via une requête Cypher.
3. Formatage des nœuds et arêtes pour construire le dictionnaire `graph_data`.
"""
# --- PHASE 1 : Collecte des IDs de tous les nœuds visités ---
all_discovered_node_ids = set()
# Ajouter le nœud source
source_id = gds_result.get("source_node")
if source_id is not None:
all_discovered_node_ids.add(source_id)
# Ajouter les descendants (profondeur positive)
desc_df = gds_result.get("descendant")
if desc_df is not None and not desc_df.empty:
node_ids = desc_df["nodeIds"].iloc[0]
all_discovered_node_ids.update(node_ids)
# Ajouter les ascendants (profondeur négative)
asc_df = gds_result.get("ascendant")
if asc_df is not None and not asc_df.empty:
node_ids = asc_df["nodeIds"].iloc[0]
all_discovered_node_ids.update(node_ids)
if not all_discovered_node_ids:
return # Aucun nœud découvert → rien à faire
# --- PHASE 2 : Récupération du sous-graphe réel dans Neo4j ---
with driver.session() as session:
# Requête Cypher : récupère les nœuds, auteurs, datasets et relations entre eux
results = session.run("""
MATCH (n) WHERE id(n) IN $ids
OPTIONAL MATCH (author:Author)-[:POSTED]->(n)
OPTIONAL MATCH (dataset:Dataset)-[:USED_IN]->(n)
OPTIONAL MATCH (o:Model) WHERE o.name = $origin_name
CALL {
WITH n,o
OPTIONAL MATCH p = (n)-[:USED_IN*1..]->(o)
WITH n,o, length(p) AS rel_asc
OPTIONAL MATCH p= (n)<-[r:USED_IN*1..]-(o)
WITH n,o, rel_asc, length(p) AS rel_desc
OPTIONAL MATCH (ancestor:Model)-[:USED_IN*1..]->(n)
WITH n, rel_asc, rel_desc, count(DISTINCT ancestor) AS ascendantsCount
OPTIONAL MATCH (descendant:Model)<-[:USED_IN*1..]-(n)
WITH n,rel_asc, rel_desc, ascendantsCount, count(DISTINCT descendant) AS descendantsCount
OPTIONAL MATCH (citation:Model)<-[:USED_IN]-(n)
RETURN ascendantsCount, descendantsCount, count(DISTINCT citation) AS citationCount,rel_asc, rel_desc
}
WITH n, author,dataset, ascendantsCount, descendantsCount, citationCount,rel_asc, rel_desc
WITH collect({
id: id(n),
node: n,
dataset: properties(dataset),
author: properties(author),
task: n.task,
license: n.license,
createdAt:n.createdAt,
likes:n.likes,
properties: properties(n),
labels: labels(n),
ascendantsCount: ascendantsCount,
descendantsCount: descendantsCount,
citationCount: citationCount,
distance: CASE
WHEN rel_asc IS NOT NULL THEN -rel_asc
WHEN rel_desc IS NOT NULL THEN rel_desc
ELSE 0
END
}) AS nodes_data
CALL {
WITH nodes_data
UNWIND [item IN nodes_data | item.node] AS n1
UNWIND [item IN nodes_data | item.node] AS n2
MATCH (n1)-[r]-(n2)
RETURN collect(r) AS rels
}
RETURN nodes_data, rels
""", {"ids": list(all_discovered_node_ids), "origin_name": origin_name})
subgraph = results.single()
if not subgraph:
return
count=0
# --- PHASE 3 : Formatage des nœuds et des arêtes ---
for node_obj in subgraph["nodes_data"]:
# Extraire et compléter les propriétés du nœud
node_properties = dict(node_obj["properties"])
node_properties.update({
"ascendantsCount": node_obj["ascendantsCount"],
"descendantsCount": node_obj["descendantsCount"],
"citationCount": node_obj["citationCount"],
"task": node_obj.get("task"),
"license": node_obj.get("license"),
"createdAt": node_obj.get("createdAt"),
"likes": node_obj.get("likes"),
"distance": node_obj.get("distance")
})
if node_obj.get("author"):
node_properties["author"] = node_obj["author"].get("name")
if node_obj.get("dataset"):
node_properties["dataset"] = node_obj["dataset"].get("name")
node_label = list(node_obj["labels"])[0]
if node_label == "Model" :
count = count +1
node_data = algo.create_node_data(node_properties, node_label)
graph_data["nodes"].append(node_data)
graph_data["models_count"].append(count)
# Construction des arêtes, en évitant les doublons
added_edges_canonical_keys = set()
for rel_obj in subgraph["rels"]:
source_name = rel_obj.start_node["name"]
target_name = rel_obj.end_node["name"]
if not source_name or not target_name:
continue
canonical_key = tuple(sorted((source_name, target_name)))
if canonical_key not in added_edges_canonical_keys:
added_edges_canonical_keys.add(canonical_key)
edge_id = f"{source_name}_{target_name}_{type(rel_obj).__name__}"
rel_props = dict(rel_obj.items())
graph_data["edges"].append({
"id": edge_id,
"source": source_name,
"target": target_name,
"relation": rel_props.get("name")
})
if __name__ == '__main__':
# Le script de démarrage n'a plus besoin de projeter le graphe.
# Il se contente de lancer l'application. La projection se fera à la demande.
app.run(host='0.0.0.0', port=7860, debug=True)
# Le driver doit être fermé quand l'application s'arrête.
# Une manière simple est d'utiliser `atexit`
import atexit
atexit.register(lambda: driver.close())
atexit.register(lambda: gds.close())