DIAL_IR / dial_ir.py
ALTAH's picture
Upload dial_ir.py
1898ed1 verified
# -*- coding: utf-8 -*-
import re, json, torch, openai, numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import ndcg_score
# ===========================
# Paramètres OpenAI
# ===========================
"""
import os
openai.api_key = os.getenv("OPENAI_API_KEY")
"""
#openai.api_key = "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"
# ===========================
# Portion 1 : NER + placeholders (ETMAN-BERT)
# ===========================
MODEL_NER = "ALTAH/ETMAN-BERT"
tokenizer_ner = AutoTokenizer.from_pretrained(MODEL_NER)
model_ner = AutoModelForTokenClassification.from_pretrained(MODEL_NER)
ner_pipeline = pipeline("ner", model=model_ner, tokenizer=tokenizer_ner, aggregation_strategy="simple")
icd11_labels = ["O","SYMPTOM","DISEASE","DRUG","BODY_PART","PROCEDURE","TEST",
"ANATOMY","CONDITION","FINDING","SIGN","ALLERGY","VACCINE","OTHER"]
id2label = {i: label for i,label in enumerate(icd11_labels)}
def ner_and_placeholders(text):
ner_results = ner_pipeline(text)
placeholders, counter = {}, {}
text_with_placeholders = text
for ent in sorted(ner_results, key=lambda x: x["start"], reverse=True):
label_id = int(ent["entity_group"].split("_")[1])
label_name = id2label.get(label_id, "O")
if label_name != "O":
counter[label_name] = counter.get(label_name, 0) + 1
placeholder = f"{label_name}_{counter[label_name]}"
placeholders[placeholder] = ent["word"]
text_with_placeholders = text_with_placeholders[:ent["start"]] + placeholder + text_with_placeholders[ent["end"]:]
return text_with_placeholders, placeholders
# ===========================
# Portion 2 : Traduction dialectal → MSA
# ===========================
MODEL_TRANSLATE = "ALTAH/ADT-MSA"
tokenizer_translate = AutoTokenizer.from_pretrained(MODEL_TRANSLATE)
model_translate = AutoModelForSeq2SeqLM.from_pretrained(MODEL_TRANSLATE)
def translate_text_keep_placeholders(text_with_placeholders, placeholders):
pattern = "|".join(re.escape(ph) for ph in placeholders.keys())
placeholder_positions = [(m.start(), m.end(), m.group()) for m in re.finditer(pattern, text_with_placeholders)]
text_no_placeholders = re.sub(pattern, "", text_with_placeholders)
inputs = tokenizer_translate(text_no_placeholders, return_tensors="pt", truncation=True)
translated_ids = model_translate.generate(**inputs, max_length=512)
text_translated_no_placeholders = tokenizer_translate.decode(translated_ids[0], skip_special_tokens=True)
# Réinsérer les placeholders
for start, end, ph in sorted(placeholder_positions, key=lambda x: x[0], reverse=True):
text_translated_no_placeholders = text_translated_no_placeholders[:start] + ph + text_translated_no_placeholders[start:]
return text_translated_no_placeholders
# ===========================
# Portion 3 : Traduction entités avec GPT
# ===========================
def translate_entities_with_gpt(placeholders):
translated_entities = {}
for ph, ent in placeholders.items():
prompt = f"Traduisez uniquement cette entité médicale dialectale vers l'arabe standard (MSA) : {ent}"
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
translated_entities[ph] = response.choices[0].message["content"].strip()
return translated_entities
# ===========================
# Portion 4 : Réinsertion + polish
# ===========================
def reinsert_and_polish(text_translated_msa, translated_entities):
prompt = f"""
Réinsérez les entités traduites dans le texte MSA en remplaçant les placeholders.
Ajustez la syntaxe pour que la phrase soit correcte et naturelle.
Texte MSA avec placeholders :
{text_translated_msa}
Entités traduites :
{json.dumps(translated_entities, ensure_ascii=False, indent=2)}
Réponse attendue : texte final MSA uniquement.
"""
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{"role":"user","content":prompt}],
temperature=0
)
return response.choices[0].message["content"].strip()
# ===========================
# Portion 5 : Normalisation
# ===========================
def normalize_query(query_msa: str) -> str:
return query_msa.strip()
# ===========================
# Classe DIAL-IR
# ===========================
class DIALIR:
def __init__(self, corpus_file, embeddings_file=None):
self.embed_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
self.corpus = self.load_corpus(corpus_file)
if embeddings_file:
self.corpus_embeddings = torch.load(embeddings_file)
else:
self.corpus_embeddings = self.embed_model.encode(self.corpus, convert_to_tensor=True)
def load_corpus(self, file_path):
with open(file_path, "r", encoding="utf-8") as f:
return [line.strip() for line in f if line.strip()]
def preprocess_query(self, query):
text_ph, placeholders = ner_and_placeholders(query)
text_translated = translate_text_keep_placeholders(text_ph, placeholders)
translated_entities = translate_entities_with_gpt(placeholders)
query_msa = reinsert_and_polish(text_translated, translated_entities)
return normalize_query(query_msa)
def search(self, query, top_k=5):
query_msa = self.preprocess_query(query)
query_embedding = self.embed_model.encode(query_msa, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, self.corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=top_k)
return [(float(score), self.corpus[idx]) for score, idx in zip(top_results.values, top_results.indices)]
# ===========================
# Évaluation IR
# ===========================
def evaluate_ir(dial_ir, test_file, top_k=5):
precisions, recalls, f1s, mrrs, aps, ndcgs = [], [], [], [], [], []
with open(test_file, "r", encoding="utf-8") as f:
for line in f:
query, relevant_docs = line.strip().split("\t")
relevant_docs = relevant_docs.split("|")
results = dial_ir.search(query, top_k=top_k)
retrieved_docs = [doc for _, doc in results]
hits = sum([1 for doc in retrieved_docs if doc in relevant_docs])
precision = hits / top_k
recall = hits / len(relevant_docs) if relevant_docs else 0
f1 = (2 * precision * recall) / (precision + recall) if (precision+recall) > 0 else 0
# MRR
rank = 0
for i, doc in enumerate(retrieved_docs, start=1):
if doc in relevant_docs:
rank = i
break
mrr = 1/rank if rank > 0 else 0
# AP
ap, hit_count = 0, 0
for i, doc in enumerate(retrieved_docs, start=1):
if doc in relevant_docs:
hit_count += 1
ap += hit_count / i
ap = ap / len(relevant_docs) if relevant_docs else 0
# nDCG
y_true_ranked = [1 if doc in relevant_docs else 0 for doc in retrieved_docs]
y_scores_ranked = [score for score, _ in results]
ndcg = ndcg_score([y_true_ranked], [y_scores_ranked], k=top_k) if any(y_true_ranked) else 0
precisions.append(precision)
recalls.append(recall)
f1s.append(f1)
mrrs.append(mrr)
aps.append(ap)
ndcgs.append(ndcg)
return {
"Precision@k": np.mean(precisions),
"Recall@k": np.mean(recalls),
"F1@k": np.mean(f1s),
"MRR": np.mean(mrrs),
"MAP": np.mean(aps),
"nDCG@k": np.mean(ndcgs),
}