ASI-Engineer commited on
Commit
bb20631
·
verified ·
1 Parent(s): 61876ea

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. exemples/demo_batch_hf.py +70 -80
exemples/demo_batch_hf.py CHANGED
@@ -1,24 +1,31 @@
1
  #!/usr/bin/env python3
2
  """
3
- 📦 Prédiction BATCH via API Hugging Face
4
 
5
  Usage: python demo_batch_hf.py
6
  - Utilise par défaut les CSV d'exemple du dossier
7
- - Envoie les 3 fichiers à la Space HF
8
  - Sauvegarde un CSV de résultats
9
 
10
- Option: définir HF_API_URL pour surcharger l'URL par défaut.
11
  """
12
 
13
  import os
 
14
  import pandas as pd
15
- import requests
16
  from datetime import datetime
17
 
 
 
 
 
 
 
 
18
  API_URL = os.getenv("HF_API_URL", "https://asi-engineer-oc-p5.hf.space")
19
 
20
  print("╔══════════════════════════════════════════════════════════╗")
21
- print("║ 📦 Prédiction BATCH - API Hugging Face ║")
22
  print("╚══════════════════════════════════════════════════════════╝\n")
23
  print(f"🌐 API: {API_URL}\n")
24
 
@@ -32,88 +39,71 @@ sirh_path = os.path.join(script_dir, "02_predict_batch_sirh.csv")
32
  for path in [sondage_path, eval_path, sirh_path]:
33
  if not os.path.exists(path):
34
  print(f"❌ Fichier introuvable: {path}")
35
- raise SystemExit(1)
36
 
37
  print("✅ Fichiers d'exemple détectés:")
38
  print(f" - {os.path.basename(sondage_path)}")
39
  print(f" - {os.path.basename(eval_path)}")
40
  print(f" - {os.path.basename(sirh_path)}\n")
41
 
42
- print("⏳ Envoi des fichiers à l'API HF...")
43
- files = {
44
- "sondage_file": open(sondage_path, "rb"),
45
- "eval_file": open(eval_path, "rb"),
46
- "sirh_file": open(sirh_path, "rb"),
47
- }
48
- headers = {}
49
- api_key = os.getenv("HF_API_KEY")
50
- if api_key:
51
- headers["X-API-Key"] = api_key
52
 
 
53
  try:
54
- # 1) Tente FastAPI (si exposé)
55
- r = requests.post(
56
- f"{API_URL}/predict/batch", files=files, headers=headers, timeout=90
 
 
57
  )
58
- if r.status_code == 404:
59
- # 2) Fallback: endpoint Gradio API
60
- print(
61
- "\nℹ️ Endpoint FastAPI indisponible, tentative via Gradio API (/api/predict_batch)..."
62
- )
63
- r = requests.post(
64
- f"{API_URL}/api/predict_batch", files=files, headers=headers, timeout=90
65
- )
66
- if r.status_code == 404:
67
- print(
68
- "\n❌ Endpoint HF introuvable (/predict/batch et /api/predict_batch)."
69
- )
70
- print(
71
- " Vérifiez que la Space expose l'API FastAPI ou l'onglet Batch Gradio."
72
- )
73
- print(" Sinon, utilisez l'API locale (lancer_api.sh).")
74
- raise SystemExit(1)
75
- r.raise_for_status()
76
- result = r.json()
77
-
78
- # Construire le CSV de sortie
79
- predictions_data = []
80
- for pred in result.get("predictions", []):
81
- predictions_data.append(
82
- {
83
- "employee_id": pred.get("employee_id"),
84
- "prediction": (
85
- "VA PARTIR" if pred.get("prediction") == 1 else "VA RESTER"
86
- ),
87
- "prediction_code": pred.get("prediction"),
88
- "risk_level": pred.get("risk_level"),
89
- "probability_stay": f"{pred.get('probability_stay', 0):.2%}",
90
- "probability_leave": f"{pred.get('probability_leave', 0):.2%}",
91
- }
92
- )
93
-
94
- df = pd.DataFrame(predictions_data)
95
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
96
- output_path = os.path.join(script_dir, f"predictions_batch_hf_{timestamp}.csv")
97
- df.to_csv(output_path, index=False, encoding="utf-8-sig")
98
-
99
- # Affichage
100
- summary = result.get("summary", {})
101
- print("\n" + "═" * 60)
102
- print(" 📊 RÉSULTAT (HF)")
103
- print("═" * 60)
104
- print(
105
- f"\n✅ Traités: {result.get('total_employees')} | RESTER: {summary.get('total_stay')} | PARTIR: {summary.get('total_leave')}"
106
- )
107
- print(
108
- f"🔴 High: {summary.get('high_risk_count')} 🟡 Medium: {summary.get('medium_risk_count')} 🟢 Low: {summary.get('low_risk_count')}\n"
109
  )
110
- print("📄 Aperçu:")
111
- print(df.head(5).to_string(index=False))
112
- print(f"\n💾 Sauvegardé: {output_path}")
113
-
114
- finally:
115
- for f in files.values():
116
- try:
117
- f.close()
118
- except Exception:
119
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ 📦 Prédiction BATCH via API Hugging Face (Gradio Client)
4
 
5
  Usage: python demo_batch_hf.py
6
  - Utilise par défaut les CSV d'exemple du dossier
7
+ - Envoie les 3 fichiers à la Space HF via Gradio Client
8
  - Sauvegarde un CSV de résultats
9
 
10
+ Prérequis: pip install gradio_client
11
  """
12
 
13
  import os
14
+ import sys
15
  import pandas as pd
 
16
  from datetime import datetime
17
 
18
+ try:
19
+ from gradio_client import Client, handle_file
20
+ except ImportError:
21
+ print("❌ gradio_client non installé. Installez-le avec:")
22
+ print(" pip install gradio_client")
23
+ sys.exit(1)
24
+
25
  API_URL = os.getenv("HF_API_URL", "https://asi-engineer-oc-p5.hf.space")
26
 
27
  print("╔══════════════════════════════════════════════════════════╗")
28
+ print("║ 📦 Prédiction BATCH - API Hugging Face (Gradio) ║")
29
  print("╚══════════════════════════════════════════════════════════╝\n")
30
  print(f"🌐 API: {API_URL}\n")
31
 
 
39
  for path in [sondage_path, eval_path, sirh_path]:
40
  if not os.path.exists(path):
41
  print(f"❌ Fichier introuvable: {path}")
42
+ sys.exit(1)
43
 
44
  print("✅ Fichiers d'exemple détectés:")
45
  print(f" - {os.path.basename(sondage_path)}")
46
  print(f" - {os.path.basename(eval_path)}")
47
  print(f" - {os.path.basename(sirh_path)}\n")
48
 
49
+ print("⏳ Connexion à l'API Gradio...")
50
+ try:
51
+ client = Client(API_URL)
52
+ print(" Connecté à l'API Gradio\n")
53
+ except Exception as e:
54
+ print(f"❌ Impossible de se connecter: {e}")
55
+ sys.exit(1)
 
 
 
56
 
57
+ print("⏳ Envoi des fichiers pour prédiction batch...")
58
  try:
59
+ result = client.predict(
60
+ sondage_path=handle_file(sondage_path),
61
+ eval_path=handle_file(eval_path),
62
+ sirh_path=handle_file(sirh_path),
63
+ api_name="/predict_batch",
64
  )
65
+ except Exception as e:
66
+ print(f"❌ Erreur lors de la prédiction: {e}")
67
+ sys.exit(1)
68
+
69
+ # Vérifier si erreur dans le résultat
70
+ if isinstance(result, dict) and "error" in result:
71
+ print(f"\n❌ Erreur API: {result.get('error')}")
72
+ print(f" Message: {result.get('message')}")
73
+ sys.exit(1)
74
+
75
+ # Construire le CSV de sortie
76
+ predictions_data = []
77
+ for pred in result.get("predictions", []):
78
+ predictions_data.append(
79
+ {
80
+ "employee_id": pred.get("employee_id"),
81
+ "prediction": "VA PARTIR" if pred.get("prediction") == 1 else "VA RESTER",
82
+ "prediction_code": pred.get("prediction"),
83
+ "risk_level": pred.get("risk_level"),
84
+ "probability_stay": f"{pred.get('probability_stay', 0):.2%}",
85
+ "probability_leave": f"{pred.get('probability_leave', 0):.2%}",
86
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
+
89
+ df = pd.DataFrame(predictions_data)
90
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
91
+ output_path = os.path.join(script_dir, f"predictions_batch_hf_{timestamp}.csv")
92
+ df.to_csv(output_path, index=False, encoding="utf-8-sig")
93
+
94
+ # Affichage résumé
95
+ summary = result.get("summary", {})
96
+ total = result.get("total_employees", len(predictions_data))
97
+
98
+ print("\n" + "=" * 50)
99
+ print("📊 RÉSULTATS DE LA PRÉDICTION BATCH")
100
+ print("=" * 50)
101
+ print(f"\n👥 Total employés analysés: {total}")
102
+ print(f"✅ Vont rester: {summary.get('total_stay', 'N/A')}")
103
+ print(f"❌ Vont partir: {summary.get('total_leave', 'N/A')}")
104
+ print(f"\n🔴 Risque élevé: {summary.get('high_risk_count', 'N/A')}")
105
+ print(f"🟠 Risque moyen: {summary.get('medium_risk_count', 'N/A')}")
106
+ print(f"🟢 Risque faible: {summary.get('low_risk_count', 'N/A')}")
107
+
108
+ print(f"\n💾 Résultats sauvegardés: {os.path.basename(output_path)}")
109
+ print("\n✅ Prédiction batch terminée avec succès!")