|
|
| import re, json, torch, openai, numpy as np
|
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModelForSeq2SeqLM
|
| from sentence_transformers import SentenceTransformer, util
|
| from sklearn.metrics import ndcg_score
|
|
|
|
|
|
|
|
|
|
|
| """
|
| import os
|
| openai.api_key = os.getenv("OPENAI_API_KEY")
|
| """
|
|
|
|
|
|
|
|
|
|
|
| MODEL_NER = "ALTAH/ETMAN-BERT"
|
| tokenizer_ner = AutoTokenizer.from_pretrained(MODEL_NER)
|
| model_ner = AutoModelForTokenClassification.from_pretrained(MODEL_NER)
|
| ner_pipeline = pipeline("ner", model=model_ner, tokenizer=tokenizer_ner, aggregation_strategy="simple")
|
|
|
| icd11_labels = ["O","SYMPTOM","DISEASE","DRUG","BODY_PART","PROCEDURE","TEST",
|
| "ANATOMY","CONDITION","FINDING","SIGN","ALLERGY","VACCINE","OTHER"]
|
| id2label = {i: label for i,label in enumerate(icd11_labels)}
|
|
|
| def ner_and_placeholders(text):
|
| ner_results = ner_pipeline(text)
|
| placeholders, counter = {}, {}
|
| text_with_placeholders = text
|
|
|
| for ent in sorted(ner_results, key=lambda x: x["start"], reverse=True):
|
| label_id = int(ent["entity_group"].split("_")[1])
|
| label_name = id2label.get(label_id, "O")
|
| if label_name != "O":
|
| counter[label_name] = counter.get(label_name, 0) + 1
|
| placeholder = f"{label_name}_{counter[label_name]}"
|
| placeholders[placeholder] = ent["word"]
|
| text_with_placeholders = text_with_placeholders[:ent["start"]] + placeholder + text_with_placeholders[ent["end"]:]
|
| return text_with_placeholders, placeholders
|
|
|
|
|
|
|
|
|
| MODEL_TRANSLATE = "ALTAH/ADT-MSA"
|
| tokenizer_translate = AutoTokenizer.from_pretrained(MODEL_TRANSLATE)
|
| model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE)
|
|
|
| def translate_text_keep_placeholders(text_with_placeholders, placeholders):
|
| pattern = "|".join(re.escape(ph) for ph in placeholders.keys())
|
| placeholder_positions = [(m.start(), m.end(), m.group()) for m in re.finditer(pattern, text_with_placeholders)]
|
| text_no_placeholders = re.sub(pattern, "", text_with_placeholders)
|
|
|
| inputs = tokenizer_translate(text_no_placeholders, return_tensors="pt", truncation=True)
|
| translated_ids = model_translate.generate(**inputs, max_length=512)
|
| text_translated_no_placeholders = tokenizer_translate.decode(translated_ids[0], skip_special_tokens=True)
|
|
|
|
|
| for start, end, ph in sorted(placeholder_positions, key=lambda x: x[0], reverse=True):
|
| text_translated_no_placeholders = text_translated_no_placeholders[:start] + ph + text_translated_no_placeholders[start:]
|
| return text_translated_no_placeholders
|
|
|
|
|
|
|
|
|
| def translate_entities_with_gpt(placeholders):
|
| translated_entities = {}
|
| for ph, ent in placeholders.items():
|
| prompt = f"Traduisez uniquement cette entité médicale dialectale vers l'arabe standard (MSA) : {ent}"
|
| response = openai.ChatCompletion.create(
|
| model="gpt-4",
|
| messages=[{"role": "user", "content": prompt}],
|
| temperature=0
|
| )
|
| translated_entities[ph] = response.choices[0].message["content"].strip()
|
| return translated_entities
|
|
|
|
|
|
|
|
|
| def reinsert_and_polish(text_translated_msa, translated_entities):
|
| prompt = f"""
|
| Réinsérez les entités traduites dans le texte MSA en remplaçant les placeholders.
|
| Ajustez la syntaxe pour que la phrase soit correcte et naturelle.
|
|
|
| Texte MSA avec placeholders :
|
| {text_translated_msa}
|
|
|
| Entités traduites :
|
| {json.dumps(translated_entities, ensure_ascii=False, indent=2)}
|
|
|
| Réponse attendue : texte final MSA uniquement.
|
| """
|
| response = openai.ChatCompletion.create(
|
| model="gpt-4",
|
| messages=[{"role":"user","content":prompt}],
|
| temperature=0
|
| )
|
| return response.choices[0].message["content"].strip()
|
|
|
|
|
|
|
|
|
| def normalize_query(query_msa: str) -> str:
|
| return query_msa.strip()
|
|
|
|
|
|
|
|
|
| class DIALIR:
|
| def __init__(self, corpus_file, embeddings_file=None):
|
| self.embed_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
| self.corpus = self.load_corpus(corpus_file)
|
| if embeddings_file:
|
| self.corpus_embeddings = torch.load(embeddings_file)
|
| else:
|
| self.corpus_embeddings = self.embed_model.encode(self.corpus, convert_to_tensor=True)
|
|
|
| def load_corpus(self, file_path):
|
| with open(file_path, "r", encoding="utf-8") as f:
|
| return [line.strip() for line in f if line.strip()]
|
|
|
| def preprocess_query(self, query):
|
| text_ph, placeholders = ner_and_placeholders(query)
|
| text_translated = translate_text_keep_placeholders(text_ph, placeholders)
|
| translated_entities = translate_entities_with_gpt(placeholders)
|
| query_msa = reinsert_and_polish(text_translated, translated_entities)
|
| return normalize_query(query_msa)
|
|
|
| def search(self, query, top_k=5):
|
| query_msa = self.preprocess_query(query)
|
| query_embedding = self.embed_model.encode(query_msa, convert_to_tensor=True)
|
| cos_scores = util.cos_sim(query_embedding, self.corpus_embeddings)[0]
|
| top_results = torch.topk(cos_scores, k=top_k)
|
| return [(float(score), self.corpus[idx]) for score, idx in zip(top_results.values, top_results.indices)]
|
|
|
|
|
|
|
| def evaluate_ir(dial_ir, test_file, top_k=5):
|
| precisions, recalls, f1s, mrrs, aps, ndcgs = [], [], [], [], [], []
|
|
|
| with open(test_file, "r", encoding="utf-8") as f:
|
| for line in f:
|
| query, relevant_docs = line.strip().split("\t")
|
| relevant_docs = relevant_docs.split("|")
|
| results = dial_ir.search(query, top_k=top_k)
|
| retrieved_docs = [doc for _, doc in results]
|
|
|
| hits = sum([1 for doc in retrieved_docs if doc in relevant_docs])
|
| precision = hits / top_k
|
| recall = hits / len(relevant_docs) if relevant_docs else 0
|
| f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
|
|
|
|
|
| rank = 0
|
| for i, doc in enumerate(retrieved_docs, start=1):
|
| if doc in relevant_docs:
|
| rank = i
|
| break
|
| mrr = 1/rank if rank > 0 else 0
|
|
|
|
|
| ap, hit_count = 0, 0
|
| for i, doc in enumerate(retrieved_docs, start=1):
|
| if doc in relevant_docs:
|
| hit_count += 1
|
| ap += hit_count / i
|
| ap = ap / len(relevant_docs) if relevant_docs else 0
|
|
|
|
|
| y_true_ranked = [1 if doc in relevant_docs else 0 for doc in retrieved_docs]
|
| y_scores_ranked = [score for score, _ in results]
|
| ndcg = ndcg_score([y_true_ranked], [y_scores_ranked], k=top_k) if any(y_true_ranked) else 0
|
|
|
| precisions.append(precision)
|
| recalls.append(recall)
|
| f1s.append(f1)
|
| mrrs.append(mrr)
|
| aps.append(ap)
|
| ndcgs.append(ndcg)
|
|
|
| return {
|
| "Precision@k": np.mean(precisions),
|
| "Recall@k": np.mean(recalls),
|
| "F1@k": np.mean(f1s),
|
| "MRR": np.mean(mrrs),
|
| "MAP": np.mean(aps),
|
| "nDCG@k": np.mean(ndcgs),
|
| }
|
|
|
|
|