obx0x3 commited on
Commit
eabf640
·
verified ·
1 Parent(s): c1a3351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -76
app.py CHANGED
@@ -1,76 +1,49 @@
1
- import json
2
- import random
3
-
4
- # Predefined attributes for empathy training
5
- emotions = ["lonely", "confused", "sadness", "frustrated", "scared"]
6
- intents = ["seek_reassurance", "express_confusion", "seek_support", "vent", "request_help"]
7
- difficulties = ["easy", "moderate", "hard"]
8
- tags_pool = [
9
- ["memory", "forgetting", "support"],
10
- ["family", "burden", "dementia"],
11
- ["confusion", "orientation", "safety"],
12
- ["fear", "loneliness", "aging"],
13
- ["support", "validation", "love"]
14
- ]
15
-
16
- # English and French empathetic templates
17
- en_templates = [
18
- ("I feel like {emotion}.", "You're not alone. You're understood."),
19
- ("I'm always {emotion}, it's scary.", "That sounds tough. I'm here to support you."),
20
- ("Sometimes I just feel so {emotion}.", "You're valued. Let's work through this together."),
21
- ("Why do I keep forgetting things?", "It's okay. Memory issues happen, and you’re not alone."),
22
- ("I don’t know where I am.", "Let’s take a deep breath. I can help guide you.")
23
- ]
24
-
25
- fr_templates = [
26
- ("Je me sens {emotion}.", "Tu n'es pas seul. Je suis pour toi."),
27
- ("Je suis toujours {emotion}, c'est effrayant.", "C'est dur. Je te comprends."),
28
- ("Parfois je me sens si {emotion}.", "Tu comptes beaucoup. Ensemble, on va y arriver."),
29
- ("Pourquoi j'oublie tout le temps ?", "Ce n'est pas grave. Ça arrive à beaucoup de gens."),
30
- ("Je ne sais pas où je suis.", "Respire un peu. Je vais t'aider.")
31
- ]
32
-
33
- # Generate a single sample
34
- def create_entry(lang):
35
- templates = en_templates if lang == "en" else fr_templates
36
- input_template, response = random.choice(templates)
37
- emotion = random.choice(emotions)
38
- intent = random.choice(intents)
39
- tags = random.choice(tags_pool)
40
- difficulty = random.choice(difficulties)
41
- input_text = input_template.format(emotion=emotion)
42
-
43
- return {
44
- "input": input_text,
45
- "response": response,
46
- "emotion": emotion,
47
- "intent": intent,
48
- "tags": tags,
49
- "care_mode": True,
50
- "language": lang,
51
- "difficulty": difficulty,
52
- "is_dementia_related": True
53
- }
54
-
55
- # Generate 100 English and 70 French samples
56
- en_data = [create_entry("en") for _ in range(100)]
57
- fr_data = [create_entry("fr") for _ in range(70)]
58
- all_data = en_data + fr_data
59
- random.shuffle(all_data)
60
-
61
- # Smart split (70% train, 20% validation, 10% test)
62
- train = all_data[:int(0.7 * len(all_data))]
63
- validation = all_data[int(0.7 * len(all_data)):int(0.9 * len(all_data))]
64
- test = all_data[int(0.9 * len(all_data)):]
65
-
66
- # Write to JSON
67
- with open("dementia_train_split.json", "w", encoding="utf-8") as f:
68
- json.dump(train, f, indent=2, ensure_ascii=False)
69
-
70
- with open("dementia_validation_split.json", "w", encoding="utf-8") as f:
71
- json.dump(validation, f, indent=2, ensure_ascii=False)
72
-
73
- with open("dementia_test_multilang.json", "w", encoding="utf-8") as f:
74
- json.dump(test, f, indent=2, ensure_ascii=False)
75
-
76
- print("✅ Dataset splits created (train/validation/test)")
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ import torch
5
+ import uvicorn
6
+
7
+ app = FastAPI()
8
+
9
+ MODEL_NAME = "obx0x3/empathy-dementia"
10
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
11
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
12
+
13
+ class PromptRequest(BaseModel):
14
+ message: str
15
+ lang: str = None # Optional
16
+
17
+ def detect_language(text: str):
18
+ fr_keywords = ["je", "tu", "c’est", "j’ai", "où", "suis", "pas", "peux"]
19
+ return "fr" if any(word in text.lower() for word in fr_keywords) else "en"
20
+
21
+ def prefix_message(message: str, lang: str) -> str:
22
+ if lang == "fr":
23
+ return f"émotion: {message}"
24
+ elif any(q in message.lower() for q in ["why", "how", "what", "when", "where", "?"]):
25
+ return f"chat: {message}"
26
+ elif any(e in message.lower() for e in ["feel", "i’m", "i am", "sad", "scared", "lonely", "happy", "forgot"]):
27
+ return f"emotion: {message}"
28
+ else:
29
+ return f"chat: {message}"
30
+
31
+ @app.get("/")
32
+ def root():
33
+ return {"message": "✅ Empathy model running!"}
34
+
35
+ @app.post("/generate")
36
+ async def generate_response(payload: PromptRequest):
37
+ lang = payload.lang or detect_language(payload.message)
38
+ input_text = prefix_message(payload.message, lang)
39
+
40
+ inputs = tokenizer.encode(input_text, return_tensors="pt")
41
+ with torch.no_grad():
42
+ outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True)
43
+
44
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return {"reply": result.strip(), "language": lang}
46
+
47
+ if __name__ == "__main__":
48
+ # For local testing: uvicorn app:app --reload --port 7860
49
+ uvicorn.run(app, host="0.0.0.0", port=7860)