Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import gradio as gr | |
| import gradio.routes as gr_routes | |
| from PIL import Image, ImageOps | |
| from pathlib import Path | |
| import importlib.util | |
| import json | |
| OUTPUT_CLASSES = 100 | |
| TARGET_HEIGHT, TARGET_WIDTH = 28, 56 | |
| STD_FLOOR = 1e-8 | |
| METRIC_TARGETS = { | |
| "mass_fraction": (0.08, 0.35), | |
| "stroke_density": (0.12, 0.65), | |
| "center_offset": (0.0, 8.0), | |
| "mean_abs_z_score": (0.0, 2.5), | |
| "max_abs_z_score": (0.0, 6.0), | |
| "std_abs_z_score": (0.0, 1.5), | |
| } | |
| def _load_training_module(): | |
| module_path = Path(__file__).resolve().parent / "training-100.py" | |
| spec = importlib.util.spec_from_file_location("mnist100_training", module_path) | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| return module | |
| training_mod = _load_training_module() | |
| forward_prop = training_mod.forward_prop | |
| get_predictions = training_mod.get_predictions | |
| softmax = training_mod.softmax | |
| # Detect if running on Hugging Face Spaces | |
| IS_SPACE = bool(os.getenv("SPACE_ID")) | |
| def _metric_status(name, value): | |
| target = METRIC_TARGETS.get(name) | |
| status = "not_tracked" | |
| target_dict = None | |
| if target is not None: | |
| low, high = target | |
| target_dict = {"min": low, "max": high} | |
| if value is None or np.isnan(value): | |
| status = "invalid" | |
| elif low <= value <= high: | |
| status = "ok" | |
| else: | |
| status = "out_of_range" | |
| return status, target_dict | |
| def load_trained_artifacts(model_path=None): | |
| base_dir = Path(__file__).resolve().parent | |
| if model_path is None: | |
| resolved_path = base_dir / "archive" / "trained_model_mnist100.npz" | |
| else: | |
| candidate = Path(model_path) | |
| resolved_path = candidate if candidate.is_absolute() else base_dir / candidate | |
| if not resolved_path.exists(): | |
| raise RuntimeError( | |
| f"Model file '{resolved_path}' not found. Train the MNIST-100 model first by running 'python training-100.py'." | |
| ) | |
| loaded = np.load(resolved_path) | |
| params = {key: loaded[key] for key in loaded.files if key not in {"mean", "std"}} | |
| mean = loaded["mean"] | |
| std = loaded["std"] | |
| return params, mean, std | |
| params, mean, std = None, None, None | |
| def ensure_model_loaded(): | |
| global params, mean, std | |
| if params is None or mean is None or std is None: | |
| params, mean, std = load_trained_artifacts() | |
| def extract_canvas_array(img_input): | |
| if img_input is None: | |
| return None | |
| if isinstance(img_input, dict): | |
| for key in ("image", "composite", "background", "value"): | |
| payload = img_input.get(key) | |
| if payload is not None: | |
| img_input = payload | |
| break | |
| else: | |
| return None | |
| if isinstance(img_input, Image.Image): | |
| return img_input | |
| if isinstance(img_input, np.ndarray): | |
| arr_in = img_input | |
| if arr_in.dtype != np.uint8: | |
| max_val = float(arr_in.max()) if arr_in.size else 1.0 | |
| if max_val <= 1.5: | |
| arr_in = (arr_in * 255.0).clip(0, 255).astype(np.uint8) | |
| else: | |
| arr_in = np.clip(arr_in, 0, 255).astype(np.uint8) | |
| if arr_in.ndim == 3 and arr_in.shape[2] == 4: | |
| return Image.fromarray(arr_in, mode="RGBA") | |
| return Image.fromarray(arr_in) | |
| return None | |
| def shift_with_zero_pad(arr, shift_y=0, shift_x=0): | |
| if shift_y == 0 and shift_x == 0: | |
| return arr | |
| rolled = np.roll(arr, shift=shift_y, axis=0) | |
| rolled = np.roll(rolled, shift=shift_x, axis=1) | |
| out = rolled.copy() | |
| if shift_y > 0: | |
| out[:shift_y, :] = 0.0 | |
| elif shift_y < 0: | |
| out[shift_y:, :] = 0.0 | |
| if shift_x > 0: | |
| out[:, :shift_x] = 0.0 | |
| elif shift_x < 0: | |
| out[:, shift_x:] = 0.0 | |
| return out | |
| def dilate_binary_like(arr, radius=1): | |
| # Vectorized dilation via max over shifted windows (3x3 when radius=1) | |
| if radius != 1: | |
| # Fallback to radius=1 behavior for simplicity/perf | |
| radius = 1 | |
| shifts = [] | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| shifts.append(shift_with_zero_pad(arr, dy, dx)) | |
| stacked = np.stack(shifts, axis=0) | |
| return np.max(stacked, axis=0) | |
| def erode_binary_like(arr, radius=1): | |
| # Vectorized erosion via min over shifted windows (3x3 when radius=1) | |
| if radius != 1: | |
| radius = 1 | |
| shifts = [] | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| shifts.append(shift_with_zero_pad(arr, dy, dx)) | |
| stacked = np.stack(shifts, axis=0) | |
| return np.min(stacked, axis=0) | |
| def generate_inference_variants(arr, *, fast: bool = False): | |
| variants = [] | |
| if fast: | |
| # Space-optimized: only cardinal shifts (4 variants) | |
| for dy, dx in ((-1, 0), (1, 0), (0, -1), (0, 1)): | |
| variants.append(shift_with_zero_pad(arr, dy, dx)) | |
| return variants | |
| # Full set: 8 shifts + morphology | |
| for dy in (-1, 0, 1): | |
| for dx in (-1, 0, 1): | |
| if dy == 0 and dx == 0: | |
| continue | |
| variants.append(shift_with_zero_pad(arr, dy, dx)) | |
| variants.append(dilate_binary_like(arr, radius=1)) | |
| variants.append(erode_binary_like(arr, radius=1)) | |
| return variants | |
| def _auto_balance_stroke(arr: np.ndarray, *, target_mass_fraction: float, clamp: tuple[float, float]): | |
| mass_fraction = float(arr.sum() / (TARGET_HEIGHT * TARGET_WIDTH)) | |
| if mass_fraction <= 1e-6: | |
| return arr, 1.0, mass_fraction | |
| scale = np.sqrt(target_mass_fraction / mass_fraction) | |
| min_scale, max_scale = clamp | |
| scale = float(np.clip(scale, min_scale, max_scale)) | |
| adjusted = np.clip(arr * scale, 0.0, 1.0) | |
| new_mass_fraction = float(adjusted.sum() / (TARGET_HEIGHT * TARGET_WIDTH)) | |
| return adjusted, scale, new_mass_fraction | |
| def compose_dual_canvas(left_input, right_input): | |
| left_img = extract_canvas_array(left_input) | |
| right_img = extract_canvas_array(right_input) | |
| if left_img is None and right_img is None: | |
| return None | |
| if left_img is None: | |
| if right_img is None: | |
| return None | |
| base_size = right_img.size | |
| left_img = Image.new("L", base_size, color=255) | |
| if right_img is None: | |
| base_size = left_img.size | |
| right_img = Image.new("L", base_size, color=255) | |
| left_img = left_img.convert("L") | |
| right_img = right_img.convert("L") | |
| if left_img.height != right_img.height: | |
| target_height = min(left_img.height, right_img.height) | |
| left_img = left_img.resize( | |
| (left_img.width, target_height), Image.Resampling.LANCZOS | |
| ) | |
| right_img = right_img.resize( | |
| (right_img.width, target_height), Image.Resampling.LANCZOS | |
| ) | |
| combined = Image.new( | |
| "L", | |
| (left_img.width + right_img.width, left_img.height), | |
| color=255, | |
| ) | |
| combined.paste(left_img, (0, 0)) | |
| combined.paste(right_img, (left_img.width, 0)) | |
| return combined | |
| def preprocess_image(img_input, stroke_scale: float = 1.0, *, auto_balance: bool = True): | |
| ensure_model_loaded() | |
| img = extract_canvas_array(img_input) | |
| if img is None: | |
| return None | |
| try: | |
| bands = img.getbands() | |
| except Exception: | |
| bands = () | |
| if "A" in bands: | |
| rgba = img.convert("RGBA") | |
| white_bg = Image.new("RGBA", rgba.size, (255, 255, 255, 255)) | |
| img = Image.alpha_composite(white_bg, rgba).convert("RGB") | |
| img = img.convert("L") | |
| img = ImageOps.invert(img) | |
| arr_u8 = np.array(img, dtype=np.uint8) | |
| original_canvas_shape = arr_u8.shape | |
| coords = np.column_stack(np.where(arr_u8 > 10)) | |
| bbox = None | |
| if coords.size > 0: | |
| y_min, x_min = coords.min(axis=0) | |
| y_max, x_max = coords.max(axis=0) + 1 | |
| pad = 4 | |
| y_min = max(0, y_min - pad) | |
| x_min = max(0, x_min - pad) | |
| y_max = min(arr_u8.shape[0], y_max + pad) | |
| x_max = min(arr_u8.shape[1], x_max + pad) | |
| bbox = (int(y_min), int(y_max), int(x_min), int(x_max)) | |
| arr_u8 = arr_u8[y_min:y_max, x_min:x_max] | |
| if arr_u8.size == 0: | |
| return None | |
| h, w = arr_u8.shape | |
| target_ratio = TARGET_WIDTH / TARGET_HEIGHT | |
| if h == 0 or w == 0: | |
| return None | |
| current_ratio = w / h if h else target_ratio | |
| if current_ratio > target_ratio: | |
| new_height = int(round(w / target_ratio)) | |
| pad_total = max(new_height - h, 0) | |
| pad_top = pad_total // 2 | |
| pad_bottom = pad_total - pad_top | |
| pad_left = pad_right = 0 | |
| else: | |
| new_width = int(round(h * target_ratio)) | |
| pad_total = max(new_width - w, 0) | |
| pad_left = pad_total // 2 | |
| pad_right = pad_total - pad_left | |
| pad_top = pad_bottom = 0 | |
| arr_padded = np.pad( | |
| arr_u8, | |
| ((pad_top, pad_bottom), (pad_left, pad_right)), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| resized = Image.fromarray(arr_padded).resize( | |
| (TARGET_WIDTH, TARGET_HEIGHT), Image.Resampling.LANCZOS | |
| ) | |
| arr_resized = np.array(resized, dtype=np.float32) / 255.0 | |
| mean_image = mean.reshape(TARGET_HEIGHT, TARGET_WIDTH) | |
| std_safe = np.maximum(std, STD_FLOOR) | |
| stroke_scale = float(stroke_scale) | |
| stroke_scale = max(0.3, min(stroke_scale, 1.5)) | |
| arr_resized = np.clip(arr_resized * stroke_scale, 0.0, 1.0) | |
| auto_balance_scale = 1.0 | |
| balanced_mass_fraction = float(arr_resized.sum() / (TARGET_HEIGHT * TARGET_WIDTH)) | |
| if auto_balance: | |
| target_mass = sum(METRIC_TARGETS["mass_fraction"]) / 2.0 | |
| arr_resized, auto_balance_scale, balanced_mass_fraction = _auto_balance_stroke( | |
| arr_resized, | |
| target_mass_fraction=target_mass, | |
| clamp=(0.7, 1.4), | |
| ) | |
| augmented_arrays = [arr_resized, *generate_inference_variants(arr_resized, fast=IS_SPACE)] | |
| augmented_standardized = [ | |
| (arr.reshape(TARGET_HEIGHT * TARGET_WIDTH, 1) - mean) / std_safe | |
| for arr in augmented_arrays | |
| ] | |
| mean_diff = np.abs(arr_resized - mean_image) | |
| mean_diff_uint8 = (mean_diff / (mean_diff.max() + 1e-8) * 255.0).astype(np.uint8) | |
| diagnostics = compute_diagnostics( | |
| arr_resized, | |
| bbox, | |
| original_canvas_shape, | |
| mean_image, | |
| augmented_standardized[0], | |
| std_safe, | |
| ) | |
| diagnostics["applied_auto_balance"] = { | |
| "enabled": bool(auto_balance), | |
| "scale": float(auto_balance_scale), | |
| "mass_fraction_after": float(balanced_mass_fraction), | |
| } | |
| return augmented_standardized, arr_resized, mean_diff_uint8, diagnostics | |
| def compute_diagnostics(arr_float, bbox, original_shape, mean_image, standardized, std_safe): | |
| mass = arr_float | |
| total_intensity = float(mass.sum()) | |
| mass_threshold = mass > 0.05 | |
| if mass_threshold.any(): | |
| ys, xs = np.where(mass_threshold) | |
| bbox_est = (int(ys.min()), int(ys.max()) + 1, int(xs.min()), int(xs.max()) + 1) | |
| else: | |
| bbox_est = None | |
| cy = cx = None | |
| if total_intensity > 1e-6: | |
| grid_y, grid_x = np.indices(mass.shape) | |
| weighted_sum = mass.sum() | |
| cy = float((grid_y * mass).sum() / weighted_sum) | |
| cx = float((grid_x * mass).sum() / weighted_sum) | |
| bbox_use = bbox_est or bbox | |
| if bbox_use: | |
| top, bottom, left, right = bbox_use | |
| height = bottom - top | |
| width = right - left | |
| bbox_area = height * width | |
| bbox_metrics = { | |
| "top": top, | |
| "bottom": bottom, | |
| "left": left, | |
| "right": right, | |
| "height": height, | |
| "width": width, | |
| "aspect_ratio": float(width / height) if height else None, | |
| "area": bbox_area, | |
| "area_ratio": float(bbox_area / (TARGET_HEIGHT * TARGET_WIDTH)) if bbox_area else 0.0, | |
| } | |
| else: | |
| bbox_metrics = { | |
| "top": None, | |
| "bottom": None, | |
| "left": None, | |
| "right": None, | |
| "height": 0, | |
| "width": 0, | |
| "aspect_ratio": None, | |
| "area": 0, | |
| "area_ratio": 0.0, | |
| } | |
| density = 0.0 | |
| bbox_area = bbox_metrics.get("area", 0) | |
| if bbox_area: | |
| density = float(total_intensity / bbox_area) | |
| center_offset = None | |
| if cy is not None and cx is not None: | |
| ideal_cy = (TARGET_HEIGHT - 1) / 2.0 | |
| ideal_cx = (TARGET_WIDTH - 1) / 2.0 | |
| center_offset = float(np.sqrt((cy - ideal_cy) ** 2 + (cx - ideal_cx) ** 2)) | |
| standardized_flat = standardized.flatten() | |
| mean_flat = mean_image.flatten() | |
| arr_flat = arr_float.flatten() | |
| std_flat = std_safe.flatten() | |
| norm_input = np.linalg.norm(arr_flat) | |
| norm_mean = np.linalg.norm(mean_flat) | |
| cosine_similarity = None | |
| if norm_input > 0.0 and norm_mean > 0.0: | |
| cosine_similarity = float(np.dot(arr_flat, mean_flat) / (norm_input * norm_mean)) | |
| mean_abs_z = float(np.mean(np.abs(standardized_flat))) | |
| max_abs_z = float(np.max(np.abs(standardized_flat))) | |
| std_of_z = float(np.std(standardized_flat)) | |
| low_var_mask = std_flat <= STD_FLOOR + 1e-12 | |
| activated_low_var = int(np.count_nonzero(low_var_mask & (np.abs(arr_flat - mean_flat) > 1e-3))) | |
| stats = { | |
| "total_intensity": total_intensity, | |
| "mass_fraction": float(total_intensity / (TARGET_HEIGHT * TARGET_WIDTH)), | |
| "center_of_mass": {"row": cy, "col": cx}, | |
| "center_offset": center_offset, | |
| "bbox": bbox_metrics, | |
| "original_canvas_shape": original_shape, | |
| "stroke_density": density, | |
| "warnings": [], | |
| "mean_intensity": float(arr_float.mean()), | |
| "pixel_intensity_range": { | |
| "min": float(arr_float.min()), | |
| "max": float(arr_float.max()), | |
| }, | |
| "cosine_similarity_vs_mean": cosine_similarity, | |
| "mean_abs_z_score": mean_abs_z, | |
| "max_abs_z_score": max_abs_z, | |
| "std_abs_z_score": std_of_z, | |
| "low_variance_pixels_triggered": activated_low_var, | |
| "low_variance_threshold": STD_FLOOR, | |
| "low_variance_pixels_fraction": float(activated_low_var / max(1, int(low_var_mask.sum()))), | |
| } | |
| if mean_image is not None: | |
| stats["distance_from_mean"] = float(np.linalg.norm(arr_float - mean_image)) | |
| metric_checks = {} | |
| for metric_name in ( | |
| "mass_fraction", | |
| "stroke_density", | |
| "center_offset", | |
| "mean_abs_z_score", | |
| "max_abs_z_score", | |
| "std_abs_z_score", | |
| ): | |
| value = stats.get(metric_name) | |
| if value is not None: | |
| value = float(value) | |
| status, target_dict = _metric_status(metric_name, value) | |
| entry = {"value": value, "status": status} | |
| if target_dict is not None: | |
| entry["target"] = target_dict | |
| metric_checks[metric_name] = entry | |
| stats["metric_checks"] = metric_checks | |
| return stats | |
| def enrich_diagnostics(stats, probs): | |
| warnings = [] | |
| bbox = stats.get("bbox", {}) | |
| metric_checks = stats.get("metric_checks", {}) | |
| for name, info in metric_checks.items(): | |
| if info.get("status") == "out_of_range": | |
| target = info.get("target") | |
| value = info.get("value") | |
| value_str = "None" if value is None else f"{value:.4f}" | |
| if target is not None: | |
| warnings.append( | |
| f"{name}: value={value_str}, target=[{target['min']:.4f},{target['max']:.4f}]" | |
| ) | |
| else: | |
| warnings.append(f"{name}: value={value_str}") | |
| aspect_ratio = bbox.get("aspect_ratio") | |
| if aspect_ratio is not None and (aspect_ratio < 1.0 or aspect_ratio > 3.5): | |
| warnings.append(f"aspect_ratio: value={aspect_ratio:.4f}, expected≈[1.00,3.50]") | |
| confidences = np.sort(probs.flatten())[::-1] | |
| if confidences.size >= 2: | |
| margin = confidences[0] - confidences[1] | |
| stats_margin = { | |
| "value": float(margin), | |
| "status": "ok" if margin >= 0.05 else "low_margin", | |
| "target": {"min": 0.05, "max": 1.0}, | |
| } | |
| else: | |
| margin = None | |
| stats_margin = {"value": None, "status": "insufficient_classes"} | |
| if margin is not None and margin < 0.05: | |
| warnings.append(f"prob_margin: value={margin:.4f}, target≥0.0500") | |
| stats = dict(stats) | |
| stats["warnings"] = warnings | |
| stats["top_confidence"] = float(confidences[0]) if confidences.size else None | |
| stats["second_confidence"] = float(confidences[1]) if confidences.size > 1 else None | |
| stats["prob_margin"] = stats_margin | |
| return stats | |
| def predict_number(left_canvas, right_canvas, stroke_scale, auto_balance): | |
| ensure_model_loaded() | |
| combined_canvas = compose_dual_canvas(left_canvas, right_canvas) | |
| if combined_canvas is None: | |
| blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)} | |
| empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| diagnostics = {"warnings": ["Draw both digits to see diagnostics."]} | |
| return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2) | |
| result = preprocess_image( | |
| combined_canvas, | |
| stroke_scale=stroke_scale, | |
| auto_balance=bool(auto_balance), | |
| ) | |
| if result is None: | |
| blank_probs = {f"{i:02d}": 0.0 for i in range(OUTPUT_CLASSES)} | |
| empty_preview = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| empty_diff = np.zeros((TARGET_HEIGHT, TARGET_WIDTH), dtype=np.uint8) | |
| diagnostics = {"warnings": ["Draw a number to see diagnostics."]} | |
| return None, blank_probs, empty_preview, empty_diff, json.dumps(diagnostics, indent=2) | |
| standardized_variants, preview, mean_diff, diagnostics = result | |
| variants_matrix = np.concatenate(standardized_variants, axis=1).astype(np.float32, copy=False) | |
| cache, probs_matrix = forward_prop(variants_matrix, params, training=False) | |
| logits_matrix = cache["Z_fc2"] | |
| avg_logits = np.mean(logits_matrix, axis=1, keepdims=True) | |
| probs = softmax(avg_logits) | |
| pred = int(get_predictions(probs)[0]) | |
| prob_rows = [[f"{i:02d}", float(probs[i, 0])] for i in range(OUTPUT_CLASSES)] | |
| prob_rows.sort(key=lambda r: r[1], reverse=True) | |
| diagnostics = enrich_diagnostics(diagnostics, probs) | |
| diagnostics["variants_used"] = int(probs_matrix.shape[1]) | |
| diagnostics["variant_top_confidences"] = [ | |
| float(probs_matrix[pred, idx]) for idx in range(probs_matrix.shape[1]) | |
| ] | |
| return pred, prob_rows, (preview * 255).astype(np.uint8), mean_diff, json.dumps(diagnostics, indent=2) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Elliot's MNIST-100 Classifier | |
| Draw a two-digit number (00-99). Use the left canvas for the tens digit and the right canvas for the ones digit. The model will predict the number, show the top class probabilities, and display diagnostics for the processed input. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| left_canvas = gr.Sketchpad(label="Tens Digit") | |
| right_canvas = gr.Sketchpad(label="Ones Digit") | |
| stroke_slider = gr.Slider( | |
| minimum=0.3, | |
| maximum=1.2, | |
| value=1.0, | |
| step=0.05, | |
| label="Stroke Intensity (scale)", | |
| ) | |
| auto_balance = gr.Checkbox( | |
| value=True, | |
| label="Auto Balance Stroke Thickness", | |
| info="Automatically rescales the digit to match training mass and brightness.", | |
| ) | |
| with gr.Column(scale=1): | |
| pred_box = gr.Number(label="Predicted Number", precision=0, value=None) | |
| prob_table = gr.Dataframe( | |
| label="Class Probabilities", | |
| headers=["class", "prob"], | |
| datatype=["str", "number"], | |
| interactive=False, | |
| ) | |
| preview = gr.Image(label="Model Input Preview (28x56)", image_mode="L") | |
| mean_diff_view = gr.Image(label="Difference vs Training Mean", image_mode="L") | |
| diagnostics_box = gr.Code(label="Diagnostics (JSON)", language="json") | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| clear_btn = gr.ClearButton( | |
| [ | |
| left_canvas, | |
| right_canvas, | |
| stroke_slider, | |
| auto_balance, | |
| pred_box, | |
| prob_table, | |
| preview, | |
| mean_diff_view, | |
| diagnostics_box, | |
| ] | |
| ) | |
| predict_btn.click( | |
| fn=predict_number, | |
| inputs=[left_canvas, right_canvas, stroke_slider, auto_balance], | |
| outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box], | |
| ) | |
| # On Spaces, avoid per-stroke inference to prevent event floods | |
| if not IS_SPACE: | |
| left_canvas.change( | |
| fn=predict_number, | |
| inputs=[left_canvas, right_canvas, stroke_slider, auto_balance], | |
| outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box], | |
| ) | |
| right_canvas.change( | |
| fn=predict_number, | |
| inputs=[left_canvas, right_canvas, stroke_slider, auto_balance], | |
| outputs=[pred_box, prob_table, preview, mean_diff_view, diagnostics_box], | |
| ) | |
| if __name__ == "__main__": | |
| space_env = os.getenv("SPACE_ID") | |
| def _queue_app(blocks): | |
| try: | |
| return blocks.queue(concurrency_count=1) | |
| except TypeError: | |
| # Older Gradio versions don't support the argument | |
| try: | |
| return blocks.queue() | |
| except Exception: | |
| return blocks | |
| app_to_launch = _queue_app(demo) | |
| if space_env: | |
| app_to_launch.launch(show_api=False) | |
| else: | |
| app_to_launch.launch(server_name="0.0.0.0", share=True, show_api=False) | |
| def _disable_gradio_api_schema(*_args, **_kwargs): | |
| """Work around Gradio schema bug on Python 3.13 by returning empty metadata.""" | |
| return {} | |
| gr_routes.api_info = _disable_gradio_api_schema | |