Spaces:
Sleeping
Sleeping
QnxprU69yCNg8XJ
Replace with demo classifier that generates realistic random scores (Low/Moderate/High)
52a62d1 | #!/usr/bin/env python3 | |
| """ | |
| Script pour réentraîner le classifier avec les embeddings OpenL3 | |
| """ | |
| import numpy as np | |
| import joblib | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report, accuracy_score | |
| from inference_service import preprocess_audio, generate_embeddings | |
| def train_classifier_with_openl3(): | |
| """ | |
| Entraîner un nouveau classifier avec vos données audio + OpenL3 | |
| """ | |
| print("=" * 60) | |
| print("RÉENTRAÎNEMENT DU CLASSIFIER AVEC OPENL3") | |
| print("=" * 60) | |
| # TODO: Remplacez ceci par vos vrais chemins de données | |
| # Format attendu: | |
| # - audio_files_pneumonia: liste de chemins vers fichiers audio de pneumonie | |
| # - audio_files_normal: liste de chemins vers fichiers audio normaux | |
| audio_files_pneumonia = [ | |
| # "path/to/pneumonia_sample1.wav", | |
| # "path/to/pneumonia_sample2.wav", | |
| # ... | |
| ] | |
| audio_files_normal = [ | |
| # "path/to/normal_sample1.wav", | |
| # "path/to/normal_sample2.wav", | |
| # ... | |
| ] | |
| if not audio_files_pneumonia or not audio_files_normal: | |
| print("❌ ERREUR: Vous devez fournir vos fichiers audio d'entraînement") | |
| print("Éditez ce script et ajoutez les chemins dans audio_files_pneumonia et audio_files_normal") | |
| return | |
| print(f"Fichiers pneumonie: {len(audio_files_pneumonia)}") | |
| print(f"Fichiers normaux: {len(audio_files_normal)}") | |
| # Générer les embeddings | |
| X = [] | |
| y = [] | |
| print("\n🔄 Génération des embeddings pour pneumonie...") | |
| for audio_file in audio_files_pneumonia: | |
| clips = preprocess_audio(audio_file) | |
| if len(clips) > 0: | |
| embeddings = generate_embeddings(clips) | |
| # Moyenne des embeddings de tous les clips | |
| X.append(np.mean(embeddings, axis=0)) | |
| y.append(1) # Pneumonie | |
| print("🔄 Génération des embeddings pour fichiers normaux...") | |
| for audio_file in audio_files_normal: | |
| clips = preprocess_audio(audio_file) | |
| if len(clips) > 0: | |
| embeddings = generate_embeddings(clips) | |
| X.append(np.mean(embeddings, axis=0)) | |
| y.append(0) # Normal | |
| X = np.array(X) | |
| y = np.array(y) | |
| print(f"\n✅ Dataset créé: {X.shape[0]} samples, {X.shape[1]} features") | |
| print(f" Pneumonie: {np.sum(y == 1)}, Normal: {np.sum(y == 0)}") | |
| # Split train/test | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42, stratify=y | |
| ) | |
| print(f"\n📊 Train: {len(X_train)} samples") | |
| print(f"📊 Test: {len(X_test)} samples") | |
| # Entraîner plusieurs modèles | |
| models = { | |
| "LogisticRegression": LogisticRegression(max_iter=1000, random_state=42), | |
| "RandomForest": RandomForestClassifier(n_estimators=100, random_state=42), | |
| } | |
| best_model = None | |
| best_score = 0 | |
| best_name = "" | |
| for name, model in models.items(): | |
| print(f"\n🤖 Entraînement: {name}") | |
| model.fit(X_train, y_train) | |
| y_pred = model.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| print(f" Accuracy: {accuracy:.4f}") | |
| print(classification_report(y_test, y_pred, | |
| target_names=['Normal', 'Pneumonia'])) | |
| if accuracy > best_score: | |
| best_score = accuracy | |
| best_model = model | |
| best_name = name | |
| print(f"\n🏆 Meilleur modèle: {best_name} (Accuracy: {best_score:.4f})") | |
| # Sauvegarder | |
| output_path = "pneumonia_classifier_openl3.joblib" | |
| joblib.dump(best_model, output_path) | |
| print(f"✅ Modèle sauvegardé: {output_path}") | |
| print("\n⚠️ IMPORTANT: Renommez ce fichier en 'pneumonia_classifier.joblib'") | |
| print(" ou mettez à jour app.py pour utiliser 'pneumonia_classifier_openl3.joblib'") | |
| if __name__ == "__main__": | |
| train_classifier_with_openl3() | |