Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import pywt | |
| from skimage import exposure | |
| import gradio as gr | |
| from PIL import Image | |
| from io import BytesIO | |
| import matplotlib.pyplot as plt | |
| def process_tiff(file): | |
| # Read file content | |
| try: | |
| img = cv2.imread(file.name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_GRAYSCALE) | |
| if img is None: | |
| raise ValueError("Invalid or corrupted TIFF file") | |
| except Exception as e: | |
| raise gr.Error(f"Error reading file: {str(e)}") | |
| # Normalize to [0, 1] | |
| img_norm = img.astype(np.float32) / 65535.0 | |
| # Check dimensions for wavelet transform | |
| if img.shape[0] % 8 != 0 or img.shape[1] % 8 != 0: | |
| raise gr.Error("Image dimensions must be divisible by 8 for wavelet processing") | |
| try: | |
| # Wavelet decomposition | |
| coeffs = pywt.wavedec2(img_norm, 'bior1.3', level=3) | |
| cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs | |
| # Processing coefficients | |
| cD1 = pywt.threshold(cD1, 0.05*np.max(cD1), 'soft') | |
| cD2 = pywt.threshold(cD2, 0.07*np.max(cD2), 'soft') | |
| cH1 *= 1.2 | |
| cV1 *= 1.2 | |
| # Reconstruction | |
| recon = pywt.waverec2([cA3, (cH3, cV3, cD3), (cH2, cV2, cD2), (cH1, cV1, cD1)], 'bior1.3') | |
| recon = np.clip(recon, 0, 1) | |
| # CLAHE | |
| entropy = -np.sum(recon * np.log2(recon + 1e-7)) | |
| clahe_img = exposure.equalize_adapthist(recon, clip_limit=0.02 if entropy > 7 else 0.05, kernel_size=64) | |
| # Gamma correction | |
| p5, p95 = np.percentile(clahe_img, (5, 95)) | |
| gamma = 0.7 if (p95 - p5) < 0.3 else 0.9 | |
| gamma_img = exposure.adjust_gamma(clahe_img, gamma) | |
| # Sharpening | |
| sharp = cv2.detailEnhance( | |
| cv2.cvtColor((gamma_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR), | |
| sigma_s=12, | |
| sigma_r=0.15 | |
| ) | |
| sharp = cv2.cvtColor(sharp, cv2.COLOR_BGR2GRAY) | |
| except Exception as e: | |
| raise gr.Error(f"Processing error: {str(e)}") | |
| # Prepare outputs | |
| original_display = (np.clip(img / np.percentile(img, 99.5), 0, 1) * 255).astype(np.uint8) | |
| # Create histogram plot | |
| fig, ax = plt.subplots() | |
| ax.hist(sharp.ravel(), bins=256, range=(0, 255)) | |
| ax.set_title("Enhanced Histogram") | |
| ax.set_xlabel("Pixel Value") | |
| ax.set_ylabel("Frequency") | |
| # Convert plot to PIL Image | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| plt.close(fig) | |
| hist_img = Image.open(buf) | |
| return original_display, sharp, hist_img | |
| with gr.Blocks(title="MUSICA Enhancement", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🖼️ MUSICA X-ray Image Enhancement") | |
| gr.Markdown("Upload a 16-bit grayscale TIFF for wavelet-based enhancement") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| label="Input TIFF", | |
| file_types=["tif", "tiff"], | |
| height=100 | |
| ) | |
| submit_btn = gr.Button("Process", variant="primary") | |
| with gr.Column(): | |
| original_output = gr.Image( | |
| label="Original (Clipped)", | |
| height=400, | |
| type="numpy" | |
| ) | |
| with gr.Row(): | |
| enhanced_output = gr.Image( | |
| label="Enhanced Result", | |
| type="numpy", | |
| height=400 | |
| ) | |
| hist_output = gr.Image( | |
| label="Histogram", | |
| type="pil", | |
| height=400 | |
| ) | |
| submit_btn.click( | |
| process_tiff, | |
| inputs=file_input, | |
| outputs=[original_output, enhanced_output, hist_output] | |
| ) | |
| gr.Examples( | |
| examples=[["sample.tif"]], | |
| inputs=file_input, | |
| outputs=[original_output, enhanced_output, hist_output], | |
| fn=process_tiff | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |