yasserrmd commited on
Commit
21ddb34
·
verified ·
1 Parent(s): 7cc675b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -59
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import re
 
5
  import os
6
- from typing import List, Tuple
7
  import spaces
8
 
9
 
10
 
11
-
12
-
13
  # Model configuration
14
  MODEL_NAME = "yasserrmd/SinaReason-Magistral-2509"
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -19,49 +20,52 @@ MEDICAL_SYSTEM_PROMPT = """
19
  You are SinaReason, a medical reasoning assistant for educational and clinical support.
20
  Your goal is to carefully reason through clinical problems for a professional audience (clinicians, students).
21
  **Never provide medical advice directly to a patient.**
 
22
  First, draft your detailed thought process (inner monologue) inside <think> ... </think>.
23
  - Use this section to work through symptoms, differential diagnoses, and investigation plans.
24
  - Be explicit and thorough in your reasoning.
 
25
  After closing </think>, provide a clear, self-contained medical summary appropriate for a clinical professional.
26
  - Summarize the most likely diagnosis and your reasoning.
27
  - Suggest next steps for investigation or management.
28
  """
29
 
30
-
31
-
32
  class SinaReasonMedicalChat:
33
  def __init__(self):
34
  self.tokenizer = None
35
  self.model = None
36
- # The PixtralProcessor requires an image argument, even if it's None.
37
- # This is a mandatory part of the call signature.
38
- self.dummy_image = None
39
  self.load_model()
40
-
41
  def load_model(self):
42
- """Load the SinaReason medical model and tokenizer using Unsloth"""
43
  try:
44
- print(f"Loading medical model with Unsloth: {MODEL_NAME}")
45
- print("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
46
 
47
- self.model = AutoModelForCausalLM.from_pretrained(
48
  MODEL_NAME,
49
- torch_dtype=torch.bfloat16, # Use bfloat16 for modern GPUs
50
- device_map="auto", # Automatically map to the available GPU
51
  )
52
 
53
- # Load the standard tokenizer
54
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
55
-
56
- print("SinaReason medical model loaded successfully with Unsloth!")
57
 
58
  except Exception as e:
59
- print(f"Error loading model with Unsloth: {e}")
60
  raise e
61
 
62
  def extract_thinking_and_response(self, text: str) -> Tuple[str, str]:
63
  """Extract thinking process from <think>...</think> tags and clinical response"""
 
64
  think_pattern = r'<think>(.*?)</think>'
 
65
  thinking = ""
66
  response = text
67
 
@@ -73,32 +77,48 @@ class SinaReasonMedicalChat:
73
  return thinking, response
74
 
75
  @spaces.GPU(duration=120)
76
- def medical_chat(self, message: str, history: List[List[str]], max_tokens: int = 1024,
77
- temperature: float = 0.7, top_p: float = 0.95) -> Tuple[str, List[List[str]]]:
78
- """Generate medical reasoning responses using the Unsloth model."""
79
-
80
  if not message.strip():
81
- return "", history
82
-
83
- self.model.to("cuda")
84
- self.model.eval()
85
 
86
  # Apply the chat template with the medical system prompt
87
- messages = [{"role": "system", "content": MEDICAL_SYSTEM_PROMPT}]
 
 
 
 
88
  for user_msg, assistant_msg in history:
89
- raw_assistant_msg = assistant_msg.split("🩺 **Clinical Summary**")[-1].strip()
90
  messages.append({"role": "user", "content": user_msg})
91
- messages.append({"role": "assistant", "content": raw_assistant_msg})
 
 
92
  messages.append({"role": "user", "content": message})
93
 
94
- formatted_prompt = self.tokenizer.apply_chat_template(
95
- messages, tokenize=False, add_generation_prompt=True,
 
 
 
96
  )
97
 
98
- # THE HACK IS GONE: Standard tokenization without any 'images' argument.
99
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
100
-
101
- # THE HACK IS GONE: Standard generation call.
 
 
 
 
 
 
 
 
 
 
 
102
  generation_kwargs = {
103
  **inputs,
104
  "max_new_tokens": max_tokens,
@@ -106,30 +126,48 @@ class SinaReasonMedicalChat:
106
  "top_p": top_p,
107
  "do_sample": True,
108
  "pad_token_id": self.tokenizer.eos_token_id,
 
 
109
  }
110
 
111
- output = self.model.generate(**generation_kwargs)[0]
112
- full_response = self.tokenizer.decode(output[inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
113
 
114
- # Extract thinking and clinical summary
115
- thinking, response = self.extract_thinking_and_response(full_response)
 
 
116
 
117
- # Format the final display
118
- final_display = ""
119
- if thinking:
120
- final_display += f"""🧠 **Medical Reasoning Process**
121
- <details>
122
- <summary>🔍 Click to view detailed thinking process</summary>
123
- *{thinking}*
124
- </details>
125
- ---
126
- """
127
-
128
- final_display += f"""🩺 **Clinical Summary**
129
- {response}"""
130
 
131
- new_history = history + [[message, final_display]]
132
- return "", new_history
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  # Initialize the medical chat model
@@ -137,7 +175,8 @@ medical_chat_model = SinaReasonMedicalChat()
137
 
138
  def respond(message, history, max_tokens, temperature, top_p):
139
  """Gradio response function for medical reasoning"""
140
- return medical_chat_model.medical_chat(message, history, max_tokens, temperature, top_p)
 
141
 
142
  # Custom CSS for medical interface
143
  css = """
@@ -322,7 +361,6 @@ with gr.Blocks(css=css, title="SinaReason Medical Reasoning", theme=gr.themes.So
322
  </div>
323
  """)
324
 
325
-
326
  # Launch configuration for HF Spaces
327
  if __name__ == "__main__":
328
  demo.launch(
 
1
  import gradio as gr
2
+ import gradio as gr
3
  import torch
4
+ from transformers import AutoTokenizer, Mistral3ForConditionalGeneration, TextIteratorStreamer
5
+ from threading import Thread
6
  import re
7
+ import time
8
  import os
9
+ from typing import Iterator, List, Tuple
10
  import spaces
11
 
12
 
13
 
 
 
14
  # Model configuration
15
  MODEL_NAME = "yasserrmd/SinaReason-Magistral-2509"
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
20
  You are SinaReason, a medical reasoning assistant for educational and clinical support.
21
  Your goal is to carefully reason through clinical problems for a professional audience (clinicians, students).
22
  **Never provide medical advice directly to a patient.**
23
+
24
  First, draft your detailed thought process (inner monologue) inside <think> ... </think>.
25
  - Use this section to work through symptoms, differential diagnoses, and investigation plans.
26
  - Be explicit and thorough in your reasoning.
27
+
28
  After closing </think>, provide a clear, self-contained medical summary appropriate for a clinical professional.
29
  - Summarize the most likely diagnosis and your reasoning.
30
  - Suggest next steps for investigation or management.
31
  """
32
 
 
 
33
  class SinaReasonMedicalChat:
34
  def __init__(self):
35
  self.tokenizer = None
36
  self.model = None
 
 
 
37
  self.load_model()
38
+
39
  def load_model(self):
40
+ """Load the SinaReason medical model and tokenizer"""
41
  try:
42
+ print(f"Loading medical model: {MODEL_NAME}")
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ "mistralai/Magistral-Small-2509"
45
+ )
46
+
47
+ # Add padding token if not present
48
+ if self.tokenizer.pad_token is None:
49
+ self.tokenizer.pad_token = self.tokenizer.eos_token
50
 
51
+ self.model = Mistral3ForConditionalGeneration.from_pretrained(
52
  MODEL_NAME,
53
+ dtype=torch.bfloat16
 
54
  )
55
 
56
+
57
+
58
+ print("SinaReason medical model loaded successfully!")
 
59
 
60
  except Exception as e:
61
+ print(f"Error loading model: {e}")
62
  raise e
63
 
64
  def extract_thinking_and_response(self, text: str) -> Tuple[str, str]:
65
  """Extract thinking process from <think>...</think> tags and clinical response"""
66
+ # Look for the specific <think>...</think> pattern used by SinaReason
67
  think_pattern = r'<think>(.*?)</think>'
68
+
69
  thinking = ""
70
  response = text
71
 
 
77
  return thinking, response
78
 
79
  @spaces.GPU(duration=120)
80
+ def medical_chat_stream(self, message: str, history: List[List[str]], max_tokens: int = 1024,
81
+ temperature: float = 0.7, top_p: float = 0.95) -> Iterator[Tuple[str, List[List[str]]]]:
82
+ """Stream medical reasoning responses with thinking display without threading."""
83
+ self.model.to(DEVICE).eval()
84
  if not message.strip():
85
+ return
 
 
 
86
 
87
  # Apply the chat template with the medical system prompt
88
+ messages = [
89
+ {"role": "system", "content": "MEDICAL_SYSTEM_PROMPT"}, # Replace with your actual prompt
90
+ ]
91
+
92
+ # Add conversation history
93
  for user_msg, assistant_msg in history:
 
94
  messages.append({"role": "user", "content": user_msg})
95
+ messages.append({"role": "assistant", "content": assistant_msg})
96
+
97
+ # Add current message
98
  messages.append({"role": "user", "content": message})
99
 
100
+ # Apply chat template
101
+ prompt = self.tokenizer.apply_chat_template(
102
+ messages,
103
+ tokenize=False,
104
+ add_generation_prompt=True,
105
  )
106
 
107
+ # Tokenize input and move to the same device as the model
108
+ inputs = self.tokenizer(
109
+ text=prompt,
110
+ return_tensors="pt"
111
+ ).to(DEVICE)
112
+
113
+ # Setup streamer
114
+ streamer = TextIteratorStreamer(
115
+ self.tokenizer,
116
+ timeout=30.0,
117
+ skip_prompt=True,
118
+ skip_special_tokens=True
119
+ )
120
+
121
+ # Generation parameters optimized for medical reasoning
122
  generation_kwargs = {
123
  **inputs,
124
  "max_new_tokens": max_tokens,
 
126
  "top_p": top_p,
127
  "do_sample": True,
128
  "pad_token_id": self.tokenizer.eos_token_id,
129
+ "streamer": streamer,
130
+ "repetition_penalty": 1.1
131
  }
132
 
133
+ # Start generation directly.
134
+ # This will return immediately and the streamer will be populated in the background.
135
+ self.model.generate(**generation_kwargs)
136
 
137
+ # Stream the response
138
+ partial_response = ""
139
+ current_thinking = ""
140
+ current_response = ""
141
 
142
+ for new_token in streamer:
143
+ partial_response += new_token
144
+
145
+ # Extract thinking and response
146
+ thinking, response = self.extract_thinking_and_response(partial_response)
147
+
148
+ # Show thinking phase while it's being generated
149
+ if thinking and thinking != current_thinking:
150
+ current_thinking = thinking
151
+ display_text = f"🧠 **Medical Reasoning in Progress...**\n\n<details>\n<summary>🔍 Click to see thinking process</summary>\n\n*{current_thinking}*\n\n</details>"
152
+ new_history = history + [[message, display_text]]
153
+ yield "", new_history
154
+ time.sleep(0.1) # Smooth streaming
155
 
156
+ # Show clinical response as it's generated
157
+ if response and response != current_response:
158
+ current_response = response
159
+
160
+ final_display = f"""🧠 **Medical Reasoning Process**
161
+ <details>
162
+ <summary>🔍 Click to view detailed thinking process</summary>
163
+ *{current_thinking}*
164
+ </details>
165
+ ---
166
+ 🩺 **Clinical Summary**
167
+ {current_response}"""
168
+
169
+ new_history = history + [[message, final_display]]
170
+ yield "", new_history
171
 
172
 
173
  # Initialize the medical chat model
 
175
 
176
  def respond(message, history, max_tokens, temperature, top_p):
177
  """Gradio response function for medical reasoning"""
178
+ for response in medical_chat_model.medical_chat_stream(message, history, max_tokens, temperature, top_p):
179
+ yield response
180
 
181
  # Custom CSS for medical interface
182
  css = """
 
361
  </div>
362
  """)
363
 
 
364
  # Launch configuration for HF Spaces
365
  if __name__ == "__main__":
366
  demo.launch(