| | from datasets import load_dataset |
| | from collections import defaultdict |
| | import json |
| | import re |
| | import random |
| |
|
| | folder = "data/" |
| | system_message = "You are a medical diagnosis classifier. Given a description of symptoms, provide ONLY the name of the most likely diagnosis. Do not include explanations, reasoning, or additional text." |
| |
|
| | |
| | dataset = load_dataset("sajjadhadi/disease-diagnosis-dataset", split="train") |
| | dataset = dataset.shuffle(seed=42) |
| |
|
| | |
| | def clean_symptom_text(text): |
| | pattern = r'(?:patient reported the following symptoms:|symptoms include:?)?\s*(.*?)(?:\s*(?:may indicate|based on these symptoms|what disease may the patient have\?|what is the most likely diagnosis\?).*)' |
| | match = re.search(pattern, text, re.IGNORECASE) |
| | if match: |
| | symptoms = match.group(1).strip() |
| | symptoms = re.sub(r'\s*,\s*', ', ', symptoms).rstrip(',') |
| | return f"{symptoms}" |
| | return text |
| |
|
| | |
| | diagnosis_to_samples = defaultdict(list) |
| | for i, sample in enumerate(dataset): |
| | diagnosis_to_samples[sample["diagnosis"]].append(i) |
| |
|
| | |
| | TARGET_SAMPLES = 300 |
| | MIN_SAMPLES = 75 |
| |
|
| | top_diagnoses = [diag for diag, indices in sorted(diagnosis_to_samples.items(), |
| | key=lambda x: len(x[1]), reverse=True) |
| | if len(indices) >= MIN_SAMPLES][:MIN_SAMPLES] |
| |
|
| | print(top_diagnoses) |
| | |
| | balanced_indices = [] |
| | for diag in top_diagnoses: |
| | indices = diagnosis_to_samples[diag] |
| | if len(indices) >= TARGET_SAMPLES: |
| | |
| | selected_indices = indices[:TARGET_SAMPLES] |
| | else: |
| | |
| | selected_indices = indices * (TARGET_SAMPLES // len(indices)) |
| | remaining = TARGET_SAMPLES % len(indices) |
| | selected_indices.extend(random.sample(indices, remaining)) |
| | balanced_indices.extend(selected_indices) |
| |
|
| | |
| | balanced_dataset = dataset.select(balanced_indices) |
| | print(f"Original dataset size: {len(dataset)}, Balanced dataset size: {len(balanced_indices)}") |
| | print(f"Number of unique diagnoses: {len(top_diagnoses)}") |
| |
|
| | |
| | splits = balanced_dataset.train_test_split(test_size=0.2, seed=42) |
| | test_valid_splits = splits['test'].train_test_split(test_size=0.5, seed=42) |
| |
|
| | |
| | def save_as_jsonl(dataset, filename): |
| | with open(filename, 'w') as file: |
| | for sample in dataset: |
| | cleaned_text = clean_symptom_text(sample["text"]) |
| | conversation = { |
| | "messages": [ |
| | {"role": "system", "content": system_message}, |
| | {"role": "user", "content": cleaned_text}, |
| | {"role": "assistant", "content": sample["diagnosis"]} |
| | ] |
| | } |
| | file.write(json.dumps(conversation) + '\n') |
| |
|
| | |
| | save_as_jsonl(splits["train"], folder + "train.jsonl") |
| | save_as_jsonl(test_valid_splits["train"], folder + "test.jsonl") |
| | save_as_jsonl(test_valid_splits["test"], folder + "valid.jsonl") |
| |
|
| | |
| | print("Dataset splits:") |
| | print(f" Train: {len(splits['train'])}") |
| | print(f" Test: {len(test_valid_splits['train'])}") |
| | print(f" Validation: {len(test_valid_splits['test'])}") |
| |
|
| | |
| | print("\nSample validation:") |
| | with open(folder + "train.jsonl", 'r') as file: |
| | for i, line in enumerate(file): |
| | if i >= 3: |
| | break |
| | example = json.loads(line) |
| | print(f"Example {i+1}:") |
| | print(f" System: {example['messages'][0]['content']}") |
| | print(f" User: {example['messages'][1]['content']}") |
| | print(f" Assistant: {example['messages'][2]['content']}") |
| | print() |
| |
|
| | |
| | class_counts = defaultdict(int) |
| | with open(folder + "train.jsonl", 'r') as file: |
| | for line in file: |
| | example = json.loads(line) |
| | diagnosis = example['messages'][2]['content'] |
| | class_counts[diagnosis] += 1 |
| |
|
| | print("\nClass distribution in training set:") |
| | for diagnosis, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]: |
| | print(f" {diagnosis}: {count}") |