| # -*- coding: utf-8 -*- | |
| from dial_ir import DIALIR, evaluate_ir | |
| # =========================== | |
| # Test du modèle DIAL-IR | |
| # =========================== | |
| if __name__ == "__main__": | |
| # Chemin vers ton corpus MSA médical | |
| corpus_file = "/content/drive/MyDrive/msa_medical_corpus.txt" | |
| dial_ir = DIALIR(corpus_file) | |
| # Sauvegarder les embeddings pour un usage futur | |
| dial_ir.save_embeddings("corpus_embeddings.pt") | |
| # Requête unique | |
| query = "ادوية ضغط الدم" | |
| results = dial_ir.search(query, top_k=5) | |
| print(f"\n=== Résultats pertinents pour : '{query}' ===") | |
| for score, doc in results: | |
| if score > 0.5: | |
| print(f"{score:.4f} → {doc}") | |
| # Créer un fichier de test temporaire pour l'évaluation IR | |
| test_file = "test_temp.txt" | |
| with open(test_file, "w", encoding="utf-8") as f: | |
| # on suppose que le 1er document du corpus est pertinent pour cette requête | |
| f.write(f"{query}\t{dial_ir.corpus[0]}\n") | |
| # Évaluation IR | |
| metrics = evaluate_ir(dial_ir, test_file, top_k=5) | |
| print("\n=== Métriques DIAL-IR ===") | |
| for name, value in metrics.items(): | |
| print(f"{name}: {value:.4f}") | |