Spaces:
Running on Zero
Running on Zero
| """Prompt builders for conditional inference.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from .constants import KNOWN_CONDITIONS, SEVERITY_MODIFIERS | |
| class ConditionConfig: | |
| """Configuration for a single inference run with specific conditions.""" | |
| name: str | |
| conditions: list[str] = field(default_factory=list) | |
| age: int | None = None | |
| sex: str | None = None | |
| view: str = "AP" | |
| custom_prompt: str | None = None | |
| severity: str | None = None | |
| heart_scale: float = 1.0 | |
| left_lung_scale: float = 1.0 | |
| right_lung_scale: float = 1.0 | |
| class InferenceConfig: | |
| """Top-level configuration for the condition-inference script.""" | |
| num_samples: int = 10 | |
| num_steps: int = 50 | |
| height: int = 512 | |
| width: int = 512 | |
| cfg_scale: float = 4.0 | |
| seed: int = 0 | |
| conditions: list[ConditionConfig] = field(default_factory=list) | |
| def build_condition_prompt(condition: ConditionConfig) -> str: | |
| """Build a CheXpert-style prompt from a ``ConditionConfig``.""" | |
| if condition.custom_prompt: | |
| return condition.custom_prompt | |
| view = condition.view.upper() if condition.view else "AP" | |
| age_str = f"{condition.age}-year-old" if condition.age else "" | |
| sex_str = condition.sex.lower() if condition.sex else "" | |
| if age_str and sex_str: | |
| demographics = f"a {age_str} {sex_str} patient" | |
| elif age_str: | |
| demographics = f"a {age_str} patient" | |
| elif sex_str: | |
| demographics = f"a {sex_str} patient" | |
| else: | |
| demographics = "a patient" | |
| pathologies: list[str] = [] | |
| severity_prefix = "" | |
| if condition.severity and condition.severity in SEVERITY_MODIFIERS: | |
| severity_prefix = SEVERITY_MODIFIERS[condition.severity] + " " | |
| for cond_key in condition.conditions: | |
| cond_text = KNOWN_CONDITIONS.get(cond_key.lower(), cond_key) | |
| if severity_prefix and not pathologies: | |
| pathologies.append(severity_prefix + cond_text) | |
| severity_prefix = "" | |
| else: | |
| pathologies.append(cond_text) | |
| pathology_str = ( | |
| f"with {', '.join(pathologies)}" if pathologies else "with no significant abnormality" | |
| ) | |
| return ( | |
| f"frontal {view} chest radiograph of {demographics} {pathology_str}. " | |
| "The conditioning mask image provides three channels " | |
| "(red=left lung, green=right lung, blue=heart). " | |
| "Reconstruct a CheXpert-style chest X-ray that aligns " | |
| "with the segmentation and follows the described pathology." | |
| ) | |