SynthCXR / synthcxr /prompt.py
gradientguild's picture
Upload folder using huggingface_hub
a4aa5c5 verified
"""Prompt builders for conditional inference."""
from __future__ import annotations
from dataclasses import dataclass, field
from .constants import KNOWN_CONDITIONS, SEVERITY_MODIFIERS
@dataclass
class ConditionConfig:
"""Configuration for a single inference run with specific conditions."""
name: str
conditions: list[str] = field(default_factory=list)
age: int | None = None
sex: str | None = None
view: str = "AP"
custom_prompt: str | None = None
severity: str | None = None
heart_scale: float = 1.0
left_lung_scale: float = 1.0
right_lung_scale: float = 1.0
@dataclass
class InferenceConfig:
"""Top-level configuration for the condition-inference script."""
num_samples: int = 10
num_steps: int = 50
height: int = 512
width: int = 512
cfg_scale: float = 4.0
seed: int = 0
conditions: list[ConditionConfig] = field(default_factory=list)
def build_condition_prompt(condition: ConditionConfig) -> str:
"""Build a CheXpert-style prompt from a ``ConditionConfig``."""
if condition.custom_prompt:
return condition.custom_prompt
view = condition.view.upper() if condition.view else "AP"
age_str = f"{condition.age}-year-old" if condition.age else ""
sex_str = condition.sex.lower() if condition.sex else ""
if age_str and sex_str:
demographics = f"a {age_str} {sex_str} patient"
elif age_str:
demographics = f"a {age_str} patient"
elif sex_str:
demographics = f"a {sex_str} patient"
else:
demographics = "a patient"
pathologies: list[str] = []
severity_prefix = ""
if condition.severity and condition.severity in SEVERITY_MODIFIERS:
severity_prefix = SEVERITY_MODIFIERS[condition.severity] + " "
for cond_key in condition.conditions:
cond_text = KNOWN_CONDITIONS.get(cond_key.lower(), cond_key)
if severity_prefix and not pathologies:
pathologies.append(severity_prefix + cond_text)
severity_prefix = ""
else:
pathologies.append(cond_text)
pathology_str = (
f"with {', '.join(pathologies)}" if pathologies else "with no significant abnormality"
)
return (
f"frontal {view} chest radiograph of {demographics} {pathology_str}. "
"The conditioning mask image provides three channels "
"(red=left lung, green=right lung, blue=heart). "
"Reconstruct a CheXpert-style chest X-ray that aligns "
"with the segmentation and follows the described pathology."
)