Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import sys | |
| import os | |
| import shutil | |
| import time | |
| from datetime import datetime | |
| import csv | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| sys.path.append('Utils') | |
| sys.path.append('model') | |
| from model.CBAM.reunet_cbam import reunet_cbam | |
| from model.transform import transforms | |
| from model.unet import UNET | |
| from Utils.area import pixel_to_sqft, process_and_overlay_image | |
| from Utils.convert import read_pansharpened_rgb | |
| def load_model(): | |
| model = reunet_cbam() | |
| model.load_state_dict(torch.load('latest.pth', map_location='cpu', weights_only = True)['model_state_dict']) | |
| model.eval() | |
| return model | |
| # Load model | |
| model = load_model() | |
| def refine_mask(mask, blur_kernel=5, threshold_value=127, morph_kernel_size=3, min_object_size=100): | |
| """Refine and clean the mask with Gaussian blur, thresholding, morphological operations, and small object removal.""" | |
| # Ensure mask is grayscale | |
| if len(mask.shape) > 2: | |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| # Apply Gaussian blur to smooth edges | |
| mask = cv2.GaussianBlur(mask, (blur_kernel, blur_kernel), 0) | |
| # Apply binary threshold | |
| _, mask = cv2.threshold(mask, threshold_value, 255, cv2.THRESH_BINARY) | |
| # Apply morphological operations (opening and closing) | |
| kernel = np.ones((morph_kernel_size, morph_kernel_size), np.uint8) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| # Remove small objects based on area | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) | |
| for i in range(1, num_labels): | |
| if stats[i, cv2.CC_STAT_AREA] < min_object_size: | |
| mask[labels == i] = 0 | |
| return mask | |
| # save to dir func | |
| base = os.getcwd() | |
| # Define subdirectories | |
| UPLOAD_DIR = os.path.join(base,"Images") | |
| MASK_DIR = os.path.join(base,"Masks") | |
| CSV_LOG_PATH = "image_log.csv" | |
| # Create directories with read and write permissions | |
| for directory in [UPLOAD_DIR, MASK_DIR]: | |
| os.makedirs(directory, exist_ok=True) | |
| def predict(image): | |
| with torch.no_grad(): | |
| output = model(image.unsqueeze(0)) | |
| return output.squeeze().cpu().numpy() | |
| def split_image(image, patch_size=512): | |
| h, w, _ = image.shape | |
| patches = [] | |
| for y in range(0, h, patch_size): | |
| for x in range(0, w, patch_size): | |
| patch = image[y:min(y+patch_size, h), x:min(x+patch_size, w)] | |
| patches.append((f"patch_{y}_{x}.png", patch)) | |
| return patches | |
| def merge(patch_folder, dest_image='out.png', image_shape=None): | |
| merged = np.zeros(image_shape[:-1] + (3,), dtype=np.uint8) | |
| for filename in os.listdir(patch_folder): | |
| if filename.endswith(".png"): | |
| patch_path = os.path.join(patch_folder, filename) | |
| patch = cv2.imread(patch_path) | |
| patch_height, patch_width, _ = patch.shape | |
| # Extract patch coordinates from filename | |
| parts = filename.split("_") | |
| x, y = None, None | |
| for part in parts: | |
| if part.endswith(".png"): | |
| x = int(part.split(".")[0]) | |
| elif part.isdigit(): | |
| y = int(part) | |
| if x is None or y is None: | |
| raise ValueError(f"Invalid filename: {filename}") | |
| # Check if patch fits within image boundaries | |
| if x + patch_width > image_shape[1] or y + patch_height > image_shape[0]: | |
| # Adjust patch position to fit within image boundaries | |
| if x + patch_width > image_shape[1]: | |
| x = image_shape[1] - patch_width | |
| if y + patch_height > image_shape[0]: | |
| y = image_shape[0] - patch_height | |
| # Merge patch into the main image | |
| merged[y:y+patch_height, x:x+patch_width, :] = patch | |
| cv2.imwrite(dest_image, merged) | |
| return merged | |
| def process_large_image(model, image_path, patch_size=512): | |
| # Read the image | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| raise ValueError(f"Failed to read image from {image_path}") | |
| h, w, _ = img.shape | |
| st.write(f"Processing image of size {w}x{h}") | |
| # Split the image into patches | |
| patches = split_image(img, patch_size) | |
| # Process each patch | |
| for filename, patch in patches: | |
| patch_pil = Image.fromarray(cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)) | |
| patch_transformed = transforms(patch_pil) | |
| prediction = predict(patch_transformed) | |
| mask = (prediction > 0.5).astype(np.uint8) * 255 | |
| # Save the mask patch | |
| mask_filepath = os.path.join(PRED_PATCHES_DIR, filename) | |
| cv2.imwrite(mask_filepath, mask) | |
| # Merge the predicted patches | |
| merged_mask = merge(PRED_PATCHES_DIR, dest_image='merged_mask.png', image_shape=img.shape) | |
| return merged_mask | |
| def log_image_details(image_id, image_filename, mask_filename): | |
| file_exists = os.path.exists(CSV_LOG_PATH) | |
| current_time = datetime.now() | |
| date = current_time.strftime('%Y-%m-%d') | |
| time = current_time.strftime('%H:%M:%S') | |
| with open(CSV_LOG_PATH, mode='a', newline='') as file: | |
| writer = csv.writer(file) | |
| if not file_exists: | |
| writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) | |
| # Get the next S.No | |
| if file_exists: | |
| with open(CSV_LOG_PATH, mode='r') as f: | |
| reader = csv.reader(f) | |
| sno = sum(1 for row in reader) | |
| else: | |
| sno = 1 | |
| writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) | |
| def upload_page(): | |
| if 'file_uploaded' not in st.session_state: | |
| st.session_state.file_uploaded = False | |
| if 'filename' not in st.session_state: | |
| st.session_state.filename = None | |
| if 'mask_filename' not in st.session_state: | |
| st.session_state.mask_filename = None | |
| image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) | |
| if image is not None and not st.session_state.file_uploaded: | |
| try: | |
| bytes_data = image.getvalue() | |
| timestamp = int(time.time()) | |
| original_filename = image.name | |
| file_extension = os.path.splitext(original_filename)[1].lower() | |
| if file_extension in ['.tiff', '.tif']: | |
| filename = f"image_{timestamp}.tif" | |
| converted_filename = f"image_{timestamp}_converted.png" | |
| else: | |
| filename = f"image_{timestamp}.png" | |
| converted_filename = filename | |
| filepath = os.path.join(UPLOAD_DIR, filename) | |
| converted_filepath = os.path.join(UPLOAD_DIR, converted_filename) | |
| with open(filepath, "wb") as f: | |
| f.write(bytes_data) | |
| if file_extension in ['.tiff', '.tif']: | |
| st.info('Processing GeoTIFF image...') | |
| rgb_image = read_pansharpened_rgb(filepath) | |
| cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) | |
| st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}') | |
| img = Image.open(converted_filepath) | |
| else: | |
| img = Image.open(filepath) | |
| img.save(converted_filepath) | |
| if os.path.exists(converted_filepath): | |
| st.success(f"Image saved successfully: {converted_filepath}") | |
| file_size = os.path.getsize(converted_filepath) | |
| st.write(f"File size: {file_size} bytes") | |
| else: | |
| st.error(f"Failed to save image: {converted_filepath}") | |
| st.image(img, caption='Uploaded Image', use_column_width=True) | |
| st.success(f'Image processed and saved as {converted_filename}') | |
| st.session_state.filename = converted_filename | |
| img_array = np.array(img) | |
| if img_array.shape[0] > 650 or img_array.shape[1] > 650: | |
| st.info('Large image detected. Using patch-based processing.') | |
| with st.spinner('Analyzing large image...'): | |
| full_mask = process_large_image(model, converted_filepath) | |
| else: | |
| st.info('Small image detected. Processing whole image at once.') | |
| with st.spinner('Analyzing image...'): | |
| img_transformed = transforms(img) | |
| prediction = predict(img_transformed) | |
| full_mask = (prediction > 0.5).astype(np.uint8) * 255 | |
| full_mask = refine_mask(full_mask)#----------------------------------------------------------------------- | |
| mask_filename = f"mask_{timestamp}.png" | |
| mask_filepath = os.path.join(MASK_DIR, mask_filename) | |
| cv2.imwrite(mask_filepath, full_mask) | |
| st.session_state.mask_filename = mask_filename | |
| log_image_details(timestamp, converted_filename, mask_filename) | |
| st.session_state.file_uploaded = True | |
| st.success("Image processed successfully") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.error("Please check the logs for more details.") | |
| print(f"Error in upload_page: {str(e)}") | |
| if st.session_state.file_uploaded and st.button('View result'): | |
| if st.session_state.filename is None: | |
| st.error("Please upload an image before viewing the result.") | |
| else: | |
| st.success('Image analyzed') | |
| st.session_state.page = 'result' | |
| st.rerun() | |
| def result_page(): | |
| st.title('Analysis Result') | |
| if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: | |
| st.error("No image or mask file found. Please upload and process an image first.") | |
| if st.button('Back to Upload'): | |
| st.session_state.page = 'upload' | |
| st.session_state.file_uploaded = False | |
| st.session_state.filename = None | |
| st.session_state.mask_filename = None | |
| st.rerun() | |
| return | |
| col1, col2 = st.columns(2) | |
| # Display original image | |
| original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) | |
| if os.path.exists(original_img_path): | |
| original_img = Image.open(original_img_path) | |
| col1.image(original_img, caption='Original Image', use_column_width=True) | |
| else: | |
| col1.error(f"Original image file not found: {original_img_path}") | |
| # Display predicted mask | |
| mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename) | |
| if os.path.exists(mask_path): | |
| mask = Image.open(mask_path) | |
| col2.image(mask, caption='Predicted Mask', use_column_width=True) | |
| else: | |
| col2.error(f"Predicted mask file not found: {mask_path}") | |
| st.subheader("Overlay with Area of Buildings (sqft)") | |
| # Display overlayed image | |
| if os.path.exists(original_img_path) and os.path.exists(mask_path): | |
| original_np = cv2.imread(original_img_path) | |
| mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| # Ensure mask is binary | |
| _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) | |
| # Resize mask to match original image size if necessary | |
| if original_np.shape[:2] != mask_np.shape[:2]: | |
| mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) | |
| # Process and overlay image | |
| overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') | |
| st.image(overlay_img, caption='Overlay Image', use_column_width=True) | |
| else: | |
| st.error("Image or mask file not found for overlay.") | |
| if st.button('Back to Upload'): | |
| st.session_state.page = 'upload' | |
| st.session_state.file_uploaded = False | |
| st.session_state.filename = None | |
| st.session_state.mask_filename = None | |
| st.rerun() | |
| def main(): | |
| st.title('Building area estimation') | |
| if 'page' not in st.session_state: | |
| st.session_state.page = 'upload' | |
| if st.session_state.page == 'upload': | |
| upload_page() | |
| elif st.session_state.page == 'result': | |
| result_page() | |
| if __name__ == '__main__': | |
| main() |