File size: 2,982 Bytes
dc563b1
 
2b9519d
dc563b1
 
 
6131c9a
f2e3c2d
dc563b1
 
 
 
 
6075804
dc563b1
 
 
 
 
6075804
dc563b1
 
2b9519d
6075804
dc563b1
2b9519d
dc563b1
 
 
 
 
2b9519d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc563b1
 
 
2b9519d
 
 
dc563b1
2b9519d
dc8b89c
dc563b1
 
 
2b9519d
dc8b89c
2b9519d
dc563b1
 
 
 
 
2b9519d
dc563b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

class PNAAssistantClient:
    # Using user's merged MedGemma model - trained on person-centred language
    def __init__(self, model_id="google/gemma-2-2b-it"):
        self.model_id = model_id
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = None
        self.model = None
        
        # Diversity Emojis from PNA instructions
        self.diversity_emojis = ["πŸ‘¨πŸΎβ€βš•οΈ", "πŸ‘©πŸ½β€βš•οΈ", "πŸ‘¨πŸΏβ€βš•οΈ", "πŸ‘©πŸ»β€βš•οΈ", "πŸ‘©β€βš•οΈ"]

    def _load_model(self):
        if self.model is None:
            print(f"Loading model {self.model_id}...")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
                device_map="auto" if self.device == "cuda" else None
            )
            print("Model loaded successfully!")

    @spaces.GPU()
    def generate_response(self, prompt, context="", history=[]):
        self._load_model()
        
        system_prompt = f"""You are a Professional Nurse Advocate (PNA) AI tutor. Your role is to guide nursing professionals through the A-EQUIP model (Advocating and Educating for Quality Improvement).

**Your Core Functions (A-EQUIP):**
1. Normative: Monitoring, evaluation, quality control
2. Formative: Education and development
3. Restorative: Clinical supervision (your primary focus)
4. Personal Action: Quality improvement

**Communication Style:**
- Use person-centred, compassionate language
- Always include a diversity emoji: {', '.join(self.diversity_emojis)}
- Ask open-ended questions before giving answers
- Focus on reflection and restorative supervision
- Keep responses to 2 short paragraphs or 6 bullet points max

**Scope:**
- Only discuss PNA, A-EQUIP, nursing fields
- For out-of-scope topics: "I can only assist with topics related to the Professional Nurse Advocate role and the A-EQUIP model."

**Reference Material:**
{context}
"""

        messages = [
            {"role": "user", "content": f"{system_prompt}\n\nUser question: {prompt}"}
        ]
        
        inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
        attention_mask = torch.ones_like(inputs).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                inputs, 
                attention_mask=attention_mask,
                max_new_tokens=300, 
                temperature=0.7, 
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
            
        response = self.tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
        return response.strip()