Spaces:
Runtime error
Runtime error
| from typing import Tuple, Dict, List | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import zipfile | |
| import os | |
| import tempfile | |
| import json | |
| from datetime import datetime | |
| import io | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use("Agg") # Use non-interactive backend | |
| from utils.image_batch_classes import coco_class_mapping | |
| from test_combined_models import run_model_predictions, combine_and_filter_predictions | |
| # ===================================================================== | |
| # CONFIG | |
| # ===================================================================== | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Final combined classes exposed to the user (25 classes) | |
| FINAL_CLASSES = list(coco_class_mapping.keys()) | |
| # ===================================================================== | |
| # COCO / COMBINATION HELPERS | |
| # ===================================================================== | |
| def _filter_coco_by_classes(coco_json: Dict, allowed_classes: List[str]) -> Dict: | |
| """Return a new COCO dict filtered to given class names.""" | |
| id_to_name = {c["id"]: c["name"] for c in coco_json["categories"]} | |
| allowed_ids = {cid for cid, name in id_to_name.items() if name in allowed_classes} | |
| filtered_categories = [c for c in coco_json["categories"] if c["id"] in allowed_ids] | |
| filtered_annotations = [ | |
| a for a in coco_json["annotations"] if a["category_id"] in allowed_ids | |
| ] | |
| used_image_ids = {a["image_id"] for a in filtered_annotations} | |
| filtered_images = [img for img in coco_json["images"] if img["id"] in used_image_ids] | |
| return { | |
| **coco_json, | |
| "categories": filtered_categories, | |
| "annotations": filtered_annotations, | |
| "images": filtered_images, | |
| } | |
| # Predefined dark, vibrant colors for each class (RGB tuples, 0-1 range) | |
| CLASS_COLOR_MAP = { | |
| 'Border': '#8B0000', # Dark red | |
| 'Table': '#006400', # Dark green | |
| 'Diagram': '#00008B', # Dark blue | |
| 'Main script black': '#FF0000', # Bright red | |
| 'Main script coloured': '#FF4500', # Orange red | |
| 'Variant script black': '#8B008B', # Dark magenta | |
| 'Variant script coloured': '#FF1493', # Deep pink | |
| 'Historiated': '#FFD700', # Gold | |
| 'Inhabited': '#FF8C00', # Dark orange | |
| 'Zoo - Anthropomorphic': '#32CD32', # Lime green | |
| 'Embellished': '#FF00FF', # Magenta | |
| 'Plain initial- coloured': '#00CED1', # Dark turquoise | |
| 'Plain initial - Highlighted': '#00BFFF', # Deep sky blue | |
| 'Plain initial - Black': '#000000', # Black | |
| 'Page Number': '#DC143C', # Crimson | |
| 'Quire Mark': '#9932CC', # Dark orchid | |
| 'Running header': '#228B22', # Forest green | |
| 'Catchword': '#B22222', # Fire brick | |
| 'Gloss': '#4169E1', # Royal blue | |
| 'Illustrations': '#FF6347', # Tomato | |
| 'Column': '#2E8B57', # Sea green | |
| 'GraphicZone': '#8A2BE2', # Blue violet | |
| 'MusicLine': '#20B2AA', # Light sea green | |
| 'MusicZone': '#4682B4', # Steel blue | |
| 'Music': '#1E90FF', # Dodger blue | |
| } | |
| def _draw_coco_on_image( | |
| image_path: str, | |
| coco_json: Dict, | |
| allowed_classes: List[str], | |
| ) -> np.ndarray: | |
| """Draw combined COCO annotations on the image with dark, visible colors.""" | |
| from matplotlib.patches import Polygon, Rectangle | |
| import matplotlib.colors as mcolors | |
| coco_filtered = _filter_coco_by_classes(coco_json, allowed_classes) | |
| id_to_name = {c["id"]: c["name"] for c in coco_filtered["categories"]} | |
| if not coco_filtered["images"]: | |
| return np.array(Image.open(image_path).convert("RGB")) | |
| img_info = coco_filtered["images"][0] | |
| img_id = img_info["id"] | |
| anns = [a for a in coco_filtered["annotations"] if a["image_id"] == img_id] | |
| img = Image.open(image_path).convert("RGB") | |
| fig, ax = plt.subplots(1, 1, figsize=(10, 14)) | |
| ax.imshow(img) | |
| ax.axis("off") | |
| # Use predefined colors or fallback to dark colors | |
| def get_class_color(class_name: str) -> str: | |
| """Get color for a class, with fallback.""" | |
| if class_name in CLASS_COLOR_MAP: | |
| return CLASS_COLOR_MAP[class_name] | |
| # Fallback: use hash-based color for unknown classes | |
| import hashlib | |
| hash_val = int(hashlib.md5(class_name.encode()).hexdigest()[:8], 16) | |
| hue = (hash_val % 360) / 360.0 | |
| # Use dark, saturated colors | |
| rgb = mcolors.hsv_to_rgb([hue, 0.9, 0.7]) | |
| return mcolors.rgb2hex(rgb) | |
| # Track label positions to avoid overlap | |
| label_positions = [] | |
| def find_label_position(x, y, w, h, existing_positions, img_width, img_height): | |
| """Find a good position for label to avoid overlap.""" | |
| # Try top-left corner first | |
| label_w, label_h = 150, 30 # Approximate label size | |
| candidates = [ | |
| (x, y - label_h - 5), # Above top-left | |
| (x, y), # Top-left corner | |
| (x + w - label_w, y), # Top-right corner | |
| (x, y + h + 5), # Below bottom-left | |
| ] | |
| for pos_x, pos_y in candidates: | |
| # Check if position is within image bounds | |
| if pos_x < 0 or pos_y < 0 or pos_x + label_w > img_width or pos_y + label_h > img_height: | |
| continue | |
| # Check overlap with existing labels | |
| overlap = False | |
| for ex_x, ex_y in existing_positions: | |
| if abs(pos_x - ex_x) < label_w * 0.8 and abs(pos_y - ex_y) < label_h * 0.8: | |
| overlap = True | |
| break | |
| if not overlap: | |
| return pos_x, pos_y | |
| # If all positions overlap, use top-left anyway | |
| return x, y | |
| img_width, img_height = img.size | |
| for ann in anns: | |
| name = id_to_name[ann["category_id"]] | |
| color_hex = get_class_color(name) | |
| # Convert hex to RGB for matplotlib | |
| color_rgb = mcolors.hex2color(color_hex) | |
| segs = ann.get("segmentation", []) | |
| if segs and isinstance(segs, list) and len(segs[0]) >= 6: | |
| coords = segs[0] | |
| xs = coords[0::2] | |
| ys = coords[1::2] | |
| # Add semi-transparent fill | |
| poly = Polygon( | |
| list(zip(xs, ys)), | |
| closed=True, | |
| edgecolor=color_rgb, | |
| facecolor=color_rgb, | |
| linewidth=2.5, | |
| alpha=0.3, # Semi-transparent fill | |
| ) | |
| ax.add_patch(poly) | |
| # Also add edge outline with higher opacity | |
| poly_edge = Polygon( | |
| list(zip(xs, ys)), | |
| closed=True, | |
| edgecolor=color_rgb, | |
| facecolor="none", | |
| linewidth=2.5, | |
| alpha=0.8, # More opaque edge | |
| ) | |
| ax.add_patch(poly_edge) | |
| # Find good label position | |
| min_x, min_y = min(xs), min(ys) | |
| label_x, label_y = find_label_position( | |
| min_x, min_y, max(xs) - min_x, max(ys) - min_y, | |
| label_positions, img_width, img_height | |
| ) | |
| label_positions.append((label_x, label_y)) | |
| # Label with black text on white background for visibility | |
| ax.text( | |
| label_x, | |
| label_y, | |
| name, | |
| color='black', # Black text for visibility | |
| fontsize=9, # Slightly smaller to fit better | |
| fontweight='bold', | |
| bbox=dict( | |
| boxstyle="round,pad=0.3", | |
| facecolor="white", | |
| edgecolor=color_rgb, | |
| linewidth=2, | |
| alpha=0.7, # Fully opaque label background | |
| ), | |
| zorder=10, # Ensure labels are on top | |
| ) | |
| else: | |
| x, y, w, h = ann.get("bbox", [0, 0, 0, 0]) | |
| # Add semi-transparent fill | |
| rect = Rectangle( | |
| (x, y), | |
| w, | |
| h, | |
| edgecolor=color_rgb, | |
| facecolor=color_rgb, | |
| linewidth=2.5, | |
| alpha=0.3, # Semi-transparent fill | |
| ) | |
| ax.add_patch(rect) | |
| # Also add edge outline with higher opacity | |
| rect_edge = Rectangle( | |
| (x, y), | |
| w, | |
| h, | |
| edgecolor=color_rgb, | |
| facecolor="none", | |
| linewidth=2.5, | |
| alpha=0.8, # More opaque edge | |
| ) | |
| ax.add_patch(rect_edge) | |
| # Find good label position | |
| label_x, label_y = find_label_position( | |
| x, y, w, h, label_positions, img_width, img_height | |
| ) | |
| label_positions.append((label_x, label_y)) | |
| # Label with black text on white background for visibility | |
| ax.text( | |
| label_x, | |
| label_y, | |
| name, | |
| color='black', # Black text for visibility | |
| fontsize=9, # Slightly smaller to fit better | |
| fontweight='bold', | |
| bbox=dict( | |
| boxstyle="round,pad=0.3", | |
| facecolor="white", | |
| edgecolor=color_rgb, | |
| linewidth=2, | |
| alpha=0.7, # Fully opaque label background | |
| ), | |
| zorder=10, # Ensure labels are on top | |
| ) | |
| plt.tight_layout() | |
| tmp_path = os.path.join(tempfile.gettempdir(), "tmp_vis.png") | |
| fig.savefig(tmp_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return np.array(Image.open(tmp_path).convert("RGB")) | |
| def _stats_from_coco(coco_json: Dict) -> Dict[str, int]: | |
| """Return {class_name: count} from COCO annotations.""" | |
| id_to_name = {c["id"]: c["name"] for c in coco_json["categories"]} | |
| counts: Dict[str, int] = {} | |
| for ann in coco_json["annotations"]: | |
| name = id_to_name.get(ann["category_id"], f"cls_{ann['category_id']}") | |
| counts[name] = counts.get(name, 0) + 1 | |
| return counts | |
| def _stats_table(stats: Dict[str, int], image_name: str | None = None) -> pd.DataFrame: | |
| if not stats: | |
| return pd.DataFrame(columns=["Class", "Count"]) | |
| rows = [{"Class": k, "Count": v} for k, v in sorted(stats.items())] | |
| df = pd.DataFrame(rows) | |
| if image_name: | |
| df.insert(0, "Image", image_name) | |
| return df | |
| def _stats_graph(stats: Dict[str, int], title: str) -> str: | |
| if not stats: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.text(0.5, 0.5, "No detections", ha="center", va="center") | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| else: | |
| classes = list(sorted(stats.keys())) | |
| counts = [stats[c] for c in classes] | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| bars = ax.bar(range(len(classes)), counts, color="steelblue") | |
| ax.set_xticks(range(len(classes))) | |
| ax.set_xticklabels(classes, rotation=45, ha="right") | |
| ax.set_ylabel("Count") | |
| ax.set_title(title) | |
| for b, c in zip(bars, counts): | |
| ax.text( | |
| b.get_x() + b.get_width() / 2.0, | |
| b.get_height(), | |
| str(c), | |
| ha="center", | |
| va="bottom", | |
| ) | |
| out_path = os.path.join( | |
| tempfile.gettempdir(), | |
| f"stats_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png", | |
| ) | |
| fig.savefig(out_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out_path | |
| def _run_combined_on_path( | |
| image_path: str, conf: float, iou: float | |
| ) -> Tuple[Dict, np.ndarray]: | |
| """ | |
| Run the three models on image_path, combine via ImageBatch helpers, and return: | |
| - combined COCO json (already filtered to coco_class_mapping) | |
| - annotated image (all final classes drawn) | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Run 3 models and save predictions JSON | |
| labels_folders = run_model_predictions(image_path, tmp_dir) | |
| # Combine & filter to coco_class_mapping | |
| coco_json = combine_and_filter_predictions( | |
| image_path, labels_folders, output_json_path=None | |
| ) | |
| annotated = _draw_coco_on_image(image_path, coco_json, FINAL_CLASSES) | |
| return coco_json, annotated | |
| # ===================================================================== | |
| # COCO MERGE FOR BATCH | |
| # ===================================================================== | |
| def _merge_coco_list(coco_list: List[Dict]) -> Dict: | |
| """Merge multiple single-image COCO dicts into one.""" | |
| merged = { | |
| "info": {"description": "Combined predictions", "version": "1.0"}, | |
| "licenses": [], | |
| "images": [], | |
| "annotations": [], | |
| "categories": [], | |
| } | |
| # fixed categories from coco_class_mapping | |
| for name, cid in coco_class_mapping.items(): | |
| merged["categories"].append({"id": cid, "name": name, "supercategory": ""}) | |
| ann_id = 1 | |
| img_id = 1 | |
| for coco in coco_list: | |
| local_img_id_map = {} | |
| for img in coco["images"]: | |
| new_img = dict(img) | |
| new_img["id"] = img_id | |
| merged["images"].append(new_img) | |
| local_img_id_map[img["id"]] = img_id | |
| img_id += 1 | |
| for ann in coco["annotations"]: | |
| new_ann = dict(ann) | |
| new_ann["id"] = ann_id | |
| new_ann["image_id"] = local_img_id_map.get(ann["image_id"], ann["image_id"]) | |
| merged["annotations"].append(new_ann) | |
| ann_id += 1 | |
| return merged | |
| # ===================================================================== | |
| # SINGLE IMAGE HANDLER | |
| # ===================================================================== | |
| def process_single_image( | |
| image: np.ndarray, | |
| conf_threshold: float, | |
| iou_threshold: float, | |
| final_classes: List[str], | |
| ) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame, str]: | |
| """Gradio handler for single image.""" | |
| if image is None: | |
| return None, None, pd.DataFrame(columns=["Class", "Count"]), None | |
| if not final_classes: | |
| raise gr.Error("Please select at least one class.") | |
| # Save image to temp path | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
| Image.fromarray(image.astype("uint8")).save(tmp.name) | |
| img_path = tmp.name | |
| coco_json, annotated_all = _run_combined_on_path( | |
| img_path, conf_threshold, iou_threshold | |
| ) | |
| filtered_coco = _filter_coco_by_classes(coco_json, final_classes) | |
| annotated = _draw_coco_on_image(img_path, filtered_coco, final_classes) | |
| stats = _stats_from_coco(filtered_coco) | |
| stats_table = _stats_table(stats, image_name="image") | |
| stats_graph = _stats_graph(stats, "Single Image Statistics") | |
| os.unlink(img_path) | |
| # store for downloads | |
| global _single_coco_json | |
| _single_coco_json = filtered_coco | |
| return image, annotated, stats_table, stats_graph | |
| _single_coco_json: Dict | None = None | |
| _batch_coco_list: List[Dict] = [] | |
| def download_single_annotations() -> str | None: | |
| if not _single_coco_json: | |
| return None | |
| path = os.path.join( | |
| tempfile.gettempdir(), | |
| f"combined_single_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
| ) | |
| with open(path, "w") as f: | |
| json.dump(_single_coco_json, f, indent=2) | |
| return path | |
| def download_single_image(annotated: np.ndarray) -> str | None: | |
| if annotated is None: | |
| return None | |
| path = os.path.join( | |
| tempfile.gettempdir(), | |
| f"combined_single_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg", | |
| ) | |
| Image.fromarray(annotated.astype("uint8")).save(path, "JPEG", quality=95) | |
| return path | |
| # ===================================================================== | |
| # BATCH HANDLERS | |
| # ===================================================================== | |
| def process_batch_images( | |
| zip_file, | |
| conf_threshold: float, | |
| iou_threshold: float, | |
| final_classes: List[str], | |
| ): | |
| global _batch_coco_list | |
| _batch_coco_list = [] | |
| if zip_file is None: | |
| return [], "Please upload a ZIP file.", pd.DataFrame(), pd.DataFrame(), None | |
| if not final_classes: | |
| raise gr.Error("Please select at least one class.") | |
| gallery = [] | |
| per_image_tables = [] | |
| total_stats: Dict[str, int] = {} | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| with zipfile.ZipFile(zip_file.name, "r") as zf: | |
| zf.extractall(tmp_dir) | |
| # Valid image extensions | |
| IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff', '.webp', '.gif'} | |
| # Patterns to skip (hidden files, system files, etc.) | |
| SKIP_PATTERNS = [ | |
| '._', # macOS resource forks | |
| '.DS_Store', # macOS metadata | |
| 'Thumbs.db', # Windows thumbnail cache | |
| 'desktop.ini', # Windows folder settings | |
| '~$', # Temporary files | |
| ] | |
| image_paths = [] | |
| for root, dirs, files in os.walk(tmp_dir): | |
| # Skip __MACOSX directories (macOS metadata) | |
| if '__MACOSX' in root or '__MACOSX' in dirs: | |
| dirs[:] = [d for d in dirs if d != '__MACOSX'] | |
| continue | |
| # Skip hidden directories (starting with .) | |
| dirs[:] = [d for d in dirs if not d.startswith('.')] | |
| for fn in files: | |
| # Skip hidden/system files | |
| if any(fn.startswith(pattern) for pattern in SKIP_PATTERNS): | |
| continue | |
| # Skip files starting with . (hidden files on Unix/Linux) | |
| if fn.startswith('.'): | |
| continue | |
| # Check file extension (case-insensitive) | |
| file_ext = os.path.splitext(fn)[1].lower() | |
| if file_ext not in IMAGE_EXTENSIONS: | |
| continue | |
| full_path = os.path.join(root, fn) | |
| # Skip if not a regular file (e.g., symlinks, directories) | |
| if not os.path.isfile(full_path): | |
| continue | |
| # Skip empty files | |
| if os.path.getsize(full_path) == 0: | |
| print(f"WARNING โ ๏ธ Skipping empty file: {full_path}") | |
| continue | |
| # Verify it's actually a valid image file by trying to open it | |
| try: | |
| # Try to open and validate the image | |
| with Image.open(full_path) as test_img: | |
| # Verify it's a valid image format | |
| test_img.verify() | |
| # Reopen for format check (verify() closes the file) | |
| with Image.open(full_path) as test_img: | |
| # Check if we can actually read the image data | |
| test_img.load() | |
| # Check if image has valid dimensions | |
| if test_img.size[0] == 0 or test_img.size[1] == 0: | |
| print(f"WARNING โ ๏ธ Skipping image with invalid dimensions: {full_path}") | |
| continue | |
| # All checks passed, add to list | |
| image_paths.append(full_path) | |
| except Exception as e: | |
| # Skip invalid image files | |
| print(f"WARNING โ ๏ธ Skipping invalid/non-image file: {full_path} (Error: {str(e)})") | |
| continue | |
| processed_count = 0 | |
| error_count = 0 | |
| error_messages = [] | |
| for path in sorted(image_paths): | |
| fn = os.path.basename(path) | |
| try: | |
| coco_json, _ = _run_combined_on_path(path, conf_threshold, iou_threshold) | |
| filtered = _filter_coco_by_classes(coco_json, final_classes) | |
| _batch_coco_list.append(filtered) | |
| annotated = _draw_coco_on_image(path, filtered, final_classes) | |
| gallery.append((annotated, fn)) | |
| stats = _stats_from_coco(filtered) | |
| per_image_tables.append(_stats_table(stats, image_name=fn)) | |
| for k, v in stats.items(): | |
| total_stats[k] = total_stats.get(k, 0) + v | |
| processed_count += 1 | |
| except Exception as e: | |
| error_count += 1 | |
| error_msg = f"Error processing {fn}: {str(e)}" | |
| error_messages.append(error_msg) | |
| print(f"WARNING โ ๏ธ {error_msg}") | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| per_image_df = ( | |
| pd.concat(per_image_tables, ignore_index=True) if per_image_tables else pd.DataFrame() | |
| ) | |
| summary_df = _stats_table(total_stats) | |
| graph_path = _stats_graph(total_stats, "Batch Statistics") | |
| # Build status message | |
| status_parts = [f"Processed {processed_count} images successfully."] | |
| if error_count > 0: | |
| status_parts.append(f"Skipped {error_count} files with errors.") | |
| if error_messages: | |
| status_parts.append(f"Errors: {', '.join(error_messages[:3])}") # Show first 3 errors | |
| status = " ".join(status_parts) | |
| return gallery, status, per_image_df, summary_df, graph_path | |
| def download_batch_annotations() -> str | None: | |
| if not _batch_coco_list: | |
| return None | |
| merged = _merge_coco_list(_batch_coco_list) | |
| path = os.path.join( | |
| tempfile.gettempdir(), | |
| f"combined_batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
| ) | |
| with open(path, "w") as f: | |
| json.dump(merged, f, indent=2) | |
| return path | |
| def download_batch_zip(gallery_images: List[Tuple[str, np.ndarray]]) -> str | None: | |
| if not gallery_images: | |
| return None | |
| merged = _merge_coco_list(_batch_coco_list) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| zip_path = os.path.join( | |
| tempfile.gettempdir(), f"combined_batch_results_{ts}.zip" | |
| ) | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: | |
| # images | |
| for idx, (img_np, caption) in enumerate(gallery_images, start=1): | |
| img = Image.fromarray(img_np.astype("uint8")) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=95) | |
| zf.writestr(f"images/{caption}", buf.getvalue()) | |
| # annotations | |
| zf.writestr("annotations.json", json.dumps(merged, indent=2)) | |
| return zip_path | |
| # ===================================================================== | |
| # GRADIO UI | |
| # ===================================================================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Combined Manuscript Models (emanuskript + catmus + zone)") | |
| with gr.Tabs(): | |
| # Single image | |
| with gr.TabItem("Single Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="numpy") | |
| with gr.Accordion("Detection Settings", open=True): | |
| with gr.Row(): | |
| conf_threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.25, | |
| ) | |
| iou_threshold = gr.Slider( | |
| label="IoU Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.3, | |
| ) | |
| with gr.Accordion("Final Classes (combined)", open=True): | |
| final_classes_box = gr.CheckboxGroup( | |
| label="Final Classes", | |
| choices=FINAL_CLASSES, | |
| value=FINAL_CLASSES, | |
| ) | |
| with gr.Row(): | |
| select_all_btn = gr.Button("Select All", size="sm") | |
| unselect_all_btn = gr.Button("Unselect All", size="sm") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| detect_btn = gr.Button( | |
| "Run 3 Models + Combine", variant="primary" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Combined Result (final classes)", type="numpy" | |
| ) | |
| with gr.Row(): | |
| single_download_json_btn = gr.Button( | |
| "๐ Download Annotations (JSON)", size="sm" | |
| ) | |
| single_download_image_btn = gr.Button( | |
| "๐ผ๏ธ Download Image", size="sm" | |
| ) | |
| single_json_output = gr.File( | |
| label="Single JSON", visible=True, height=50 | |
| ) | |
| single_image_output = gr.File( | |
| label="Single Image", visible=True, height=50 | |
| ) | |
| with gr.Accordion("๐ Statistics", open=False): | |
| with gr.Tabs(): | |
| with gr.TabItem("Table"): | |
| single_stats_table = gr.Dataframe( | |
| label="Detection Statistics", | |
| headers=["Class", "Count"], | |
| wrap=True, | |
| ) | |
| with gr.TabItem("Graph"): | |
| single_stats_graph = gr.Image( | |
| label="Statistics Graph", type="filepath" | |
| ) | |
| # Batch tab | |
| with gr.TabItem("Batch Processing (ZIP)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| zip_file = gr.File( | |
| label="Upload ZIP with images", file_types=[".zip"] | |
| ) | |
| with gr.Accordion("Detection Settings", open=True): | |
| with gr.Row(): | |
| batch_conf_threshold = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.25, | |
| ) | |
| batch_iou_threshold = gr.Slider( | |
| label="IoU Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.3, | |
| ) | |
| with gr.Accordion("Final Classes (combined)", open=True): | |
| batch_final_classes_box = gr.CheckboxGroup( | |
| label="Final Classes", | |
| choices=FINAL_CLASSES, | |
| value=FINAL_CLASSES, | |
| ) | |
| with gr.Row(): | |
| batch_select_all_btn = gr.Button("Select All", size="sm") | |
| batch_unselect_all_btn = gr.Button( | |
| "Unselect All", size="sm" | |
| ) | |
| batch_status = gr.Textbox( | |
| label="Processing Status", | |
| value="Ready to process ZIP file...", | |
| interactive=False, | |
| max_lines=3, | |
| ) | |
| with gr.Row(): | |
| clear_batch_btn = gr.Button("Clear") | |
| process_batch_btn = gr.Button( | |
| "Process ZIP", variant="primary" | |
| ) | |
| with gr.Column(): | |
| batch_gallery = gr.Gallery( | |
| label="Batch Results (combined)", | |
| show_label=True, | |
| columns=2, | |
| rows=2, | |
| height="auto", | |
| type="numpy", | |
| ) | |
| with gr.Row(): | |
| download_json_btn = gr.Button( | |
| "๐ Download COCO Annotations (JSON)", size="sm" | |
| ) | |
| download_zip_btn = gr.Button( | |
| "๐ฆ Download Results (ZIP)", size="sm" | |
| ) | |
| json_file_output = gr.File( | |
| label="Batch JSON", visible=True, height=50 | |
| ) | |
| zip_file_output = gr.File( | |
| label="Batch ZIP", visible=True, height=50 | |
| ) | |
| with gr.Accordion("๐ Statistics", open=False): | |
| with gr.Tabs(): | |
| with gr.TabItem("Per Image"): | |
| batch_stats_table = gr.Dataframe( | |
| label="Per Image Statistics", wrap=True | |
| ) | |
| with gr.TabItem("Overall Summary"): | |
| batch_stats_summary_table = gr.Dataframe( | |
| label="Overall Statistics", wrap=True | |
| ) | |
| with gr.TabItem("Graph"): | |
| batch_stats_graph = gr.Image( | |
| label="Batch Statistics Graph", type="filepath" | |
| ) | |
| # Callbacks | |
| detect_btn.click( | |
| fn=process_single_image, | |
| inputs=[input_image, conf_threshold, iou_threshold, final_classes_box], | |
| outputs=[input_image, output_image, single_stats_table, single_stats_graph], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, pd.DataFrame(columns=["Class", "Count"]), None), | |
| inputs=None, | |
| outputs=[input_image, output_image, single_stats_table, single_stats_graph], | |
| ) | |
| select_all_btn.click(fn=lambda: FINAL_CLASSES, outputs=[final_classes_box]) | |
| unselect_all_btn.click(fn=lambda: [], outputs=[final_classes_box]) | |
| process_batch_btn.click( | |
| fn=process_batch_images, | |
| inputs=[ | |
| zip_file, | |
| batch_conf_threshold, | |
| batch_iou_threshold, | |
| batch_final_classes_box, | |
| ], | |
| outputs=[ | |
| batch_gallery, | |
| batch_status, | |
| batch_stats_table, | |
| batch_stats_summary_table, | |
| batch_stats_graph, | |
| ], | |
| ) | |
| clear_batch_btn.click( | |
| fn=lambda: (None, [], "Ready to process ZIP file..."), | |
| inputs=None, | |
| outputs=[zip_file, batch_gallery, batch_status], | |
| ) | |
| batch_select_all_btn.click( | |
| fn=lambda: FINAL_CLASSES, outputs=[batch_final_classes_box] | |
| ) | |
| batch_unselect_all_btn.click(fn=lambda: [], outputs=[batch_final_classes_box]) | |
| single_download_json_btn.click( | |
| fn=download_single_annotations, inputs=None, outputs=[single_json_output] | |
| ) | |
| single_download_image_btn.click( | |
| fn=lambda img: download_single_image(img), | |
| inputs=[output_image], | |
| outputs=[single_image_output], | |
| ) | |
| download_json_btn.click( | |
| fn=download_batch_annotations, inputs=None, outputs=[json_file_output] | |
| ) | |
| download_zip_btn.click( | |
| fn=lambda gallery: download_batch_zip(gallery), | |
| inputs=[batch_gallery], | |
| outputs=[zip_file_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch( | |
| debug=False, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=8000, | |
| share=False, | |
| max_threads=4, | |
| inbrowser=False, | |
| ssl_verify=True, | |
| quiet=False, | |
| ) | |