Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image | |
| import time | |
| import pandas as pd | |
| import requests | |
| from io import BytesIO | |
| # Load class names | |
| CLASS_NAMES = [ | |
| "AC Mat", | |
| "Alco brake camera", | |
| "Alco-brake device", | |
| "Back windshield", | |
| "Bus back side", | |
| "Bus front side", | |
| "Bus side", | |
| "Cabin", | |
| "Driver grooming", | |
| "First aid kit", | |
| "Floormats & POS", | |
| "Front windshield", | |
| "Hat rack", | |
| "ITMS Device", | |
| "Jack & Spare tyre", | |
| "Luggage compartment", | |
| "RFID Card", | |
| "Seats" | |
| ] | |
| # Load ONNX model | |
| MODEL_PATH = "siglip_v2.onnx" | |
| session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) | |
| input_name = session.get_inputs()[0].name | |
| def preprocess_image(image): | |
| """Preprocess image with SigLIP normalization""" | |
| # Resize to 224x224 | |
| img_resized = image.resize((224, 224)) | |
| img_array = np.array(img_resized).astype(np.float32) / 255.0 | |
| # SigLIP normalization | |
| mean = np.array([0.5, 0.5, 0.5]) | |
| std = np.array([0.5, 0.5, 0.5]) | |
| img_norm = (img_array - mean) / std | |
| # Convert to CHW format (channels, height, width) | |
| img_final = np.transpose(img_norm, (2, 0, 1)) | |
| return np.expand_dims(img_final, axis=0).astype(np.float32) | |
| def predict_single_image(image): | |
| """ | |
| Run inference on a single image | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| dict: Contains class_name, confidence, and inference_time_ms | |
| """ | |
| # Convert to PIL Image if numpy array | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| else: | |
| image = image.convert('RGB') | |
| # Start timing | |
| start_time = time.time() | |
| # Preprocess | |
| img_tensor = preprocess_image(image) | |
| # Run inference | |
| outputs = session.run(None, {input_name: img_tensor})[0] | |
| # Apply softmax | |
| exp_outputs = np.exp(outputs - np.max(outputs)) | |
| probs = exp_outputs / exp_outputs.sum() | |
| # Get prediction | |
| pred_idx = np.argmax(probs) | |
| confidence = float(probs[0][pred_idx]) | |
| pred_class = CLASS_NAMES[pred_idx] | |
| # Calculate inference time | |
| inference_time = (time.time() - start_time) * 1000 # Convert to milliseconds | |
| # Return results | |
| return { | |
| "class_name": pred_class, | |
| "confidence": f"{confidence:.2%}", | |
| "inference_time_ms": f"{inference_time:.2f}" | |
| } | |
| def predict_batch(images, csv_file): | |
| """ | |
| Run inference on multiple images or CSV with image URLs (unlimited) with PROGRESSIVE DISPLAY | |
| Args: | |
| images: List of PIL Images or file paths (or None) | |
| csv_file: CSV file with image URLs (or None) | |
| Yields: | |
| tuple: (gallery_data, json_results) after each image is processed | |
| """ | |
| # Check if CSV file is provided | |
| if csv_file is not None: | |
| try: | |
| # Read CSV | |
| df = pd.read_csv(csv_file) | |
| # Validate columns | |
| if 'Answer' not in df.columns or 'Questions - QuestionId β Name' not in df.columns: | |
| yield [], { | |
| "error": "CSV must have 'Answer' and 'Questions - QuestionId β Name' columns", | |
| "total_images": 0, | |
| "results": [] | |
| } | |
| return | |
| results = [] | |
| gallery_images = [] | |
| total_start_time = time.time() | |
| # Process each row PROGRESSIVELY | |
| for idx, row in df.iterrows(): | |
| try: | |
| # Get image URL and expected class | |
| img_url = row['Answer'] | |
| given_class = row['Questions - QuestionId β Name'] | |
| # Download image from URL | |
| response = requests.get(img_url, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert('RGB') | |
| # Get prediction | |
| result = predict_single_image(image) | |
| result["image_index"] = idx + 1 | |
| result["given_class"] = given_class | |
| result["image_url"] = img_url | |
| # Check if matches | |
| result["match"] = "β" if given_class.lower() in result["class_name"].lower() or result["class_name"].lower() in given_class.lower() else "β" | |
| results.append(result) | |
| # Create caption for gallery - CONCISE FORMAT | |
| caption = f"#{idx + 1} {result['match']} Pred: {result['class_name']}\nβ Expected: {given_class}\n{result['confidence']} | {result['inference_time_ms']}ms" | |
| # Add to gallery | |
| gallery_images.append((image, caption)) | |
| except Exception as e: | |
| results.append({ | |
| "image_index": idx + 1, | |
| "given_class": row.get('Questions - QuestionId β Name', 'Unknown'), | |
| "image_url": row.get('Answer', 'Unknown'), | |
| "error": str(e), | |
| "class_name": None, | |
| "confidence": None, | |
| "inference_time_ms": None, | |
| "match": "β" | |
| }) | |
| # YIELD every 5 images (or last image) - More reliable updates! | |
| if (idx + 1) % 5 == 0 or (idx + 1) == len(df): | |
| elapsed_time = (time.time() - total_start_time) * 1000 | |
| successful = [r for r in results if "error" not in r] | |
| matched = [r for r in successful if r["match"] == "β"] | |
| json_results = { | |
| "source": "CSV", | |
| "status": f"Processing... {idx + 1}/{len(df)} ({((idx+1)/len(df)*100):.1f}%)", | |
| "total_images": len(df), | |
| "processed": idx + 1, | |
| "successful_predictions": len(successful), | |
| "failed_predictions": len(results) - len(successful), | |
| "matched_predictions": len(matched), | |
| "accuracy": f"{(len(matched) / len(successful) * 100):.2f}%" if successful else "0%", | |
| "elapsed_time_ms": f"{elapsed_time:.2f}", | |
| "average_time_per_image_ms": f"{elapsed_time / (idx + 1):.2f}", | |
| "last_results": results[-5:] # Show last 5 for reference | |
| } | |
| yield gallery_images.copy(), json_results | |
| # Final yield with complete results | |
| total_time = (time.time() - total_start_time) * 1000 | |
| successful = [r for r in results if "error" not in r] | |
| matched = [r for r in successful if r["match"] == "β"] | |
| final_results = { | |
| "source": "CSV", | |
| "status": "β Complete!", | |
| "total_images": len(df), | |
| "processed": len(df), | |
| "successful_predictions": len(successful), | |
| "failed_predictions": len(results) - len(successful), | |
| "matched_predictions": len(matched), | |
| "accuracy": f"{(len(matched) / len(successful) * 100):.2f}%" if successful else "0%", | |
| "total_processing_time_ms": f"{total_time:.2f}", | |
| "average_time_per_image_ms": f"{total_time / len(df):.2f}", | |
| "results": results # Full results at the end | |
| } | |
| yield gallery_images, final_results | |
| except Exception as e: | |
| yield [], { | |
| "error": f"CSV processing error: {str(e)}", | |
| "total_images": 0, | |
| "results": [] | |
| } | |
| return | |
| # Process regular image uploads (no limit) PROGRESSIVELY | |
| if images is None or len(images) == 0: | |
| yield [], { | |
| "error": "No images or CSV provided", | |
| "total_images": 0, | |
| "results": [] | |
| } | |
| return | |
| results = [] | |
| gallery_images = [] | |
| total_start_time = time.time() | |
| for idx, img in enumerate(images): | |
| try: | |
| # Handle file path or PIL Image | |
| if isinstance(img, str): | |
| image = Image.open(img).convert('RGB') | |
| img_path = img | |
| elif isinstance(img, np.ndarray): | |
| image = Image.fromarray(img).convert('RGB') | |
| img_path = None | |
| else: | |
| image = img.convert('RGB') | |
| img_path = None | |
| # Get prediction | |
| result = predict_single_image(image) | |
| result["image_index"] = idx + 1 | |
| results.append(result) | |
| # Create caption for gallery - CONCISE FORMAT | |
| caption = f"#{idx + 1} {result['class_name']}\n{result['confidence']} | {result['inference_time_ms']}ms" | |
| # Add to gallery (use file path if available, otherwise PIL Image) | |
| gallery_images.append((img_path if img_path else image, caption)) | |
| except Exception as e: | |
| results.append({ | |
| "image_index": idx + 1, | |
| "error": str(e), | |
| "class_name": None, | |
| "confidence": None, | |
| "inference_time_ms": None | |
| }) | |
| # Add error image to gallery | |
| try: | |
| if isinstance(img, str): | |
| error_img = Image.open(img).convert('RGB') | |
| elif isinstance(img, np.ndarray): | |
| error_img = Image.fromarray(img).convert('RGB') | |
| else: | |
| error_img = img.convert('RGB') | |
| gallery_images.append((error_img, f"#{idx + 1}: ERROR - {str(e)}")) | |
| except: | |
| pass | |
| # YIELD every 5 images (or last image) - More reliable updates! | |
| if (idx + 1) % 5 == 0 or (idx + 1) == len(images): | |
| elapsed_time = (time.time() - total_start_time) * 1000 | |
| json_results = { | |
| "source": "Direct Upload", | |
| "status": f"Processing... {idx + 1}/{len(images)} ({((idx+1)/len(images)*100):.1f}%)", | |
| "total_images": len(images), | |
| "processed": idx + 1, | |
| "successful_predictions": len([r for r in results if "error" not in r]), | |
| "failed_predictions": len([r for r in results if "error" in r]), | |
| "elapsed_time_ms": f"{elapsed_time:.2f}", | |
| "average_time_per_image_ms": f"{elapsed_time / (idx + 1):.2f}", | |
| "last_results": results[-5:] # Show last 5 for reference | |
| } | |
| yield gallery_images.copy(), json_results | |
| # Final yield with complete results | |
| total_time = (time.time() - total_start_time) * 1000 | |
| final_results = { | |
| "source": "Direct Upload", | |
| "status": "β Complete!", | |
| "total_images": len(images), | |
| "processed": len(images), | |
| "successful_predictions": len([r for r in results if "error" not in r]), | |
| "failed_predictions": len([r for r in results if "error" in r]), | |
| "total_processing_time_ms": f"{total_time:.2f}", | |
| "average_time_per_image_ms": f"{total_time / len(images):.2f}", | |
| "results": results # Full results at the end | |
| } | |
| yield gallery_images, final_results | |
| # Create tabbed interface | |
| with gr.Blocks(title="π Bus Inspection Classifier") as demo: | |
| gr.Markdown("# π Bus Inspection Classifier - SigLIP v2") | |
| gr.Markdown(""" | |
| Automated bus component classification using the **SigLIP v2** vision model. | |
| **18 Categories:** AC Mat | Alco brake camera | Alco-brake device | Back windshield | Bus back side | Bus front side | Bus side | Cabin | Driver grooming | First aid kit | Floormats & POS | Front windshield | Hat rack | ITMS Device | Jack & Spare tyre | Luggage compartment | RFID Card | Seats | |
| """) | |
| with gr.Tabs(): | |
| # Single Image Tab | |
| with gr.Tab("Single Image"): | |
| gr.Markdown("### Upload a single bus inspection image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| single_input = gr.Image(type="pil", label="Upload Image") | |
| single_button = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| single_output = gr.JSON(label="Prediction Result") | |
| single_button.click( | |
| fn=predict_single_image, | |
| inputs=single_input, | |
| outputs=single_output | |
| ) | |
| gr.Markdown(""" | |
| **Returns:** | |
| - `class_name`: Predicted bus component category | |
| - `confidence`: Model confidence score (%) | |
| - `inference_time_ms`: Processing time in milliseconds | |
| """) | |
| # Batch Processing Tab | |
| with gr.Tab("Batch Processing (Unlimited)"): | |
| gr.Markdown("### Upload images OR CSV file with image URLs") | |
| gr.Markdown("**Option 1:** Upload multiple images directly") | |
| gr.Markdown("**Option 2:** Upload CSV with columns: `Questions - QuestionId β Name` (given class) and `Answer` (image URL)") | |
| batch_input = gr.File( | |
| file_count="multiple", | |
| label="Upload Images", | |
| file_types=["image"] | |
| ) | |
| csv_input = gr.File( | |
| file_count="single", | |
| label="OR Upload CSV with Image URLs", | |
| file_types=[".csv"] | |
| ) | |
| batch_button = gr.Button("Classify Batch", variant="primary", size="lg") | |
| # Gallery to show images with predictions - LARGER DISPLAY | |
| batch_gallery = gr.Gallery( | |
| label="Classified Images with Predictions", | |
| show_label=True, | |
| columns=2, # Reduced from 3 to show larger images | |
| rows=4, # Increased rows | |
| height=600, # Fixed height for better scrolling | |
| object_fit="contain" | |
| ) | |
| # JSON output for API/detailed results | |
| batch_output = gr.JSON(label="Detailed JSON Results") | |
| batch_button.click( | |
| fn=predict_batch, | |
| inputs=[batch_input, csv_input], | |
| outputs=[batch_gallery, batch_output], | |
| show_progress="full" # Enable progress display | |
| ).then( | |
| lambda: None, # Completion callback | |
| None, | |
| None | |
| ) | |
| gr.Markdown(""" | |
| **Returns:** | |
| ```json | |
| { | |
| "total_images": 10, | |
| "successful_predictions": 10, | |
| "failed_predictions": 0, | |
| "total_processing_time_ms": "456.78", | |
| "average_time_per_image_ms": "45.68", | |
| "results": [ | |
| { | |
| "image_index": 1, | |
| "class_name": "Bus front side", | |
| "confidence": "98.45%", | |
| "inference_time_ms": "43.21" | |
| }, | |
| ... | |
| ] | |
| } | |
| ``` | |
| """) | |
| # API Documentation | |
| gr.Markdown(""" | |
| --- | |
| ## π API Usage | |
| ### Single Image API | |
| **Using Gradio Client (Python):** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("Wicky/bus-inspection-classifier") | |
| result = client.predict("bus_image.jpg", api_name="/predict") | |
| print(result) | |
| ``` | |
| ### Batch Processing API | |
| **Using Gradio Client (Python):** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("Wicky/bus-inspection-classifier") | |
| # Upload multiple images | |
| image_files = ["img1.jpg", "img2.jpg", "img3.jpg"] | |
| result = client.predict(image_files, api_name="/predict_batch") | |
| print(f"Total: {result['total_images']}") | |
| print(f"Successful: {result['successful_predictions']}") | |
| for res in result['results']: | |
| print(f"Image {res['image_index']}: {res['class_name']} ({res['confidence']})") | |
| ``` | |
| **Using Python Requests:** | |
| ```python | |
| import requests | |
| files = [ | |
| ('files', open('img1.jpg', 'rb')), | |
| ('files', open('img2.jpg', 'rb')), | |
| ('files', open('img3.jpg', 'rb')) | |
| ] | |
| response = requests.post( | |
| "https://Wicky-bus-inspection-classifier.hf.space/api/predict_batch", | |
| files=files | |
| ) | |
| results = response.json() | |
| print(results) | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |