Testing / app.py
HeshamAI's picture
Update app.py
84bc2f7 verified
raw
history blame
6.24 kB
import gradio as gr
import cv2
import numpy as np
import pandas as pd
import pydicom
import io
from PIL import Image
print("Starting imports completed...")
class DicomAnalyzer:
def __init__(self):
self.results = []
self.circle_diameter = 9
self.zoom_factor = 1.0
self.current_image = None
self.dicom_data = None
self.display_image = None
self.marks = [] # Store (x, y, diameter) for each mark
self.original_image = None
self.original_display = None
# Pan position
self.pan_x = 0
self.pan_y = 0
self.max_pan_x = 0
self.max_pan_y = 0
print("DicomAnalyzer initialized...")
def load_dicom(self, file):
try:
if file is None:
return None, "No file uploaded"
if hasattr(file, 'name'):
dicom_data = pydicom.dcmread(file.name)
else:
dicom_data = pydicom.dcmread(file)
image = dicom_data.pixel_array.astype(np.float32)
rescale_slope = getattr(dicom_data, 'RescaleSlope', 1)
rescale_intercept = getattr(dicom_data, 'RescaleIntercept', 0)
image = (image * rescale_slope) + rescale_intercept
self.current_image = image
self.original_image = image.copy()
self.dicom_data = dicom_data
self.display_image = self.normalize_image(image)
self.original_display = self.display_image.copy()
# Reset view on new image
self.reset_view()
print("DICOM file loaded successfully")
return self.display_image, "DICOM file loaded successfully"
except Exception as e:
print(f"Error loading DICOM file: {str(e)}")
return None, f"Error loading DICOM file: {str(e)}"
def normalize_image(self, image):
try:
# Improve image normalization
normalized = cv2.normalize(
image,
None,
alpha=0,
beta=255,
norm_type=cv2.NORM_MINMAX,
dtype=cv2.CV_8U
)
if len(normalized.shape) == 2:
normalized = cv2.cvtColor(normalized, cv2.COLOR_GRAY2RGB)
return normalized
except Exception as e:
print(f"Error normalizing image: {str(e)}")
return None
def reset_view(self):
self.zoom_factor = 1.0
self.pan_x = 0
self.pan_y = 0
if self.original_display is not None:
return self.update_display()
return None
def zoom_in(self, image):
print("Zooming in...")
self.zoom_factor = min(20.0, self.zoom_factor + 0.5)
return self.update_display()
def zoom_out(self, image):
print("Zooming out...")
self.zoom_factor = max(1.0, self.zoom_factor - 0.5)
return self.update_display()
def update_display(self):
try:
if self.original_display is None:
return None
# Calculate zoomed size
height, width = self.original_display.shape[:2]
new_height = int(height * self.zoom_factor)
new_width = int(width * self.zoom_factor)
# Create zoomed image
zoomed = cv2.resize(self.original_display, (new_width, new_height),
interpolation=cv2.INTER_CUBIC)
# Draw marks with ImageJ-like yellow circle
for x, y, diameter in self.marks:
zoomed_x = int(x * self.zoom_factor)
zoomed_y = int(y * self.zoom_factor)
zoomed_diameter = int(diameter * self.zoom_factor)
# Draw main circle like ImageJ
cv2.circle(zoomed,
(zoomed_x, zoomed_y),
zoomed_diameter // 2,
(0, 255, 255), # Yellow color
1, # Thinner line
lineType=cv2.LINE_AA)
# Add small points around circle perimeter (ImageJ style)
num_points = 8
for i in range(num_points):
angle = 2 * np.pi * i / num_points
point_x = int(zoomed_x + (zoomed_diameter/2) * np.cos(angle))
point_y = int(zoomed_y + (zoomed_diameter/2) * np.sin(angle))
cv2.circle(zoomed,
(point_x, point_y),
1,
(0, 255, 255),
-1,
lineType=cv2.LINE_AA)
# Extract visible portion
visible_height = min(height, new_height)
visible_width = min(width, new_width)
# Ensure pan values don't exceed bounds
self.pan_x = min(self.pan_x, max(0, new_width - width))
self.pan_y = min(self.pan_y, max(0, new_height - height))
visible = zoomed[
self.pan_y:self.pan_y + visible_height,
self.pan_x:self.pan_x + visible_width
]
return visible
except Exception as e:
print(f"Error updating display: {str(e)}")
return self.original_display
def handle_keyboard(self, key):
try:
print(f"Handling key press: {key}")
pan_amount = int(10 * self.zoom_factor)
if key == 'ArrowLeft':
self.pan_x = max(0, self.pan_x - pan_amount)
elif key == 'ArrowRight':
self.pan_x = min(self.max_pan_x, self.pan_x + pan_amount)
elif key == 'ArrowUp':
self.pan_y = max(0, self.pan_y - pan_amount)
elif key == 'ArrowDown':
self.pan_y = min(self.max_pan_y, self.pan_y + pan_amount)
return self.update_display()
except Exception as e:
print(f"Error handling keyboard input: {str(e)}")
return self.display_image