pyamy's picture
Upload 31 files
0a0f923 verified
# xray_generator/inference.py
import os
import torch
import numpy as np
from PIL import Image
import logging
from typing import Union, List, Dict, Tuple, Optional
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from pathlib import Path
from .models.diffusion import DiffusionModel
from .utils.processing import get_device, apply_clahe
logger = logging.getLogger(__name__)
class XrayGenerator:
"""
Wrapper class for chest X-ray generation from text prompts.
"""
def __init__(
self,
model_path: str,
device: Optional[torch.device] = None,
tokenizer_name: str = "dmis-lab/biobert-base-cased-v1.1",
):
"""
Initialize the X-ray generator.
Args:
model_path: Path to the saved model weights
device: Device to run the model on (defaults to CUDA if available)
tokenizer_name: Name of the HuggingFace tokenizer
"""
self.device = device if device is not None else get_device()
self.model_path = Path(model_path)
# Load tokenizer
try:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
logger.info(f"Loaded tokenizer: {tokenizer_name}")
except Exception as e:
logger.error(f"Error loading tokenizer: {e}")
raise RuntimeError(f"Failed to load tokenizer: {e}")
# Load model
self.model = self._load_model()
# Set model to evaluation mode
self.model.vae.eval()
self.model.text_encoder.eval()
self.model.unet.eval()
logger.info("XrayGenerator initialized successfully")
def _load_model(self) -> DiffusionModel:
"""Load the diffusion model from saved weights."""
logger.info(f"Loading model from {self.model_path}")
try:
# Load checkpoint
checkpoint = torch.load(self.model_path, map_location=self.device)
# Import model components here to avoid circular imports
from .models.vae import MedicalVAE
from .models.text_encoder import MedicalTextEncoder
from .models.unet import DiffusionUNet
# Get model configuration
config = checkpoint.get('config', {})
latent_channels = config.get('latent_channels', 8)
model_channels = config.get('model_channels', 48)
# Initialize model components
vae = MedicalVAE(
in_channels=1,
out_channels=1,
latent_channels=latent_channels,
hidden_dims=[model_channels, model_channels*2, model_channels*4, model_channels*8]
).to(self.device)
text_encoder = MedicalTextEncoder(
model_name=config.get('text_model', "dmis-lab/biobert-base-cased-v1.1"),
projection_dim=768,
freeze_base=True
).to(self.device)
unet = DiffusionUNet(
in_channels=latent_channels,
model_channels=model_channels,
out_channels=latent_channels,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 2, 4, 8),
context_dim=768
).to(self.device)
# Load state dictionaries
if 'vae_state_dict' in checkpoint:
vae.load_state_dict(checkpoint['vae_state_dict'])
logger.info("Loaded VAE weights")
if 'text_encoder_state_dict' in checkpoint:
text_encoder.load_state_dict(checkpoint['text_encoder_state_dict'])
logger.info("Loaded text encoder weights")
if 'unet_state_dict' in checkpoint:
unet.load_state_dict(checkpoint['unet_state_dict'])
logger.info("Loaded UNet weights")
# Create diffusion model
model = DiffusionModel(
vae=vae,
unet=unet,
text_encoder=text_encoder,
scheduler_type=config.get('scheduler_type', "ddim"),
num_train_timesteps=config.get('num_train_timesteps', 1000),
beta_schedule=config.get('beta_schedule', "linear"),
prediction_type=config.get('prediction_type', "epsilon"),
guidance_scale=config.get('guidance_scale', 7.5),
device=self.device
)
return model
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
raise RuntimeError(f"Failed to load model: {e}")
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]],
height: int = 256,
width: int = 256,
num_inference_steps: int = 50,
guidance_scale: float = 10.0,
eta: float = 0.0,
output_type: str = "pil",
return_dict: bool = True,
seed: Optional[int] = None,
) -> Union[Dict, List[Image.Image]]:
"""
Generate chest X-rays from text prompts.
Args:
prompt: Text prompt(s) describing the X-ray
height: Output image height
width: Output image width
num_inference_steps: Number of denoising steps (more = higher quality, slower)
guidance_scale: Controls adherence to the text prompt (higher = more faithful)
eta: Controls randomness in sampling (0 = deterministic, 1 = stochastic)
output_type: Output format, one of ["pil", "np", "tensor"]
return_dict: Whether to return a dictionary with additional metadata
seed: Random seed for reproducible generation
Returns:
Images and optionally metadata
"""
# Set seed for reproducibility if provided
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Generate images
try:
results = self.model.sample(
text=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
tokenizer=self.tokenizer
)
# Get images
images_tensor = results['images']
# Convert to desired output format
if output_type == "tensor":
images = images_tensor
elif output_type == "np":
images = [img.cpu().numpy().transpose(1, 2, 0) for img in images_tensor]
elif output_type == "pil":
images = []
for img in images_tensor:
img_np = img.cpu().numpy().transpose(1, 2, 0)
img_np = (img_np * 255).astype(np.uint8)
if img_np.shape[-1] == 1: # Remove channel dimension for grayscale
img_np = img_np.squeeze(-1)
images.append(Image.fromarray(img_np))
else:
raise ValueError(f"Unknown output type: {output_type}")
# Return results
if return_dict:
return {
'images': images,
'latents': results['latents'].cpu(),
'prompt': prompt,
'parameters': {
'height': height,
'width': width,
'num_inference_steps': num_inference_steps,
'guidance_scale': guidance_scale,
'eta': eta,
'seed': seed
}
}
else:
return images
except Exception as e:
logger.error(f"Error generating images: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def save_images(self, images, output_dir, base_filename="generated", add_prompt=True, prompts=None):
"""
Save generated images to disk.
Args:
images: List of images (PIL, numpy, or tensor)
output_dir: Directory to save images
base_filename: Base name for saved files
add_prompt: Whether to include prompt in filename
prompts: List of prompts corresponding to images
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Convert to PIL if needed
if isinstance(images[0], torch.Tensor):
images_pil = []
for img in images:
img_np = img.cpu().numpy().transpose(1, 2, 0)
img_np = (img_np * 255).astype(np.uint8)
if img_np.shape[-1] == 1:
img_np = img_np.squeeze(-1)
images_pil.append(Image.fromarray(img_np))
images = images_pil
elif isinstance(images[0], np.ndarray):
images_pil = []
for img in images:
img_np = (img * 255).astype(np.uint8)
if img_np.shape[-1] == 1:
img_np = img_np.squeeze(-1)
images_pil.append(Image.fromarray(img_np))
images = images_pil
# Save each image
for i, img in enumerate(images):
# Create filename
if add_prompt and prompts is not None:
# Clean prompt for filename
prompt_str = prompts[i] if isinstance(prompts, list) else prompts
prompt_str = prompt_str.replace(" ", "_").replace(".", "").lower()
prompt_str = ''.join(c for c in prompt_str if c.isalnum() or c == '_')
prompt_str = prompt_str[:50] # Limit length
filename = f"{base_filename}_{i+1}_{prompt_str}.png"
else:
filename = f"{base_filename}_{i+1}.png"
# Save image
file_path = output_dir / filename
img.save(file_path)
logger.info(f"Saved image to {file_path}")