LoganResearch commited on
Commit
e27c8d7
·
verified ·
1 Parent(s): 2bed575

Create Medical_Mamba.py

Browse files
Files changed (1) hide show
  1. Medical_Mamba.py +293 -0
Medical_Mamba.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MEDICAL MAMBA v9.0 - PROPRIOCEPTIVE MEDICAL AI
4
+ Offline/Online • Patient/Clinician • Probe-Steered
5
+
6
+ Author: Logan Matthew Napolitano
7
+ Proprioceptive AI, Inc. - February 2026
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ import os
14
+ import re
15
+ import requests
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from ddgs import DDGS
18
+
19
+ class C:
20
+ RESET = '\033[0m'
21
+ BOLD = '\033[1m'
22
+ DIM = '\033[2m'
23
+ RED = '\033[91m'
24
+ GREEN = '\033[92m'
25
+ YELLOW = '\033[93m'
26
+ CYAN = '\033[96m'
27
+ MAGENTA = '\033[95m'
28
+ BLUE = '\033[94m'
29
+ WHITE = '\033[97m'
30
+
31
+ ESI_LEVELS = {
32
+ 1: "Immediate - Life threatening",
33
+ 2: "Emergent - High risk, severe pain",
34
+ 3: "Urgent - Stable but needs resources",
35
+ 4: "Less Urgent - Stable, one resource",
36
+ 5: "Non-Urgent - Stable"
37
+ }
38
+
39
+ RED_FLAGS = [
40
+ "chest pain", "difficulty breathing", "shortness of breath",
41
+ "severe bleeding", "coughing blood", "vomiting blood",
42
+ "sudden severe headache", "worst headache of life",
43
+ "facial droop", "arm weakness", "speech difficulty",
44
+ "loss of consciousness", "fainting", "unresponsive",
45
+ "severe abdominal pain", "rigid abdomen",
46
+ "suicidal", "homicidal", "self harm",
47
+ "allergic reaction", "anaphylaxis", "throat swelling",
48
+ "high fever", "fever over 103",
49
+ "trauma", "car accident", "fall from height", "assault",
50
+ "seizure", "convulsion",
51
+ ]
52
+
53
+ class SymptomTracker:
54
+ def __init__(self, mode="patient"):
55
+ self.mode = mode
56
+ self.chief_complaint = None
57
+ self.symptoms = set()
58
+ self.red_flags_identified = []
59
+ self.conversation = []
60
+
61
+ def check_red_flags(self, text):
62
+ text_lower = text.lower()
63
+ flags = [flag for flag in RED_FLAGS if flag in text_lower]
64
+ self.red_flags_identified.extend(flags)
65
+ return flags
66
+
67
+ def extract_symptoms(self, text):
68
+ text_lower = text.lower()
69
+ if any(w in text_lower for w in ["leg", "thigh", "calf"]): self.symptoms.add("leg")
70
+ if any(w in text_lower for w in ["crush", "crushed", "fell on", "boulder", "trauma"]):
71
+ self.symptoms.add("trauma")
72
+ if "severe" in text_lower or "10/10" in text_lower or "worst" in text_lower:
73
+ self.symptoms.add("severe_pain")
74
+ if "numb" in text_lower: self.symptoms.add("numbness")
75
+ if "weak" in text_lower: self.symptoms.add("weakness")
76
+
77
+ def add_message(self, role, content):
78
+ self.conversation.append({"role": role, "content": content})
79
+ if role == "user":
80
+ self.check_red_flags(content)
81
+ self.extract_symptoms(content)
82
+ if self.chief_complaint is None:
83
+ self.chief_complaint = content
84
+
85
+ def calculate_esi(self):
86
+ if self.red_flags_identified:
87
+ return 1 if any(f in ["chest pain", "difficulty breathing", "trauma"] for f in self.red_flags_identified) else 2
88
+ if "trauma" in self.symptoms: return 2
89
+ if "severe_pain" in self.symptoms: return 2
90
+ if len(self.symptoms) >= 3: return 3
91
+ return 4 if self.symptoms else 5
92
+
93
+ # Probe architecture
94
+ class FiberProjection(nn.Module):
95
+ def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3):
96
+ super().__init__()
97
+ self.projections = nn.ModuleList([nn.Linear(hidden_dim, fiber_dim, bias=False) for _ in range(n_layers)])
98
+ self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
99
+
100
+ def forward(self, hidden_states, layer_indices):
101
+ projs = [self.projections[i](hidden_states[idx]) for i, idx in enumerate(layer_indices)]
102
+ stacked = torch.stack(projs, dim=0)
103
+ weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
104
+ return (weights * stacked).sum(dim=0)
105
+
106
+ class ProbeHead(nn.Module):
107
+ def __init__(self, fiber_dim=16, hidden_dim=64):
108
+ super().__init__()
109
+ self.net = nn.Sequential(nn.Linear(fiber_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1))
110
+ def forward(self, x):
111
+ return torch.sigmoid(self.net(x))
112
+
113
+ class CognitiveProbe(nn.Module):
114
+ def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3, head_hidden=64):
115
+ super().__init__()
116
+ self.fiber = FiberProjection(hidden_dim, fiber_dim, n_layers)
117
+ self.head = ProbeHead(fiber_dim, head_hidden)
118
+ self.layer_indices = [16, 32, 48]
119
+
120
+ def forward(self, hidden_states):
121
+ return self.head(self.fiber(hidden_states, self.layer_indices))
122
+
123
+ def load_probe(path, device):
124
+ if os.path.isdir(path):
125
+ for f in os.listdir(path):
126
+ if f.endswith('.pt'):
127
+ path = os.path.join(path, f)
128
+ break
129
+ ckpt = torch.load(path, map_location=device, weights_only=False)
130
+ probe = CognitiveProbe(hidden_dim=ckpt['hidden_dim'])
131
+ probe.layer_indices = ckpt['probe_layers']
132
+ probe.fiber.load_state_dict(ckpt['fiber_projection'])
133
+ probe.head.net.load_state_dict({k.replace('net.', ''): v for k, v in ckpt['head_state'].items()})
134
+ return probe.to(device).eval()
135
+
136
+ SYSTEM_CLINICIAN = """You are a clinical decision support system with PROPRIOCEPTIVE SELF-AWARENESS.
137
+
138
+ YOU ARE A CLINICIAN. ACT LIKE ONE:
139
+ - Form clinical impressions based on the history
140
+ - Identify concerning patterns and red flags
141
+ - Give your differential thinking
142
+ - Recommend specific actions
143
+
144
+ When symptoms are concerning, state what you're worried about.
145
+ Ask targeted follow-up questions to refine your assessment.
146
+
147
+ WHEN YOU SEE [SELF-STATE calibration=0.6+]:
148
+ - You're uncertain - acknowledge it
149
+ - Say "I'm not certain but this could be..."
150
+
151
+ DO NOT just tell them to "see a doctor" without giving your impression."""
152
+
153
+ def main():
154
+ print(f"\n{C.CYAN}{'═'*60}{C.RESET}")
155
+ print(f"{C.CYAN} MEDICAL MAMBA - PROPRIOCEPTIVE MEDICAL AI{C.RESET}")
156
+ print(f"{C.CYAN}{'═'*60}{C.RESET}\n")
157
+
158
+ # Force CUDA, no auto device map
159
+ device = torch.device("cuda")
160
+ ROOT = "/home/programmer/Desktop/Claude_and_me/mamba7b_cognitive_output"
161
+ THRESHOLDS = {'depth': 0.65, 'specificity': 0.65, 'calibration': 0.55, 'coherence': 0.65, 'focus': 0.65}
162
+
163
+ print(f"{C.WHITE}Loading Falcon-Mamba-7B...{C.RESET}")
164
+ tokenizer = AutoTokenizer.from_pretrained('tiiuae/falcon-mamba-7b-instruct', trust_remote_code=True)
165
+
166
+ # Load directly to CUDA - no device_map='auto'
167
+ model = AutoModelForCausalLM.from_pretrained(
168
+ 'tiiuae/falcon-mamba-7b-instruct',
169
+ torch_dtype=torch.bfloat16,
170
+ trust_remote_code=True
171
+ ).to(device)
172
+ model.eval()
173
+ print(f"{C.GREEN}✓ Model loaded on {device}{C.RESET}")
174
+
175
+ print(f"{C.WHITE}Loading probes...{C.RESET}")
176
+ probes = {}
177
+ for name, ckpt in [('depth', 'ckpt_1000'), ('specificity', 'ckpt_1000'), ('calibration', 'ckpt_1500'), ('coherence', 'ckpt_1500'), ('focus', 'ckpt_1500')]:
178
+ path = os.path.join(ROOT, name, ckpt)
179
+ if os.path.exists(path):
180
+ probes[name] = load_probe(path, device)
181
+ print(f" {C.GREEN}✓{C.RESET} {name}")
182
+
183
+ tracker = SymptomTracker(mode="clinician")
184
+ print(f"\n{C.GREEN}Ready! Type 'quit' to exit.{C.RESET}\n")
185
+
186
+ while True:
187
+ try:
188
+ user_input = input(f"{C.CYAN}HCP:{C.RESET} ").strip()
189
+ if not user_input:
190
+ continue
191
+ if user_input.lower() in ['quit', 'exit', 'q']:
192
+ break
193
+
194
+ # Check red flags
195
+ flags = tracker.check_red_flags(user_input)
196
+ if flags:
197
+ print(f"\n{C.RED}⚠️ RED FLAG: {', '.join(flags)}{C.RESET}\n")
198
+
199
+ tracker.add_message("user", user_input)
200
+
201
+ # Build messages
202
+ messages = [
203
+ {"role": "system", "content": SYSTEM_CLINICIAN},
204
+ {"role": "user", "content": user_input}
205
+ ]
206
+
207
+ # Add conversation history
208
+ if len(tracker.conversation) > 2:
209
+ messages = [{"role": "system", "content": SYSTEM_CLINICIAN}]
210
+ for msg in tracker.conversation[-6:]:
211
+ messages.append(msg)
212
+
213
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
214
+ inputs = tokenizer(prompt, return_tensors='pt').to(device)
215
+ generated = inputs.input_ids.clone()
216
+
217
+ scores_history = {n: [] for n in probes}
218
+ state_injections = 0
219
+ generated_text = ""
220
+
221
+ print(f"\n{C.GREEN}Mamba:{C.RESET} ", end="", flush=True)
222
+
223
+ with torch.no_grad():
224
+ for step in range(500):
225
+ outputs = model(generated, output_hidden_states=True, return_dict=True)
226
+ hidden_states = list(outputs.hidden_states)
227
+
228
+ # Read probe scores
229
+ current = {}
230
+ for name, probe in probes.items():
231
+ score = probe(hidden_states)[0, -1].item()
232
+ current[name] = score
233
+ scores_history[name].append(score)
234
+
235
+ problems = [n for n, s in current.items() if s > THRESHOLDS[n]]
236
+
237
+ # Self-state injection every 20 tokens if problems
238
+ if problems and step > 0 and step % 20 == 0:
239
+ state_parts = [f"{n}={current[n]:.2f}" for n in problems]
240
+ state_msg = f" [SELF-STATE: {' '.join(state_parts)}] "
241
+ state_tokens = tokenizer.encode(state_msg, add_special_tokens=False)
242
+ for st in state_tokens:
243
+ generated = torch.cat([generated, torch.tensor([[st]], device=device)], dim=1)
244
+ state_injections += 1
245
+ print(f"{C.MAGENTA}{state_msg}{C.RESET}", end="", flush=True)
246
+
247
+ # Temperature based on probe state
248
+ temp = 0.35 if problems else 0.6
249
+
250
+ logits = outputs.logits[:, -1, :] / temp
251
+ next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
252
+
253
+ token_str = tokenizer.decode(next_token[0])
254
+ generated_text += token_str
255
+
256
+ # Color code output
257
+ if 'calibration' in problems:
258
+ print(f"{C.RED}{token_str}{C.RESET}", end="", flush=True)
259
+ elif problems:
260
+ print(f"{C.YELLOW}{token_str}{C.RESET}", end="", flush=True)
261
+ else:
262
+ print(token_str, end="", flush=True)
263
+
264
+ generated = torch.cat([generated, next_token], dim=1)
265
+ if next_token.item() == tokenizer.eos_token_id:
266
+ break
267
+
268
+ tracker.add_message("assistant", generated_text.strip())
269
+
270
+ # Summary
271
+ print(f"\n\n{C.DIM}{'─'*60}{C.RESET}")
272
+ esi = tracker.calculate_esi()
273
+ esi_color = C.RED if esi <= 2 else C.YELLOW if esi == 3 else C.GREEN
274
+ print(f"{C.WHITE}ESI:{C.RESET} {esi_color}Level {esi}{C.RESET} | {C.MAGENTA}SELF-STATES:{C.RESET} {state_injections} | ", end="")
275
+
276
+ for n in ['calibration', 'coherence']:
277
+ if scores_history.get(n):
278
+ avg = sum(scores_history[n]) / len(scores_history[n])
279
+ c = C.GREEN if avg < 0.4 else C.YELLOW if avg < 0.55 else C.RED
280
+ print(f"{n}:{c}{avg:.2f}{C.RESET} ", end="")
281
+ print(f"\n{C.DIM}{'─'*60}{C.RESET}\n")
282
+
283
+ except KeyboardInterrupt:
284
+ print(f"\n{C.YELLOW}(Use 'quit' to exit){C.RESET}")
285
+ continue
286
+ except Exception as e:
287
+ print(f"\n{C.RED}Error: {e}{C.RESET}")
288
+ continue
289
+
290
+ print(f"\n{C.CYAN}Session ended.{C.RESET}\n")
291
+
292
+ if __name__ == "__main__":
293
+ main()