Spaces:
Paused
Paused
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import requests | |
| from datetime import datetime | |
| import io | |
| import os | |
| from urllib.parse import urljoin | |
| from bs4 import BeautifulSoup | |
| # Default parameters | |
| low_int = 10 | |
| high_int = 255 | |
| edge_thresh = 100 | |
| accum_thresh = 20 | |
| center_tol = 60 | |
| morph_dia = 2 | |
| min_rad = 40 | |
| def fetch_sdo_images(max_images, size="1024by960", tool="ccor1"): | |
| """Fetch SDO images from NOAA URL directory.""" | |
| try: | |
| # Construct the directory URL | |
| base_url = f"https://services.swpc.noaa.gov/images/animations/{tool}/" | |
| # Fetch the directory listing | |
| response = requests.get(base_url, timeout=10) | |
| if response.status_code != 200: | |
| return None, f"Failed to access directory {base_url}: Status {response.status_code}", 0 | |
| # Parse HTML with BeautifulSoup | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| links = soup.find_all('a') | |
| # Extract image filenames matching the pattern YYYYMMDD_HHMM_{tool}_{size}.jpg | |
| image_files = [] | |
| for link in links: | |
| href = link.get('href') | |
| if href and href.endswith(f"_{tool}_{size}.jpg"): | |
| # Check if the filename starts with a valid timestamp (YYYYMMDD_HHMM) | |
| if len(href) >= 12 and href[:8].isdigit() and href[9:13].isdigit(): | |
| image_files.append(href) | |
| # Sort images by timestamp (most recent first) | |
| image_files = sorted(image_files, key=lambda x: x[:13], reverse=True) | |
| total_images = len(image_files) | |
| if not image_files: | |
| return None, f"No images found with size={size}, tool={tool}.", total_images | |
| # Fetch up to max_images | |
| frames = [] | |
| for img_file in image_files[:max_images]: | |
| url = urljoin(base_url, img_file) | |
| try: | |
| img_response = requests.get(url, timeout=5) | |
| if img_response.status_code == 200: | |
| img = Image.open(io.BytesIO(img_response.content)).convert('L') # Convert to grayscale | |
| frames.append(np.array(img)) | |
| else: | |
| print(f"Failed to fetch {url}: Status {img_response.status_code}") | |
| except Exception as e: | |
| print(f"Error fetching {url}: {str(e)}") | |
| if not frames: | |
| return None, f"No valid images fetched.", total_images | |
| # Reverse frames to display from oldest to newest | |
| frames = frames[::-1] | |
| return frames, f"Fetched {len(frames)} images. Total available: {total_images}.", total_images | |
| except Exception as e: | |
| return None, f"Error fetching images: {str(e)}", 0 | |
| def extract_frames(gif_path): | |
| """Extract frames from a GIF and return as a list of numpy arrays.""" | |
| try: | |
| img = Image.open(gif_path) | |
| frames = [] | |
| while True: | |
| frame = img.convert('L') # Convert to grayscale | |
| frames.append(np.array(frame)) | |
| try: | |
| img.seek(img.tell() + 1) | |
| except EOFError: | |
| break | |
| return frames, None | |
| except Exception as e: | |
| return None, f"Error loading GIF: {str(e)}" | |
| def preprocess_frame(frame, lower_bound, upper_bound, morph_iterations): | |
| """Preprocess a frame: isolate mid-to-light pixels and enhance circular patterns.""" | |
| blurred = cv2.GaussianBlur(frame, (9, 9), 0) | |
| mask = cv2.inRange(blurred, lower_bound, upper_bound) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| enhanced = cv2.dilate(mask, kernel, iterations=morph_iterations) | |
| return enhanced | |
| def detect_circles(frame_diff, image_center, center_tolerance, param1, param2, min_radius=20, max_radius=200): | |
| """Detect circles in a frame difference image, centered at the Sun.""" | |
| circles = cv2.HoughCircles( | |
| frame_diff, | |
| cv2.HOUGH_GRADIENT, | |
| dp=1.5, | |
| minDist=100, | |
| param1=param1, | |
| param2=param2, | |
| minRadius=min_radius, | |
| maxRadius=max_radius | |
| ) | |
| if circles is not None: | |
| circles = np.round(circles[0, :]).astype("int") | |
| filtered_circles = [] | |
| for (x, y, r) in circles: | |
| if (abs(x - image_center[0]) < center_tolerance and | |
| abs(y - image_center[1]) < center_tolerance): | |
| filtered_circles.append((x, y, r)) | |
| return filtered_circles if filtered_circles else None | |
| return None | |
| def create_gif(frames, output_path, duration=0.5, scale_to_512=False): | |
| """Create a GIF from a list of frames, optionally scaling to 512x512.""" | |
| pil_frames = [Image.fromarray(frame) for frame in frames] | |
| if scale_to_512: | |
| pil_frames = [frame.resize((512, 512), Image.Resampling.LANCZOS) for frame in pil_frames] | |
| pil_frames[0].save( | |
| output_path, | |
| save_all=True, | |
| append_images=pil_frames[1:], | |
| duration=int(duration * 1000), # Duration in milliseconds | |
| loop=0 | |
| ) | |
| return output_path | |
| def handle_fetch(max_images, size, tool): | |
| """Fetch SDO images and return frames for preview and state.""" | |
| frames, message, total_images = fetch_sdo_images(max_images, size, tool) | |
| if frames is None: | |
| return message, [], frames, total_images | |
| preview_frames = [Image.fromarray(frame) for frame in frames] | |
| return message, preview_frames, frames, total_images | |
| def analyze_images(frames, lower_bound, upper_bound, param1, param2, center_tolerance, morph_iterations, min_rad, display_mode, scale_to_512): | |
| """Analyze frames for concentric circles, highlighting growing series.""" | |
| try: | |
| if not frames or len(frames) < 2: | |
| return "At least 2 frames are required for analysis.", [], None | |
| # Determine image center | |
| height, width = frames[0].shape | |
| image_center = (width // 2, height // 2) | |
| min_radius = int(min_rad) | |
| max_radius = min(height, width) // 2 | |
| # Process frames and detect circles | |
| all_circle_data = [] | |
| for i in range(len(frames) - 1): | |
| frame1 = preprocess_frame(frames[i], lower_bound, upper_bound, morph_iterations) | |
| frame2 = preprocess_frame(frames[i + 1], lower_bound, upper_bound, morph_iterations) | |
| frame_diff = cv2.absdiff(frame2, frame1) | |
| frame_diff = cv2.convertScaleAbs(frame_diff, alpha=3.0, beta=0) | |
| circles = detect_circles(frame_diff, image_center, center_tolerance, param1, param2, min_radius, max_radius) | |
| if circles: | |
| largest_circle = max(circles, key=lambda c: c[2]) | |
| x, y, r = largest_circle | |
| all_circle_data.append({ | |
| "frame": i + 1, | |
| "center": (x, y), | |
| "radius": r, | |
| "output_frame": frames[i + 1] | |
| }) | |
| # Find growing series (indicative of CMEs) | |
| growing_circle_data = [] | |
| current_series = [] | |
| if all_circle_data: | |
| current_series.append(all_circle_data[0]) | |
| for i in range(1, len(all_circle_data)): | |
| if all_circle_data[i]["radius"] > current_series[-1]["radius"]: | |
| current_series.append(all_circle_data[i]) | |
| else: | |
| if len(current_series) > len(growing_circle_data): | |
| growing_circle_data = current_series.copy() | |
| current_series = [all_circle_data[i]] | |
| if len(current_series) > len(growing_circle_data): | |
| growing_circle_data = current_series.copy() | |
| growing_frames = set(c["frame"] for c in growing_circle_data) | |
| results = [] | |
| report = f"Analysis Report (as of {datetime.now().strftime('%I:%M %p PDT, %B %d, %Y')}):\n" | |
| # Prepare output based on display mode | |
| if display_mode == "Raw Frames": | |
| results = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)) for frame in frames] | |
| elif display_mode == "CME Detection": | |
| for i, frame in enumerate(frames): | |
| if i + 1 in growing_frames: | |
| for c in growing_circle_data: | |
| if c["frame"] == i + 1: | |
| output_frame = cv2.cvtColor(c["output_frame"], cv2.COLOR_GRAY2RGB) | |
| cv2.circle(output_frame, c["center"], c["radius"], (255, 255, 0), 2) # Yellow for CME | |
| results.append(Image.fromarray(output_frame)) | |
| break | |
| else: | |
| results.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB))) | |
| elif display_mode == "All Detection": | |
| for i, frame in enumerate(frames): | |
| if i + 1 in [c["frame"] for c in all_circle_data]: | |
| for c in all_circle_data: | |
| if c["frame"] == i + 1: | |
| output_frame = cv2.cvtColor(c["output_frame"], cv2.COLOR_GRAY2RGB) | |
| cv2.circle(output_frame, c["center"], c["radius"], (0, 255, 0), 2) # Green for detected | |
| if c["frame"] in growing_frames: | |
| cv2.circle(output_frame, c["center"], c["radius"] + 2, (255, 255, 0), 2) # Yellow for CME | |
| results.append(Image.fromarray(output_frame)) | |
| break | |
| else: | |
| results.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB))) | |
| # Generate report | |
| if all_circle_data: | |
| report += f"\nAll Frames with Detected Circles ({len(all_circle_data)} frames):\n" | |
| for c in all_circle_data: | |
| report += f"Frame {c['frame']}: Center at {c['center']}, Radius {c['radius']} pixels\n" | |
| else: | |
| report += "No circles detected.\n" | |
| if growing_circle_data: | |
| report += f"\nSeries of Frames with Growing Circles (CMEs) ({len(growing_circle_data)} frames):\n" | |
| for c in growing_circle_data: | |
| report += f"Frame {c['frame']}: Center at {c['center']}, Radius {c['radius']} pixels\n" | |
| report += "\nConclusion: Growing concentric circles detected, indicative of a potential Earth-directed CME." | |
| else: | |
| report += "\nNo growing concentric circles detected. CME may not be Earth-directed." | |
| # Create GIF if results exist | |
| gif_path = None | |
| if results: | |
| gif_frames = [np.array(img) for img in results] | |
| gif_path = "output.gif" | |
| create_gif(gif_frames, gif_path, scale_to_512=scale_to_512) | |
| return report, results, gif_path | |
| except Exception as e: | |
| return f"Error during analysis: {str(e)}", [], None | |
| def process_input(gif_file, max_images, size, tool, lower_bound, upper_bound, param1, param2, center_tolerance, morph_iterations, min_rad, display_mode, fetched_frames_state, scale_to_512): | |
| """Process either uploaded GIF or fetched SDO images.""" | |
| if gif_file: | |
| frames, error = extract_frames(gif_file.name) | |
| total_images = 0 # No total images when using uploaded GIF | |
| if error: | |
| return error, [], None, [], total_images, None | |
| else: | |
| frames = fetched_frames_state | |
| _, _, total_images = fetch_sdo_images(max_images, size, tool) # Get total_images | |
| if not frames: | |
| return "No fetched frames available. Please fetch images first.", [], None, [], total_images, None | |
| # Preview all frames | |
| preview = [Image.fromarray(frame) for frame in frames] if frames else [] | |
| # Analyze frames | |
| report, results, gif_path = analyze_images( | |
| frames, lower_bound, upper_bound, param1, param2, center_tolerance, morph_iterations, min_rad, display_mode, scale_to_512 | |
| ) | |
| return report, results, gif_path, preview, total_images, gif_path | |
| def load_demo(): | |
| """Load demo GIF for analysis.""" | |
| gif_file = "./demo_gif.gif" | |
| if os.path.exists(gif_file): | |
| frames, error = extract_frames(gif_file) | |
| total_images = 0 # Demo GIF, so no directory total | |
| if error: | |
| return error, [], None, total_images | |
| else: | |
| return "Demo GIF not found.", [], None, 0 | |
| preview = [Image.fromarray(frame) for frame in frames] if frames else [] | |
| message = "Demo Loaded" | |
| return message, preview, frames, total_images | |
| # Gradio Blocks interface | |
| with gr.Blocks(title="Solar CME Detection") as demo: | |
| gr.Markdown(""" | |
| # Solar CME Detection | |
| Upload a GIF or fetch SDO images to detect concentric circles indicative of coronal mass ejections (CMEs). | |
| Yellow circles mark growing series (potential CMEs); green circles mark other detected features. | |
| """) | |
| # State to store fetched frames | |
| fetched_frames_state = gr.State(value=[]) | |
| with gr.Sidebar(open=False): | |
| gr.Markdown("### Analysis Parameters") | |
| size = gr.Textbox(label="Image Size", value="1024by960") | |
| tool = gr.Textbox(label="Instrument", value="ccor1") | |
| lower_bound = gr.Slider(minimum=0, maximum=255, value=low_int, step=1, label="Lower Intensity Bound (0-255)") | |
| upper_bound = gr.Slider(minimum=0, maximum=255, value=high_int, step=1, label="Upper Intensity Bound (0-255)") | |
| param1 = gr.Slider(minimum=10, maximum=200, value=edge_thresh, step=1, label="Hough Param1 (Edge Threshold)") | |
| param2 = gr.Slider(minimum=1, maximum=50, value=accum_thresh, step=1, label="Hough Param2 (Accumulator Threshold)") | |
| center_tolerance = gr.Slider(minimum=10, maximum=100, value=center_tol, step=1, label="Center Tolerance (Pixels)") | |
| morph_iterations = gr.Slider(minimum=1, maximum=5, value=morph_dia, step=1, label="Morphological Dilation Iterations") | |
| min_rad = gr.Slider(minimum=1, maximum=100, value=min_rad, step=1, label="Minimum Circle Radius") | |
| display_mode = gr.Dropdown( | |
| choices=["Raw Frames", "CME Detection", "All Detection"], | |
| value="All Detection", | |
| label="Display Mode" | |
| ) | |
| scale_to_512 = gr.Checkbox(label="Scale Output GIF to 512x512", value=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Input Options") | |
| demo_btn = gr.Button("Load Demo") | |
| gif_input = gr.File(label="Upload Solar GIF (optional)", file_types=[".gif"]) | |
| max_images = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Max Images to Fetch") | |
| fetch_button = gr.Button("Fetch Images from URL") | |
| analyze_button = gr.Button("Analyze") | |
| with gr.Column(): | |
| gr.Markdown("### Outputs") | |
| report = gr.Textbox(label="Analysis Report", lines=10) | |
| gif_viewer = gr.Image(label="Output GIF Preview", type="filepath") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gif_output = gr.File(label="Download Resulting GIF") | |
| gallery = gr.Gallery(label="Frames with Detected Circles (Green: Detected, Yellow: CME)") | |
| with gr.Column(): | |
| total_images = gr.Textbox(label="Total Images Available in Directory", value="0") | |
| preview = gr.Gallery(label="Input Preview (All Frames)") | |
| # Fetch button action | |
| fetch_button.click( | |
| fn=handle_fetch, | |
| inputs=[max_images, size, tool], | |
| outputs=[report, preview, fetched_frames_state, total_images] | |
| ) | |
| demo_btn.click( | |
| fn=load_demo, | |
| inputs=[], | |
| outputs=[report, preview, fetched_frames_state, total_images] | |
| ) | |
| # Analyze button action | |
| analyze_button.click( | |
| fn=process_input, | |
| inputs=[ | |
| gif_input, max_images, size, tool, | |
| lower_bound, upper_bound, param1, param2, center_tolerance, morph_iterations, min_rad, display_mode, fetched_frames_state, scale_to_512 | |
| ], | |
| outputs=[report, gallery, gif_output, preview, total_images, gif_viewer] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |