obx0x3 commited on
Commit
c1a3351
·
verified ·
1 Parent(s): c1de85f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -19
app.py CHANGED
@@ -1,25 +1,76 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
- import torch
5
 
6
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
7
 
8
- model = T5ForConditionalGeneration.from_pretrained("obx0x3/empathy-dementia")
9
- tokenizer = T5Tokenizer.from_pretrained("obx0x3/empathy-dementia")
 
 
 
 
 
 
10
 
11
- class PromptRequest(BaseModel):
12
- message: str
13
- lang: str = "en"
 
 
 
 
14
 
15
- @app.post("/generate")
16
- async def generate_response(payload: PromptRequest):
17
- prefix = "émotion: " if payload.lang == "fr" else ""
18
- text = prefix + payload.message
 
 
 
 
 
19
 
20
- input_ids = tokenizer.encode(text, return_tensors="pt")
21
- with torch.no_grad():
22
- outputs = model.generate(input_ids, max_length=50)
 
 
 
 
 
 
 
 
23
 
24
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- return {"reply": result.strip()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
 
 
3
 
4
+ # Predefined attributes for empathy training
5
+ emotions = ["lonely", "confused", "sadness", "frustrated", "scared"]
6
+ intents = ["seek_reassurance", "express_confusion", "seek_support", "vent", "request_help"]
7
+ difficulties = ["easy", "moderate", "hard"]
8
+ tags_pool = [
9
+ ["memory", "forgetting", "support"],
10
+ ["family", "burden", "dementia"],
11
+ ["confusion", "orientation", "safety"],
12
+ ["fear", "loneliness", "aging"],
13
+ ["support", "validation", "love"]
14
+ ]
15
 
16
+ # English and French empathetic templates
17
+ en_templates = [
18
+ ("I feel like {emotion}.", "You're not alone. You're understood."),
19
+ ("I'm always {emotion}, it's scary.", "That sounds tough. I'm here to support you."),
20
+ ("Sometimes I just feel so {emotion}.", "You're valued. Let's work through this together."),
21
+ ("Why do I keep forgetting things?", "It's okay. Memory issues happen, and you’re not alone."),
22
+ ("I don’t know where I am.", "Let’s take a deep breath. I can help guide you.")
23
+ ]
24
 
25
+ fr_templates = [
26
+ ("Je me sens {emotion}.", "Tu n'es pas seul. Je suis là pour toi."),
27
+ ("Je suis toujours {emotion}, c'est effrayant.", "C'est dur. Je te comprends."),
28
+ ("Parfois je me sens si {emotion}.", "Tu comptes beaucoup. Ensemble, on va y arriver."),
29
+ ("Pourquoi j'oublie tout le temps ?", "Ce n'est pas grave. Ça arrive à beaucoup de gens."),
30
+ ("Je ne sais pas où je suis.", "Respire un peu. Je vais t'aider.")
31
+ ]
32
 
33
+ # Generate a single sample
34
+ def create_entry(lang):
35
+ templates = en_templates if lang == "en" else fr_templates
36
+ input_template, response = random.choice(templates)
37
+ emotion = random.choice(emotions)
38
+ intent = random.choice(intents)
39
+ tags = random.choice(tags_pool)
40
+ difficulty = random.choice(difficulties)
41
+ input_text = input_template.format(emotion=emotion)
42
 
43
+ return {
44
+ "input": input_text,
45
+ "response": response,
46
+ "emotion": emotion,
47
+ "intent": intent,
48
+ "tags": tags,
49
+ "care_mode": True,
50
+ "language": lang,
51
+ "difficulty": difficulty,
52
+ "is_dementia_related": True
53
+ }
54
 
55
+ # Generate 100 English and 70 French samples
56
+ en_data = [create_entry("en") for _ in range(100)]
57
+ fr_data = [create_entry("fr") for _ in range(70)]
58
+ all_data = en_data + fr_data
59
+ random.shuffle(all_data)
60
+
61
+ # Smart split (70% train, 20% validation, 10% test)
62
+ train = all_data[:int(0.7 * len(all_data))]
63
+ validation = all_data[int(0.7 * len(all_data)):int(0.9 * len(all_data))]
64
+ test = all_data[int(0.9 * len(all_data)):]
65
+
66
+ # Write to JSON
67
+ with open("dementia_train_split.json", "w", encoding="utf-8") as f:
68
+ json.dump(train, f, indent=2, ensure_ascii=False)
69
+
70
+ with open("dementia_validation_split.json", "w", encoding="utf-8") as f:
71
+ json.dump(validation, f, indent=2, ensure_ascii=False)
72
+
73
+ with open("dementia_test_multilang.json", "w", encoding="utf-8") as f:
74
+ json.dump(test, f, indent=2, ensure_ascii=False)
75
+
76
+ print("✅ Dataset splits created (train/validation/test)")