Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,25 +1,76 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
| 4 |
-
import torch
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# Predefined attributes for empathy training
|
| 5 |
+
emotions = ["lonely", "confused", "sadness", "frustrated", "scared"]
|
| 6 |
+
intents = ["seek_reassurance", "express_confusion", "seek_support", "vent", "request_help"]
|
| 7 |
+
difficulties = ["easy", "moderate", "hard"]
|
| 8 |
+
tags_pool = [
|
| 9 |
+
["memory", "forgetting", "support"],
|
| 10 |
+
["family", "burden", "dementia"],
|
| 11 |
+
["confusion", "orientation", "safety"],
|
| 12 |
+
["fear", "loneliness", "aging"],
|
| 13 |
+
["support", "validation", "love"]
|
| 14 |
+
]
|
| 15 |
|
| 16 |
+
# English and French empathetic templates
|
| 17 |
+
en_templates = [
|
| 18 |
+
("I feel like {emotion}.", "You're not alone. You're understood."),
|
| 19 |
+
("I'm always {emotion}, it's scary.", "That sounds tough. I'm here to support you."),
|
| 20 |
+
("Sometimes I just feel so {emotion}.", "You're valued. Let's work through this together."),
|
| 21 |
+
("Why do I keep forgetting things?", "It's okay. Memory issues happen, and you’re not alone."),
|
| 22 |
+
("I don’t know where I am.", "Let’s take a deep breath. I can help guide you.")
|
| 23 |
+
]
|
| 24 |
|
| 25 |
+
fr_templates = [
|
| 26 |
+
("Je me sens {emotion}.", "Tu n'es pas seul. Je suis là pour toi."),
|
| 27 |
+
("Je suis toujours {emotion}, c'est effrayant.", "C'est dur. Je te comprends."),
|
| 28 |
+
("Parfois je me sens si {emotion}.", "Tu comptes beaucoup. Ensemble, on va y arriver."),
|
| 29 |
+
("Pourquoi j'oublie tout le temps ?", "Ce n'est pas grave. Ça arrive à beaucoup de gens."),
|
| 30 |
+
("Je ne sais pas où je suis.", "Respire un peu. Je vais t'aider.")
|
| 31 |
+
]
|
| 32 |
|
| 33 |
+
# Generate a single sample
|
| 34 |
+
def create_entry(lang):
|
| 35 |
+
templates = en_templates if lang == "en" else fr_templates
|
| 36 |
+
input_template, response = random.choice(templates)
|
| 37 |
+
emotion = random.choice(emotions)
|
| 38 |
+
intent = random.choice(intents)
|
| 39 |
+
tags = random.choice(tags_pool)
|
| 40 |
+
difficulty = random.choice(difficulties)
|
| 41 |
+
input_text = input_template.format(emotion=emotion)
|
| 42 |
|
| 43 |
+
return {
|
| 44 |
+
"input": input_text,
|
| 45 |
+
"response": response,
|
| 46 |
+
"emotion": emotion,
|
| 47 |
+
"intent": intent,
|
| 48 |
+
"tags": tags,
|
| 49 |
+
"care_mode": True,
|
| 50 |
+
"language": lang,
|
| 51 |
+
"difficulty": difficulty,
|
| 52 |
+
"is_dementia_related": True
|
| 53 |
+
}
|
| 54 |
|
| 55 |
+
# Generate 100 English and 70 French samples
|
| 56 |
+
en_data = [create_entry("en") for _ in range(100)]
|
| 57 |
+
fr_data = [create_entry("fr") for _ in range(70)]
|
| 58 |
+
all_data = en_data + fr_data
|
| 59 |
+
random.shuffle(all_data)
|
| 60 |
+
|
| 61 |
+
# Smart split (70% train, 20% validation, 10% test)
|
| 62 |
+
train = all_data[:int(0.7 * len(all_data))]
|
| 63 |
+
validation = all_data[int(0.7 * len(all_data)):int(0.9 * len(all_data))]
|
| 64 |
+
test = all_data[int(0.9 * len(all_data)):]
|
| 65 |
+
|
| 66 |
+
# Write to JSON
|
| 67 |
+
with open("dementia_train_split.json", "w", encoding="utf-8") as f:
|
| 68 |
+
json.dump(train, f, indent=2, ensure_ascii=False)
|
| 69 |
+
|
| 70 |
+
with open("dementia_validation_split.json", "w", encoding="utf-8") as f:
|
| 71 |
+
json.dump(validation, f, indent=2, ensure_ascii=False)
|
| 72 |
+
|
| 73 |
+
with open("dementia_test_multilang.json", "w", encoding="utf-8") as f:
|
| 74 |
+
json.dump(test, f, indent=2, ensure_ascii=False)
|
| 75 |
+
|
| 76 |
+
print("✅ Dataset splits created (train/validation/test)")
|