SebAustin commited on
Commit
b904a07
·
1 Parent(s): b6d4c5f

Speed: 4-bit default on Spaces, SDPA option, lower token limits; CUDA greedy fix

Browse files
config.py CHANGED
@@ -38,15 +38,17 @@ class ModelConfig:
38
  # Model parameters
39
  USE_GPU: bool = os.getenv("USE_GPU", "true").lower() == "true"
40
  MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "2048"))
41
- MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", "512"))
42
  TEMPERATURE: float = float(os.getenv("TEMPERATURE", "0.7"))
43
  TOP_P: float = 0.9
44
  TOP_K: int = 50
45
 
46
- # Performance (set LOAD_IN_4BIT=true in .env for faster inference and less VRAM)
 
47
  LOAD_IN_8BIT: bool = os.getenv("LOAD_IN_8BIT", "false").lower() == "true"
48
- LOAD_IN_4BIT: bool = os.getenv("LOAD_IN_4BIT", "false").lower() == "true"
49
- USE_FLASH_ATTENTION: bool = False
 
50
 
51
  @classmethod
52
  def get_device(cls) -> str:
 
38
  # Model parameters
39
  USE_GPU: bool = os.getenv("USE_GPU", "true").lower() == "true"
40
  MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "2048"))
41
+ MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", "384"))
42
  TEMPERATURE: float = float(os.getenv("TEMPERATURE", "0.7"))
43
  TOP_P: float = 0.9
44
  TOP_K: int = 50
45
 
46
+ # Performance: 4-bit greatly speeds up inference on GPU (e.g. HF Spaces). Default on when SPACE_ID is set.
47
+ _default_4bit = "true" if os.getenv("SPACE_ID") else "false"
48
  LOAD_IN_8BIT: bool = os.getenv("LOAD_IN_8BIT", "false").lower() == "true"
49
+ LOAD_IN_4BIT: bool = os.getenv("LOAD_IN_4BIT", _default_4bit).lower() == "true"
50
+ # Attention: "eager" (default), "sdpa" (faster on GPU), "flash_attention_2" (fastest, needs flash-attn; Gemma can be flaky)
51
+ ATTN_IMPLEMENTATION: str = os.getenv("ATTN_IMPLEMENTATION", "eager").lower()
52
 
53
  @classmethod
54
  def get_device(cls) -> str:
src/agents/care_agent.py CHANGED
@@ -37,7 +37,7 @@ class CareRecommendationAgent(BaseAgent):
37
  urgency_reasoning=urgency_reasoning
38
  )
39
 
40
- recommendations = self._generate(prompt, temperature=0.5, max_length=1536, max_new_tokens=512)
41
 
42
  # Extract structured components
43
  care_setting = self._extract_care_setting(recommendations, urgency_level)
 
37
  urgency_reasoning=urgency_reasoning
38
  )
39
 
40
+ recommendations = self._generate(prompt, temperature=0.5, max_length=1536, max_new_tokens=384)
41
 
42
  # Extract structured components
43
  care_setting = self._extract_care_setting(recommendations, urgency_level)
src/agents/communication_agent.py CHANGED
@@ -45,7 +45,7 @@ class CommunicationAgent(BaseAgent):
45
  care_recommendation=care_recommendation_text
46
  )
47
 
48
- report = self._generate(prompt, temperature=0.6, max_length=2048, max_new_tokens=768)
49
 
50
  # Create structured formatted report
51
  formatted_report = self._create_formatted_report(
 
45
  care_recommendation=care_recommendation_text
46
  )
47
 
48
+ report = self._generate(prompt, temperature=0.6, max_length=2048, max_new_tokens=512)
49
 
50
  # Create structured formatted report
51
  formatted_report = self._create_formatted_report(
src/agents/intake_agent.py CHANGED
@@ -48,7 +48,7 @@ class IntakeAgent(BaseAgent):
48
  })
49
 
50
  # Generate response
51
- response = self._generate(prompt, temperature=0.7, max_length=1024, max_new_tokens=384)
52
 
53
  self.conversation_history.append({
54
  "role": "assistant",
@@ -155,7 +155,7 @@ Provide a structured summary including:
155
 
156
  Be concise and focus on medically relevant information."""
157
 
158
- summary = self._generate(summary_prompt, temperature=0.5, max_length=1024, max_new_tokens=384)
159
 
160
  logger.info(f"{self.name} generated case summary")
161
  return summary
 
48
  })
49
 
50
  # Generate response
51
+ response = self._generate(prompt, temperature=0.7, max_length=1024, max_new_tokens=256)
52
 
53
  self.conversation_history.append({
54
  "role": "assistant",
 
155
 
156
  Be concise and focus on medically relevant information."""
157
 
158
+ summary = self._generate(summary_prompt, temperature=0.5, max_length=1024, max_new_tokens=256)
159
 
160
  logger.info(f"{self.name} generated case summary")
161
  return summary
src/agents/symptom_agent.py CHANGED
@@ -32,7 +32,7 @@ class SymptomAssessmentAgent(BaseAgent):
32
 
33
  # Generate symptom analysis with lower temperature for more focused analysis
34
  prompt = PromptTemplates.format_symptom_assessment(case_summary)
35
- analysis = self._generate(prompt, temperature=0.4, max_length=1536, max_new_tokens=512)
36
 
37
  # Extract key components
38
  primary_symptoms = self._extract_primary_symptoms(analysis)
 
32
 
33
  # Generate symptom analysis with lower temperature for more focused analysis
34
  prompt = PromptTemplates.format_symptom_assessment(case_summary)
35
+ analysis = self._generate(prompt, temperature=0.4, max_length=1536, max_new_tokens=384)
36
 
37
  # Extract key components
38
  primary_symptoms = self._extract_primary_symptoms(analysis)
src/models/medgemma_client.py CHANGED
@@ -80,9 +80,11 @@ class MedGemmaClient:
80
  "cache_dir": ModelConfig.MODEL_CACHE_DIR,
81
  "token": self.token,
82
  "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
83
- "low_cpu_mem_usage": True
84
  }
85
-
 
 
86
  # Use BitsAndBytesConfig for 4/8-bit (Gemma 3 etc. don't accept load_in_4bit kwarg). Only on GPU.
87
  if self.device == "cuda" and ModelConfig.LOAD_IN_8BIT:
88
  model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
 
80
  "cache_dir": ModelConfig.MODEL_CACHE_DIR,
81
  "token": self.token,
82
  "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
83
+ "low_cpu_mem_usage": True,
84
  }
85
+ if getattr(ModelConfig, "ATTN_IMPLEMENTATION", "eager") != "eager" and self.device == "cuda":
86
+ model_kwargs["attn_implementation"] = ModelConfig.ATTN_IMPLEMENTATION
87
+
88
  # Use BitsAndBytesConfig for 4/8-bit (Gemma 3 etc. don't accept load_in_4bit kwarg). Only on GPU.
89
  if self.device == "cuda" and ModelConfig.LOAD_IN_8BIT:
90
  model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)