Spaces:
Sleeping
Sleeping
SebAustin commited on
Commit ·
b904a07
1
Parent(s): b6d4c5f
Speed: 4-bit default on Spaces, SDPA option, lower token limits; CUDA greedy fix
Browse files- config.py +6 -4
- src/agents/care_agent.py +1 -1
- src/agents/communication_agent.py +1 -1
- src/agents/intake_agent.py +2 -2
- src/agents/symptom_agent.py +1 -1
- src/models/medgemma_client.py +4 -2
config.py
CHANGED
|
@@ -38,15 +38,17 @@ class ModelConfig:
|
|
| 38 |
# Model parameters
|
| 39 |
USE_GPU: bool = os.getenv("USE_GPU", "true").lower() == "true"
|
| 40 |
MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "2048"))
|
| 41 |
-
MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", "
|
| 42 |
TEMPERATURE: float = float(os.getenv("TEMPERATURE", "0.7"))
|
| 43 |
TOP_P: float = 0.9
|
| 44 |
TOP_K: int = 50
|
| 45 |
|
| 46 |
-
# Performance
|
|
|
|
| 47 |
LOAD_IN_8BIT: bool = os.getenv("LOAD_IN_8BIT", "false").lower() == "true"
|
| 48 |
-
LOAD_IN_4BIT: bool = os.getenv("LOAD_IN_4BIT",
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
@classmethod
|
| 52 |
def get_device(cls) -> str:
|
|
|
|
| 38 |
# Model parameters
|
| 39 |
USE_GPU: bool = os.getenv("USE_GPU", "true").lower() == "true"
|
| 40 |
MAX_LENGTH: int = int(os.getenv("MAX_LENGTH", "2048"))
|
| 41 |
+
MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", "384"))
|
| 42 |
TEMPERATURE: float = float(os.getenv("TEMPERATURE", "0.7"))
|
| 43 |
TOP_P: float = 0.9
|
| 44 |
TOP_K: int = 50
|
| 45 |
|
| 46 |
+
# Performance: 4-bit greatly speeds up inference on GPU (e.g. HF Spaces). Default on when SPACE_ID is set.
|
| 47 |
+
_default_4bit = "true" if os.getenv("SPACE_ID") else "false"
|
| 48 |
LOAD_IN_8BIT: bool = os.getenv("LOAD_IN_8BIT", "false").lower() == "true"
|
| 49 |
+
LOAD_IN_4BIT: bool = os.getenv("LOAD_IN_4BIT", _default_4bit).lower() == "true"
|
| 50 |
+
# Attention: "eager" (default), "sdpa" (faster on GPU), "flash_attention_2" (fastest, needs flash-attn; Gemma can be flaky)
|
| 51 |
+
ATTN_IMPLEMENTATION: str = os.getenv("ATTN_IMPLEMENTATION", "eager").lower()
|
| 52 |
|
| 53 |
@classmethod
|
| 54 |
def get_device(cls) -> str:
|
src/agents/care_agent.py
CHANGED
|
@@ -37,7 +37,7 @@ class CareRecommendationAgent(BaseAgent):
|
|
| 37 |
urgency_reasoning=urgency_reasoning
|
| 38 |
)
|
| 39 |
|
| 40 |
-
recommendations = self._generate(prompt, temperature=0.5, max_length=1536, max_new_tokens=
|
| 41 |
|
| 42 |
# Extract structured components
|
| 43 |
care_setting = self._extract_care_setting(recommendations, urgency_level)
|
|
|
|
| 37 |
urgency_reasoning=urgency_reasoning
|
| 38 |
)
|
| 39 |
|
| 40 |
+
recommendations = self._generate(prompt, temperature=0.5, max_length=1536, max_new_tokens=384)
|
| 41 |
|
| 42 |
# Extract structured components
|
| 43 |
care_setting = self._extract_care_setting(recommendations, urgency_level)
|
src/agents/communication_agent.py
CHANGED
|
@@ -45,7 +45,7 @@ class CommunicationAgent(BaseAgent):
|
|
| 45 |
care_recommendation=care_recommendation_text
|
| 46 |
)
|
| 47 |
|
| 48 |
-
report = self._generate(prompt, temperature=0.6, max_length=2048, max_new_tokens=
|
| 49 |
|
| 50 |
# Create structured formatted report
|
| 51 |
formatted_report = self._create_formatted_report(
|
|
|
|
| 45 |
care_recommendation=care_recommendation_text
|
| 46 |
)
|
| 47 |
|
| 48 |
+
report = self._generate(prompt, temperature=0.6, max_length=2048, max_new_tokens=512)
|
| 49 |
|
| 50 |
# Create structured formatted report
|
| 51 |
formatted_report = self._create_formatted_report(
|
src/agents/intake_agent.py
CHANGED
|
@@ -48,7 +48,7 @@ class IntakeAgent(BaseAgent):
|
|
| 48 |
})
|
| 49 |
|
| 50 |
# Generate response
|
| 51 |
-
response = self._generate(prompt, temperature=0.7, max_length=1024, max_new_tokens=
|
| 52 |
|
| 53 |
self.conversation_history.append({
|
| 54 |
"role": "assistant",
|
|
@@ -155,7 +155,7 @@ Provide a structured summary including:
|
|
| 155 |
|
| 156 |
Be concise and focus on medically relevant information."""
|
| 157 |
|
| 158 |
-
summary = self._generate(summary_prompt, temperature=0.5, max_length=1024, max_new_tokens=
|
| 159 |
|
| 160 |
logger.info(f"{self.name} generated case summary")
|
| 161 |
return summary
|
|
|
|
| 48 |
})
|
| 49 |
|
| 50 |
# Generate response
|
| 51 |
+
response = self._generate(prompt, temperature=0.7, max_length=1024, max_new_tokens=256)
|
| 52 |
|
| 53 |
self.conversation_history.append({
|
| 54 |
"role": "assistant",
|
|
|
|
| 155 |
|
| 156 |
Be concise and focus on medically relevant information."""
|
| 157 |
|
| 158 |
+
summary = self._generate(summary_prompt, temperature=0.5, max_length=1024, max_new_tokens=256)
|
| 159 |
|
| 160 |
logger.info(f"{self.name} generated case summary")
|
| 161 |
return summary
|
src/agents/symptom_agent.py
CHANGED
|
@@ -32,7 +32,7 @@ class SymptomAssessmentAgent(BaseAgent):
|
|
| 32 |
|
| 33 |
# Generate symptom analysis with lower temperature for more focused analysis
|
| 34 |
prompt = PromptTemplates.format_symptom_assessment(case_summary)
|
| 35 |
-
analysis = self._generate(prompt, temperature=0.4, max_length=1536, max_new_tokens=
|
| 36 |
|
| 37 |
# Extract key components
|
| 38 |
primary_symptoms = self._extract_primary_symptoms(analysis)
|
|
|
|
| 32 |
|
| 33 |
# Generate symptom analysis with lower temperature for more focused analysis
|
| 34 |
prompt = PromptTemplates.format_symptom_assessment(case_summary)
|
| 35 |
+
analysis = self._generate(prompt, temperature=0.4, max_length=1536, max_new_tokens=384)
|
| 36 |
|
| 37 |
# Extract key components
|
| 38 |
primary_symptoms = self._extract_primary_symptoms(analysis)
|
src/models/medgemma_client.py
CHANGED
|
@@ -80,9 +80,11 @@ class MedGemmaClient:
|
|
| 80 |
"cache_dir": ModelConfig.MODEL_CACHE_DIR,
|
| 81 |
"token": self.token,
|
| 82 |
"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
|
| 83 |
-
"low_cpu_mem_usage": True
|
| 84 |
}
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
# Use BitsAndBytesConfig for 4/8-bit (Gemma 3 etc. don't accept load_in_4bit kwarg). Only on GPU.
|
| 87 |
if self.device == "cuda" and ModelConfig.LOAD_IN_8BIT:
|
| 88 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
| 80 |
"cache_dir": ModelConfig.MODEL_CACHE_DIR,
|
| 81 |
"token": self.token,
|
| 82 |
"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
|
| 83 |
+
"low_cpu_mem_usage": True,
|
| 84 |
}
|
| 85 |
+
if getattr(ModelConfig, "ATTN_IMPLEMENTATION", "eager") != "eager" and self.device == "cuda":
|
| 86 |
+
model_kwargs["attn_implementation"] = ModelConfig.ATTN_IMPLEMENTATION
|
| 87 |
+
|
| 88 |
# Use BitsAndBytesConfig for 4/8-bit (Gemma 3 etc. don't accept load_in_4bit kwarg). Only on GPU.
|
| 89 |
if self.device == "cuda" and ModelConfig.LOAD_IN_8BIT:
|
| 90 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|