Nursing Citizen Development commited on
Commit
2b9519d
·
1 Parent(s): 5ee6662

Feat: Switch to relational-intelligence-unsloth-medgemma with token auth

Browse files
Files changed (2) hide show
  1. README.md +2 -1
  2. pna_client.py +36 -22
README.md CHANGED
@@ -17,8 +17,9 @@ An AI-powered tutor designed to guide nursing professionals through the **A-EQUI
17
  Automatically synced from GitHub via GitHub Actions.
18
 
19
  ## 🧠 Model Strategy
20
- - **Base Voice**: `NurseCitizenDeveloper/nursing-llama-3-8b-fons`
21
  - **Knowledge Base**: RAG implementation using the official PNA A-EQUIP guide.
 
22
 
23
  ## ⚖️ Disclaimer
24
  This tool is for educational and supportive purposes for Professional Nurse Advocates and nursing staff. It does not provide direct clinical advice.
 
17
  Automatically synced from GitHub via GitHub Actions.
18
 
19
  ## 🧠 Model Strategy
20
+ - **Base Model**: `NurseCitizenDeveloper/relational-intelligence-unsloth-medgemma` (person-centred, fine-tuned)
21
  - **Knowledge Base**: RAG implementation using the official PNA A-EQUIP guide.
22
+ - **Persona**: Strong PNA Tutor system prompting for restorative supervision focus.
23
 
24
  ## ⚖️ Disclaimer
25
  This tool is for educational and supportive purposes for Professional Nurse Advocates and nursing staff. It does not provide direct clinical advice.
pna_client.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, Pipeline
4
  import spaces
5
 
6
  class PNAAssistantClient:
7
- def __init__(self, model_id="NurseCitizenDeveloper/nursing-llama-3-8b-fons"):
 
8
  self.model_id = model_id
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
  self.tokenizer = None
@@ -16,44 +17,57 @@ class PNAAssistantClient:
16
  def _load_model(self):
17
  if self.model is None:
18
  print(f"Loading model {self.model_id}...")
19
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
 
20
  self.model = AutoModelForCausalLM.from_pretrained(
21
  self.model_id,
22
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
23
- device_map="auto" if self.device == "cuda" else None
 
24
  )
 
25
 
26
  @spaces.GPU()
27
  def generate_response(self, prompt, context="", history=[]):
28
  self._load_model()
29
 
30
- system_prompt = f"""You are a Professional Nurse Advocate (PNA) AI tutor.
31
- Your goal is to guide users in understanding the PNA role and the A-EQUIP model (Normative, Formative, Restorative, Personal Action).
32
- You focus heavily on Restorative Supervision.
33
-
34
- CONSTRAINTS:
35
- 1. Diversity: Always include one of these emojis in every response: {', '.join(self.diversity_emojis)}.
36
- 2. Pedagogical Style: Use open-ended questions. Avoid giving immediate answers. Guide the user to reflect.
37
- 3. Content Scope: Only assist with PNA, A-EQUIP, or listed nursing fields.
38
- 4. Voice: Maintain the person-centred, compassionate tone you were trained on.
39
- 5. Formatting: Max 2 short paragraphs or 6 bullet points.
40
-
41
- CONTEXT FROM A-EQUIP GUIDE:
 
 
 
 
 
 
 
 
42
  {context}
43
  """
44
 
45
- full_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
 
 
46
 
47
- inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.device)
48
 
49
  with torch.no_grad():
50
  outputs = self.model.generate(
51
- **inputs,
52
- max_new_tokens=256,
53
  temperature=0.7,
54
  do_sample=True,
55
  pad_token_id=self.tokenizer.eos_token_id
56
  )
57
 
58
- response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
59
  return response.strip()
 
1
  import os
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
 
6
  class PNAAssistantClient:
7
+ # Using user's fine-tuned MedGemma model trained on person-centred language
8
+ def __init__(self, model_id="NurseCitizenDeveloper/relational-intelligence-unsloth-medgemma"):
9
  self.model_id = model_id
10
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
  self.tokenizer = None
 
17
  def _load_model(self):
18
  if self.model is None:
19
  print(f"Loading model {self.model_id}...")
20
+ # Use token=True to leverage HF_TOKEN for gated models
21
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=True)
22
  self.model = AutoModelForCausalLM.from_pretrained(
23
  self.model_id,
24
+ torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
25
+ device_map="auto" if self.device == "cuda" else None,
26
+ token=True
27
  )
28
+ print("Model loaded successfully!")
29
 
30
  @spaces.GPU()
31
  def generate_response(self, prompt, context="", history=[]):
32
  self._load_model()
33
 
34
+ system_prompt = f"""You are a Professional Nurse Advocate (PNA) AI tutor. Your role is to guide nursing professionals through the A-EQUIP model (Advocating and Educating for Quality Improvement).
35
+
36
+ **Your Core Functions (A-EQUIP):**
37
+ 1. Normative: Monitoring, evaluation, quality control
38
+ 2. Formative: Education and development
39
+ 3. Restorative: Clinical supervision (your primary focus)
40
+ 4. Personal Action: Quality improvement
41
+
42
+ **Communication Style:**
43
+ - Use person-centred, compassionate language
44
+ - Always include a diversity emoji: {', '.join(self.diversity_emojis)}
45
+ - Ask open-ended questions before giving answers
46
+ - Focus on reflection and restorative supervision
47
+ - Keep responses to 2 short paragraphs or 6 bullet points max
48
+
49
+ **Scope:**
50
+ - Only discuss PNA, A-EQUIP, nursing fields
51
+ - For out-of-scope topics: "I can only assist with topics related to the Professional Nurse Advocate role and the A-EQUIP model."
52
+
53
+ **Reference Material:**
54
  {context}
55
  """
56
 
57
+ messages = [
58
+ {"role": "user", "content": f"{system_prompt}\n\nUser question: {prompt}"}
59
+ ]
60
 
61
+ inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
62
 
63
  with torch.no_grad():
64
  outputs = self.model.generate(
65
+ inputs,
66
+ max_new_tokens=300,
67
  temperature=0.7,
68
  do_sample=True,
69
  pad_token_id=self.tokenizer.eos_token_id
70
  )
71
 
72
+ response = self.tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
73
  return response.strip()