ALTAH commited on
Commit
791e05d
·
verified ·
1 Parent(s): de725ce

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +8 -0
  2. dial_ir.py +134 -0
  3. msa_medical_corpus.txt +12 -0
  4. requirements.txt +7 -0
  5. test_dial_ir.py +34 -0
  6. test_temp.txt +2 -0
README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # DIAL-IR
2
+
3
+ DIAL-IR est un système de **recherche d’information en arabe dialectal** avec traduction en MSA et gestion des entités médicales.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install -r requirements.txt
dial_ir.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import re, json, torch, openai, numpy as np
3
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModelForSeq2SeqLM
4
+ from sentence_transformers import SentenceTransformer, util
5
+ from sklearn.metrics import ndcg_score
6
+
7
+ # ===========================
8
+ # Paramètres OpenAI
9
+ # ===========================
10
+ import os
11
+ openai.api_key = os.getenv("OPENAI_API_KEY")
12
+
13
+
14
+ # ===========================
15
+ # Portion 1 : NER + placeholders (ETMAN-BERT)
16
+ # ===========================
17
+ MODEL_NER = "ALTAH/ETMAN-BERT"
18
+ tokenizer_ner = AutoTokenizer.from_pretrained(MODEL_NER)
19
+ model_ner = AutoModelForTokenClassification.from_pretrained(MODEL_NER)
20
+ ner_pipeline = pipeline("ner", model=model_ner, tokenizer=tokenizer_ner, aggregation_strategy="simple")
21
+
22
+ icd11_labels = ["O","SYMPTOM","DISEASE","DRUG","BODY_PART","PROCEDURE","TEST",
23
+ "ANATOMY","CONDITION","FINDING","SIGN","ALLERGY","VACCINE","OTHER"]
24
+ id2label = {i: label for i,label in enumerate(icd11_labels)}
25
+
26
+ def ner_and_placeholders(text):
27
+ ner_results = ner_pipeline(text)
28
+ placeholders, counter = {}, {}
29
+ text_with_placeholders = text
30
+
31
+ for ent in sorted(ner_results, key=lambda x: x["start"], reverse=True):
32
+ label_id = int(ent["entity_group"].split("_")[1])
33
+ label_name = id2label.get(label_id, "O")
34
+ if label_name != "O":
35
+ counter[label_name] = counter.get(label_name, 0) + 1
36
+ placeholder = f"{label_name}_{counter[label_name]}"
37
+ placeholders[placeholder] = ent["word"]
38
+ text_with_placeholders = text_with_placeholders[:ent["start"]] + placeholder + text_with_placeholders[ent["end"]:]
39
+ return text_with_placeholders, placeholders
40
+
41
+ # ===========================
42
+ # Portion 2 : Traduction dialectal → MSA
43
+ # ===========================
44
+ MODEL_TRANSLATE = "ALTAH/ADT-MSA"
45
+ tokenizer_translate = AutoTokenizer.from_pretrained(MODEL_TRANSLATE)
46
+ model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE)
47
+
48
+ def translate_text_keep_placeholders(text_with_placeholders, placeholders):
49
+ pattern = "|".join(re.escape(ph) for ph in placeholders.keys())
50
+ placeholder_positions = [(m.start(), m.end(), m.group()) for m in re.finditer(pattern, text_with_placeholders)]
51
+ text_no_placeholders = re.sub(pattern, "", text_with_placeholders)
52
+
53
+ inputs = tokenizer_translate(text_no_placeholders, return_tensors="pt", truncation=True)
54
+ translated_ids = model_translate.generate(**inputs, max_length=512)
55
+ text_translated_no_placeholders = tokenizer_translate.decode(translated_ids[0], skip_special_tokens=True)
56
+
57
+ # Réinsérer les placeholders
58
+ for start, end, ph in sorted(placeholder_positions, key=lambda x: x[0], reverse=True):
59
+ text_translated_no_placeholders = text_translated_no_placeholders[:start] + ph + text_translated_no_placeholders[start:]
60
+ return text_translated_no_placeholders
61
+
62
+ # ===========================
63
+ # Portion 3 : Traduction entités avec GPT
64
+ # ===========================
65
+ def translate_entities_with_gpt(placeholders):
66
+ translated_entities = {}
67
+ for ph, ent in placeholders.items():
68
+ prompt = f"Traduisez uniquement cette entité médicale dialectale vers l'arabe standard (MSA) : {ent}"
69
+ response = openai.ChatCompletion.create(
70
+ model="gpt-4",
71
+ messages=[{"role": "user", "content": prompt}],
72
+ temperature=0
73
+ )
74
+ translated_entities[ph] = response.choices[0].message["content"].strip()
75
+ return translated_entities
76
+
77
+ # ===========================
78
+ # Portion 4 : Réinsertion + polish
79
+ # ===========================
80
+ def reinsert_and_polish(text_translated_msa, translated_entities):
81
+ prompt = f"""
82
+ Réinsérez les entités traduites dans le texte MSA en remplaçant les placeholders.
83
+ Ajustez la syntaxe pour que la phrase soit correcte et naturelle.
84
+
85
+ Texte MSA avec placeholders :
86
+ {text_translated_msa}
87
+
88
+ Entités traduites :
89
+ {json.dumps(translated_entities, ensure_ascii=False, indent=2)}
90
+
91
+ Réponse attendue : texte final MSA uniquement.
92
+ """
93
+ response = openai.ChatCompletion.create(
94
+ model="gpt-4",
95
+ messages=[{"role":"user","content":prompt}],
96
+ temperature=0
97
+ )
98
+ return response.choices[0].message["content"].strip()
99
+
100
+ # ===========================
101
+ # Portion 5 : Normalisation
102
+ # ===========================
103
+ def normalize_query(query_msa: str) -> str:
104
+ return query_msa.strip()
105
+
106
+ # ===========================
107
+ # Classe DIAL-IR
108
+ # ===========================
109
+ class DIALIR:
110
+ def __init__(self, corpus_file, embeddings_file=None):
111
+ self.embed_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
112
+ self.corpus = self.load_corpus(corpus_file)
113
+ if embeddings_file:
114
+ self.corpus_embeddings = torch.load(embeddings_file)
115
+ else:
116
+ self.corpus_embeddings = self.embed_model.encode(self.corpus, convert_to_tensor=True)
117
+
118
+ def load_corpus(self, file_path):
119
+ with open(file_path, "r", encoding="utf-8") as f:
120
+ return [line.strip() for line in f if line.strip()]
121
+
122
+ def preprocess_query(self, query):
123
+ text_ph, placeholders = ner_and_placeholders(query)
124
+ text_translated = translate_text_keep_placeholders(text_ph, placeholders)
125
+ translated_entities = translate_entities_with_gpt(placeholders)
126
+ query_msa = reinsert_and_polish(text_translated, translated_entities)
127
+ return normalize_query(query_msa)
128
+
129
+ def search(self, query, top_k=5):
130
+ query_msa = self.preprocess_query(query)
131
+ query_embedding = self.embed_model.encode(query_msa, convert_to_tensor=True)
132
+ cos_scores = util.cos_sim(query_embedding, self.corpus_embeddings)[0]
133
+ top_results = torch.topk(cos_scores, k=top_k)
134
+ return [(float(score), self.corpus[idx]) for score, idx in zip(top_results.values, top_results.indices)]
msa_medical_corpus.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ المريض يعاني من ارتفاع ضغط الدم، ويوصى بتناول دواء أملوديبين 5 ملغ مرة يومياً ومراقبة ضغط الدم أسبوعياً.
2
+ المريض مصاب بالسكري من النوع الثاني، ينصح بتعديل النظام الغذائي وممارسة الرياضة، مع تناول ميتفورمين 500 ملغ صباحاً ومساءً.
3
+ المريض يعاني من التهاب الحلق، ينصح بتناول مضاد حيوي أموكسيسيلين 500 ملغ ثلاث مرات يومياً لمدة سبعة أيام.
4
+ المريض يشكو من صداع نصفي متكرر، يمكن تناول دواء سوماتريبتان 50 ملغ عند ظهور الأعراض وعدم قيادة السيارة بعد تناوله.
5
+ المريض يعاني من حموضة المعدة، يوصى بتجنب الأطعمة الدهنية والحارة، وتناول أوميبرازول 20 ملغ قبل النوم.
6
+ المريض لديه أعراض نزلة برد، ينصح بالراحة في المنزل، شرب السوائل الدافئة، واستخدام شراب خافض للحرارة عند الحاجة.
7
+ المريض يعاني من التهاب المفاصل، ينصح بممارسة تمارين خفيفة، استخدام كمادات دافئة، وتناول دواء إيبوبروفين 400 ملغ عند الحاجة.
8
+ المريض يشكو من صعوبة في التنفس بسبب الربو، ينصح باستخدام جهاز استنشاق سالبوتامول عند الحاجة ومراجعة الطبيب عند زيادة الأعراض.
9
+ المريض يعاني من أرق متكرر، ينصح بمراعاة روتين نوم ثابت، تقليل الكافيين، ويمكن استخدام أقراص ميلاتونين 3 ملغ قبل النوم.
10
+ المريض يشكو من ألم في المعدة بعد تناول الطعام الدهني، ينصح بتناول مضادات الحموضة مثل رانيتيدين 150 ملغ بعد الوجبات.
11
+ ذهبت الى العمل منذ الصباح الباكر.
12
+ جاء أبي من العمل
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ sentence-transformers
3
+ transformers
4
+ scikit-learn
5
+ numpy
6
+ openai==0.28
7
+ huggingface_hub
test_dial_ir.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from dial_ir import DIALIR, evaluate_ir
3
+
4
+ # ===========================
5
+ # Test du modèle DIAL-IR
6
+ # ===========================
7
+ if __name__ == "__main__":
8
+ # Chemin vers ton corpus MSA médical
9
+ corpus_file = "/content/drive/MyDrive/msa_medical_corpus.txt"
10
+ dial_ir = DIALIR(corpus_file)
11
+
12
+ # Sauvegarder les embeddings pour un usage futur
13
+ dial_ir.save_embeddings("corpus_embeddings.pt")
14
+
15
+ # Requête unique
16
+ query = "ادوية ضغط الدم"
17
+ results = dial_ir.search(query, top_k=5)
18
+
19
+ print(f"\n=== Résultats pertinents pour : '{query}' ===")
20
+ for score, doc in results:
21
+ if score > 0.5:
22
+ print(f"{score:.4f} → {doc}")
23
+
24
+ # Créer un fichier de test temporaire pour l'évaluation IR
25
+ test_file = "test_temp.txt"
26
+ with open(test_file, "w", encoding="utf-8") as f:
27
+ # on suppose que le 1er document du corpus est pertinent pour cette requête
28
+ f.write(f"{query}\t{dial_ir.corpus[0]}\n")
29
+
30
+ # Évaluation IR
31
+ metrics = evaluate_ir(dial_ir, test_file, top_k=5)
32
+ print("\n=== Métriques DIAL-IR ===")
33
+ for name, value in metrics.items():
34
+ print(f"{name}: {value:.4f}")
test_temp.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ادوية ضغط الدم المريض يعاني من ارتفاع ضغط الدم، ويوصى بتناول دواء أملوديبين 5 ملغ مرة يومياً ومراقبة ضغط الدم أسبوعياً.|ارتفاع ضغط الدم قد يؤدي إلى مشاكل في القلب
2
+ أعراض السكري السكري يتم علاجه بالأنسولين أو الحمية|ارتفاع نسبة السكر في الدم قد تسبب العطش المتكرر