yasserrmd commited on
Commit
c9f8cc4
·
verified ·
1 Parent(s): 00657a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -38
app.py CHANGED
@@ -5,6 +5,7 @@ import re
5
  import os
6
  from typing import List, Tuple
7
  import spaces
 
8
 
9
 
10
 
@@ -25,39 +26,39 @@ After closing </think>, provide a clear, self-contained medical summary appropri
25
  - Suggest next steps for investigation or management.
26
  """
27
 
 
 
28
  class SinaReasonMedicalChat:
29
  def __init__(self):
30
  self.tokenizer = None
31
  self.model = None
 
 
 
32
  self.load_model()
33
 
34
  def load_model(self):
35
- """Load the SinaReason medical model and tokenizer"""
36
  try:
37
- print(f"Loading medical model: {MODEL_NAME}")
38
- self.tokenizer = AutoTokenizer.from_pretrained(
39
- MODEL_NAME,tokenizer_type="mistral"
40
- )
41
-
42
- # Add padding token if not present
43
- if self.tokenizer.pad_token is None:
44
- self.tokenizer.pad_token = self.tokenizer.eos_token
45
 
46
- self.model = Mistral3ForConditionalGeneration.from_pretrained(
47
- MODEL_NAME,
48
- dtype=torch.bfloat16
 
 
 
49
  )
50
 
51
- print("SinaReason medical model loaded successfully!")
52
 
53
  except Exception as e:
54
- print(f"Error loading model: {e}")
55
  raise e
56
 
57
  def extract_thinking_and_response(self, text: str) -> Tuple[str, str]:
58
  """Extract thinking process from <think>...</think> tags and clinical response"""
59
  think_pattern = r'<think>(.*?)</think>'
60
-
61
  thinking = ""
62
  response = text
63
 
@@ -71,53 +72,51 @@ class SinaReasonMedicalChat:
71
  @spaces.GPU(duration=120)
72
  def medical_chat(self, message: str, history: List[List[str]], max_tokens: int = 1024,
73
  temperature: float = 0.7, top_p: float = 0.95) -> Tuple[str, List[List[str]]]:
74
- """Generate medical reasoning responses without streaming."""
75
- self.model.to(DEVICE).eval()
 
 
76
  if not message.strip():
77
  return "", history
78
 
79
  # Apply the chat template with the medical system prompt
80
- messages = [
81
- {"role": "system", "content": MEDICAL_SYSTEM_PROMPT},
82
- ]
83
-
84
- # Add conversation history
85
  for user_msg, assistant_msg in history:
86
- # We need to reconstruct the full assistant message for the model
87
- # For simplicity, we'll just use the user message and the final response part
88
- # This part might need adjustment depending on how history is formatted
89
- # For this modification, let's assume the assistant message is just the clinical summary
90
- # A more robust solution might store the full generated text.
91
  raw_assistant_msg = assistant_msg.split("🩺 **Clinical Summary**")[-1].strip()
92
  messages.append({"role": "user", "content": user_msg})
93
  messages.append({"role": "assistant", "content": raw_assistant_msg})
94
-
95
-
96
- # Add current message
97
  messages.append({"role": "user", "content": message})
98
 
99
- tokenized = self.tokenizer.apply_chat_template(messages, return_dict=True)
 
 
 
 
 
100
 
101
- input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0)
102
- attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0)
 
 
 
 
103
 
104
  # Generation parameters
105
  generation_kwargs = {
106
- "input_ids": input_ids,
107
- "attention_mask": attention_mask,
108
  "max_new_tokens": max_tokens,
109
  "temperature": temperature,
110
  "top_p": top_p,
111
  "do_sample": True,
112
  "pad_token_id": self.tokenizer.eos_token_id,
113
- "repetition_penalty": 1.1
114
  }
115
 
116
  # Generate the full response
117
  output = self.model.generate(**generation_kwargs)[0]
118
 
119
- # Decode the response
120
- full_response = self.tokenizer.decode(output[len(tokenized.input_ids) : (-1 if output[-1] == self.tokenizer.eos_token_id else len(output) ) ])
121
 
122
  # Extract thinking and clinical summary
123
  thinking, response = self.extract_thinking_and_response(full_response)
 
5
  import os
6
  from typing import List, Tuple
7
  import spaces
8
+ from unsloth import FastLanguageModel
9
 
10
 
11
 
 
26
  - Suggest next steps for investigation or management.
27
  """
28
 
29
+
30
+
31
  class SinaReasonMedicalChat:
32
  def __init__(self):
33
  self.tokenizer = None
34
  self.model = None
35
+ # The PixtralProcessor requires an image argument, even if it's None.
36
+ # This is a mandatory part of the call signature.
37
+ self.dummy_image = None
38
  self.load_model()
39
 
40
  def load_model(self):
41
+ """Load the SinaReason medical model and tokenizer using Unsloth"""
42
  try:
43
+ print(f"Loading medical model with Unsloth: {MODEL_NAME}")
 
 
 
 
 
 
 
44
 
45
+ # Use FastLanguageModel from Unsloth to load the model and tokenizer
46
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
47
+ model_name=MODEL_NAME,
48
+ dtype=torch.bfloat16,
49
+ load_in_4bit=True, # Or False if you have enough VRAM for 16-bit
50
+ #device_map="auto",
51
  )
52
 
53
+ print("SinaReason medical model loaded successfully with Unsloth!")
54
 
55
  except Exception as e:
56
+ print(f"Error loading model with Unsloth: {e}")
57
  raise e
58
 
59
  def extract_thinking_and_response(self, text: str) -> Tuple[str, str]:
60
  """Extract thinking process from <think>...</think> tags and clinical response"""
61
  think_pattern = r'<think>(.*?)</think>'
 
62
  thinking = ""
63
  response = text
64
 
 
72
  @spaces.GPU(duration=120)
73
  def medical_chat(self, message: str, history: List[List[str]], max_tokens: int = 1024,
74
  temperature: float = 0.7, top_p: float = 0.95) -> Tuple[str, List[List[str]]]:
75
+ """Generate medical reasoning responses using the Unsloth model."""
76
+ # No need for model.to(DEVICE), Unsloth's device_map handles it.
77
+ self.model.to(DEVICE)
78
+ self.model.eval()
79
  if not message.strip():
80
  return "", history
81
 
82
  # Apply the chat template with the medical system prompt
83
+ messages = [{"role": "system", "content": MEDICAL_SYSTEM_PROMPT}]
 
 
 
 
84
  for user_msg, assistant_msg in history:
 
 
 
 
 
85
  raw_assistant_msg = assistant_msg.split("🩺 **Clinical Summary**")[-1].strip()
86
  messages.append({"role": "user", "content": user_msg})
87
  messages.append({"role": "assistant", "content": raw_assistant_msg})
 
 
 
88
  messages.append({"role": "user", "content": message})
89
 
90
+ # Format the prompt using the chat template
91
+ formatted_prompt = self.tokenizer.apply_chat_template(
92
+ messages,
93
+ tokenize=False,
94
+ add_generation_prompt=True,
95
+ )
96
 
97
+ # Tokenize the input, correctly passing images=None
98
+ inputs = self.tokenizer(
99
+ text=formatted_prompt,
100
+ images=self.dummy_image,
101
+ return_tensors="pt"
102
+ ).to(self.model.device)
103
 
104
  # Generation parameters
105
  generation_kwargs = {
106
+ **inputs,
107
+ "images": self.dummy_image, # This MUST be passed to model.generate
108
  "max_new_tokens": max_tokens,
109
  "temperature": temperature,
110
  "top_p": top_p,
111
  "do_sample": True,
112
  "pad_token_id": self.tokenizer.eos_token_id,
 
113
  }
114
 
115
  # Generate the full response
116
  output = self.model.generate(**generation_kwargs)[0]
117
 
118
+ # Decode only the newly generated tokens
119
+ full_response = self.tokenizer.decode(output[inputs.input_ids.shape[1]:], skip_special_tokens=True)
120
 
121
  # Extract thinking and clinical summary
122
  thinking, response = self.extract_thinking_and_response(full_response)