Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import json | |
| from PIL import Image | |
| from torchvision import transforms | |
| import time | |
| import pandas as pd | |
| from pathlib import Path | |
| import io | |
| import base64 | |
| from reportlab.lib.pagesizes import letter, A4 | |
| from reportlab.lib import colors | |
| from reportlab.lib.units import inch | |
| from reportlab.platypus import SimpleDocTemplate, Table, TableStyle, Paragraph, Spacer, PageBreak, Image as RLImage | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.lib.enums import TA_CENTER, TA_LEFT | |
| from datetime import datetime | |
| print("β Packages installed!\n") | |
| print("π Creating Gradio Interface...\n") | |
| # ==================== LOAD MODEL & METADATA ==================== | |
| class BusClassifierInference: | |
| def __init__(self, model_path='deployment/bus_classifier_traced.pt', | |
| metadata_path='deployment/model_metadata.json'): | |
| """Initialize the inference model""" | |
| # Load metadata | |
| with open(metadata_path, 'r') as f: | |
| self.metadata = json.load(f) | |
| self.class_names = self.metadata['class_names'] | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"π§ Loading model on {self.device.upper()}...") | |
| # Try loading TorchScript first, fallback to PyTorch checkpoint | |
| try: | |
| self.model = torch.jit.load(model_path, map_location=self.device) | |
| print(f"β TorchScript model loaded from {model_path}") | |
| except: | |
| print(f"β οΈ TorchScript not found, loading PyTorch checkpoint...") | |
| from torchvision import models | |
| # Load checkpoint | |
| checkpoint = torch.load('deployment/bus_classifier.pth', map_location=self.device) | |
| # Recreate model architecture | |
| self.model = models.efficientnet_b0(weights=None) | |
| num_features = self.model.classifier[1].in_features | |
| self.model.classifier[1] = torch.nn.Linear(num_features, len(self.class_names)) | |
| # Load weights | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model = self.model.to(self.device) | |
| print(f"β PyTorch checkpoint loaded") | |
| self.model.eval() | |
| # Define transform | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((self.metadata['image_size'], self.metadata['image_size'])), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=self.metadata['normalization']['mean'], | |
| std=self.metadata['normalization']['std'] | |
| ) | |
| ]) | |
| print(f"β Model ready for inference!") | |
| print(f"π Classes: {', '.join(self.class_names)}\n") | |
| def predict_single(self, image): | |
| """Predict class for a single image""" | |
| start_time = time.time() | |
| # Load image if path provided | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert('RGB') | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert('RGB') | |
| # Preprocess | |
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| logits = self.model(input_tensor) | |
| probs = torch.softmax(logits, dim=1) | |
| pred_class_idx = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_class_idx].item() | |
| inference_time = time.time() - start_time | |
| # Get all probabilities | |
| all_probs = { | |
| self.class_names[i]: float(probs[0][i].item()) | |
| for i in range(len(self.class_names)) | |
| } | |
| # Sort by confidence | |
| sorted_probs = dict(sorted(all_probs.items(), key=lambda x: x[1], reverse=True)) | |
| return { | |
| 'predicted_class': self.class_names[pred_class_idx], | |
| 'confidence': confidence, | |
| 'all_probabilities': sorted_probs, | |
| 'inference_time_ms': inference_time * 1000 | |
| } | |
| def predict_batch(self, images): | |
| """Predict for multiple images""" | |
| results = [] | |
| total_start = time.time() | |
| for idx, image in enumerate(images): | |
| result = self.predict_single(image) | |
| result['image_index'] = idx + 1 | |
| results.append(result) | |
| total_time = time.time() - total_start | |
| return results, total_time | |
| # Initialize model | |
| print("="*80) | |
| predictor = BusClassifierInference() | |
| print("="*80) | |
| # ==================== PDF GENERATION FUNCTION ==================== | |
| def generate_pdf_report(results, images, total_time): | |
| """Generate a professional PDF report""" | |
| # Create temporary file | |
| pdf_filename = f"classification_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf" | |
| # Create PDF | |
| doc = SimpleDocTemplate(pdf_filename, pagesize=letter) | |
| story = [] | |
| styles = getSampleStyleSheet() | |
| # Custom styles | |
| title_style = ParagraphStyle( | |
| 'CustomTitle', | |
| parent=styles['Heading1'], | |
| fontSize=24, | |
| textColor=colors.HexColor('#667eea'), | |
| spaceAfter=30, | |
| alignment=TA_CENTER, | |
| fontName='Helvetica-Bold' | |
| ) | |
| heading_style = ParagraphStyle( | |
| 'CustomHeading', | |
| parent=styles['Heading2'], | |
| fontSize=16, | |
| textColor=colors.HexColor('#333333'), | |
| spaceAfter=12, | |
| spaceBefore=12, | |
| fontName='Helvetica-Bold' | |
| ) | |
| # Title | |
| story.append(Paragraph("π Bus Component Classification Report", title_style)) | |
| story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", styles['Normal'])) | |
| story.append(Spacer(1, 0.3*inch)) | |
| # Summary Section | |
| story.append(Paragraph("π Executive Summary", heading_style)) | |
| summary_data = [ | |
| ['Metric', 'Value'], | |
| ['Total Images Processed', str(len(images))], | |
| ['Total Processing Time', f'{total_time:.2f} seconds'], | |
| ['Average Time per Image', f'{total_time/len(images)*1000:.2f} ms'], | |
| ['Model Used', 'EfficientNet-B0'], | |
| ['Model Accuracy', '98.71%'], | |
| ['Device', predictor.device.upper()], | |
| ] | |
| summary_table = Table(summary_data, colWidths=[3*inch, 3*inch]) | |
| summary_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#667eea')), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('ALIGN', (0, 0), (-1, -1), 'LEFT'), | |
| ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), | |
| ('FONTSIZE', (0, 0), (-1, 0), 12), | |
| ('BOTTOMPADDING', (0, 0), (-1, 0), 12), | |
| ('BACKGROUND', (0, 1), (-1, -1), colors.beige), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('FONTNAME', (0, 1), (-1, -1), 'Helvetica'), | |
| ('FONTSIZE', (0, 1), (-1, -1), 10), | |
| ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), | |
| ])) | |
| story.append(summary_table) | |
| story.append(Spacer(1, 0.3*inch)) | |
| # Performance Metrics | |
| story.append(Paragraph("π Performance Metrics", heading_style)) | |
| avg_confidence = sum([r['confidence'] for r in results]) / len(results) | |
| high_conf = sum([1 for r in results if r['confidence'] >= 0.95]) | |
| medium_conf = sum([1 for r in results if 0.80 <= r['confidence'] < 0.95]) | |
| low_conf = sum([1 for r in results if r['confidence'] < 0.80]) | |
| perf_data = [ | |
| ['Performance Metric', 'Value', 'Percentage'], | |
| ['Average Confidence', f'{avg_confidence*100:.2f}%', '-'], | |
| ['High Confidence (β₯95%)', str(high_conf), f'{high_conf/len(images)*100:.1f}%'], | |
| ['Medium Confidence (80-95%)', str(medium_conf), f'{medium_conf/len(images)*100:.1f}%'], | |
| ['Low Confidence (<80%)', str(low_conf), f'{low_conf/len(images)*100:.1f}%'], | |
| ] | |
| perf_table = Table(perf_data, colWidths=[2.5*inch, 1.5*inch, 1.5*inch]) | |
| perf_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#4CAF50')), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
| ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), | |
| ('FONTSIZE', (0, 0), (-1, 0), 11), | |
| ('BOTTOMPADDING', (0, 0), (-1, 0), 12), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), | |
| ])) | |
| story.append(perf_table) | |
| story.append(Spacer(1, 0.3*inch)) | |
| # Class Distribution | |
| story.append(Paragraph("π¦ Class Distribution", heading_style)) | |
| class_counts = {} | |
| for result in results: | |
| pred = result['predicted_class'] | |
| class_counts[pred] = class_counts.get(pred, 0) + 1 | |
| dist_data = [['Class Name', 'Count', 'Percentage']] | |
| for class_name, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True): | |
| dist_data.append([class_name, str(count), f'{count/len(images)*100:.1f}%']) | |
| dist_table = Table(dist_data, colWidths=[3*inch, 1.5*inch, 1.5*inch]) | |
| dist_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2196F3')), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
| ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), | |
| ('FONTSIZE', (0, 0), (-1, 0), 11), | |
| ('BOTTOMPADDING', (0, 0), (-1, 0), 12), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), | |
| ])) | |
| story.append(dist_table) | |
| story.append(PageBreak()) | |
| # Detailed Results | |
| story.append(Paragraph("π Detailed Classification Results", heading_style)) | |
| story.append(Spacer(1, 0.2*inch)) | |
| # Create detailed table | |
| detail_data = [['#', 'Predicted Class', 'Confidence', 'Time (ms)', '2nd Best', '2nd Conf']] | |
| for result in results: | |
| second_best = list(result['all_probabilities'].keys())[1] | |
| second_conf = list(result['all_probabilities'].values())[1] | |
| detail_data.append([ | |
| str(result['image_index']), | |
| result['predicted_class'], | |
| f"{result['confidence']*100:.2f}%", | |
| f"{result['inference_time_ms']:.2f}", | |
| second_best, | |
| f"{second_conf*100:.2f}%" | |
| ]) | |
| detail_table = Table(detail_data, colWidths=[0.5*inch, 1.8*inch, 1*inch, 0.9*inch, 1.8*inch, 1*inch]) | |
| detail_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#764ba2')), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('ALIGN', (0, 0), (-1, -1), 'CENTER'), | |
| ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'), | |
| ('FONTSIZE', (0, 0), (-1, 0), 9), | |
| ('BOTTOMPADDING', (0, 0), (-1, 0), 12), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('FONTSIZE', (0, 1), (-1, -1), 8), | |
| ('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.lightgrey]), | |
| ])) | |
| story.append(detail_table) | |
| story.append(Spacer(1, 0.3*inch)) | |
| # Footer | |
| story.append(Spacer(1, 0.5*inch)) | |
| footer_style = ParagraphStyle( | |
| 'Footer', | |
| parent=styles['Normal'], | |
| fontSize=9, | |
| textColor=colors.grey, | |
| alignment=TA_CENTER | |
| ) | |
| story.append(Paragraph("Bus Component Classification System v1.0 | Powered by EfficientNet-B0", footer_style)) | |
| story.append(Paragraph("This report is auto-generated and contains AI predictions.", footer_style)) | |
| # Build PDF | |
| doc.build(story) | |
| print(f"β PDF Report generated: {pdf_filename}") | |
| return pdf_filename | |
| # ==================== GRADIO INTERFACE FUNCTIONS ==================== | |
| def predict_images(images): | |
| """Main prediction function for Gradio interface""" | |
| if images is None or len(images) == 0: | |
| return "<h3 style='color: #F44336; text-align: center;'>β οΈ Please upload at least one image!</h3>", None | |
| if len(images) > 50: | |
| return f"<h3 style='color: #F44336; text-align: center;'>β οΈ Maximum 50 images allowed! You uploaded {len(images)} images.</h3>", None | |
| print(f"\nπ Processing {len(images)} image(s)...") | |
| # Get predictions | |
| results, total_time = predictor.predict_batch(images) | |
| # Generate PDF Report | |
| pdf_file = generate_pdf_report(results, images, total_time) | |
| # Calculate class distribution | |
| class_counts = {} | |
| for result in results: | |
| pred = result['predicted_class'] | |
| class_counts[pred] = class_counts.get(pred, 0) + 1 | |
| # ==================== BUILD COMPACT GRID OUTPUT ==================== | |
| html_output = f""" | |
| <div style="font-family: 'Segoe UI', Arial, sans-serif; max-width: 1400px; margin: 0 auto;"> | |
| <!-- Summary Stats --> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 12px 20px; border-radius: 8px; margin-bottom: 20px; color: white; display: flex; justify-content: space-around; align-items: center; flex-wrap: wrap; gap: 10px;"> | |
| <div><strong>π Images:</strong> {len(images)}</div> | |
| <div><strong>β±οΈ Total Time:</strong> {total_time:.2f}s</div> | |
| <div><strong>β‘ Avg Time:</strong> {total_time/len(images)*1000:.0f}ms</div> | |
| <div><strong>π― High Confidence:</strong> {sum([1 for r in results if r['confidence'] >= 0.95])}/{len(images)}</div> | |
| </div> | |
| <!-- Class Distribution Chart --> | |
| <div style="background: white; padding: 15px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #667eea;"> | |
| <h3 style="margin: 0 0 15px 0; color: #333; font-size: 18px;">π¦ Class Distribution</h3> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 12px;"> | |
| """ | |
| # Add class distribution bars | |
| for class_name, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True): | |
| percentage = (count / len(images)) * 100 | |
| html_output += f""" | |
| <div style="background: #f5f5f5; padding: 12px; border-radius: 6px; border-left: 4px solid #667eea;"> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 6px;"> | |
| <strong style="color: #333; font-size: 13px;">{class_name}</strong> | |
| <span style="color: #667eea; font-weight: bold; font-size: 13px;">{count}</span> | |
| </div> | |
| <div style="background: #e0e0e0; height: 8px; border-radius: 4px; overflow: hidden;"> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); width: {percentage}%; height: 100%;"></div> | |
| </div> | |
| <div style="text-align: right; margin-top: 4px; color: #666; font-size: 11px;">{percentage:.1f}%</div> | |
| </div> | |
| """ | |
| html_output += """ | |
| </div> | |
| </div> | |
| <!-- Results Grid (4 per row) --> | |
| <h3 style="margin: 20px 0 15px 0; color: #333; font-size: 18px;">π Detailed Results</h3> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(280px, 1fr)); gap: 15px;"> | |
| """ | |
| # Individual predictions in grid | |
| for idx, result in enumerate(results): | |
| pred_class = result['predicted_class'] | |
| confidence = result['confidence'] | |
| inf_time = result['inference_time_ms'] | |
| # Color based on confidence | |
| if confidence >= 0.95: | |
| border_color = "#4CAF50" | |
| badge_color = "#4CAF50" | |
| elif confidence >= 0.80: | |
| border_color = "#FF9800" | |
| badge_color = "#FF9800" | |
| else: | |
| border_color = "#F44336" | |
| badge_color = "#F44336" | |
| # Get the actual image | |
| img = images[idx] | |
| if isinstance(img, str): | |
| with open(img, 'rb') as f: | |
| img_data = f.read() | |
| else: | |
| img_pil = Image.open(img).convert('RGB') | |
| buffer = io.BytesIO() | |
| img_pil.save(buffer, format='JPEG') | |
| img_data = buffer.getvalue() | |
| img_base64 = base64.b64encode(img_data).decode() | |
| html_output += f""" | |
| <div style="border: 3px solid {border_color}; border-radius: 10px; overflow: hidden; background: white; box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> | |
| <!-- Image --> | |
| <div style="position: relative;"> | |
| <img src="data:image/jpeg;base64,{img_base64}" | |
| style="width: 100%; height: 200px; object-fit: cover;" | |
| alt="Image {idx+1}"> | |
| <div style="position: absolute; top: 8px; left: 8px; background: rgba(0,0,0,0.7); color: white; padding: 4px 10px; border-radius: 5px; font-size: 12px; font-weight: bold;"> | |
| #{idx+1} | |
| </div> | |
| </div> | |
| <!-- Prediction Info --> | |
| <div style="padding: 12px;"> | |
| <div style="background: {badge_color}; color: white; padding: 8px 12px; border-radius: 6px; margin-bottom: 8px; text-align: center;"> | |
| <div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">{pred_class}</div> | |
| <div style="font-size: 18px; font-weight: bold;">{confidence*100:.1f}%</div> | |
| </div> | |
| <div style="font-size: 11px; color: #666; text-align: center;"> | |
| β±οΈ {inf_time:.1f}ms | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| html_output += """ | |
| </div> | |
| </div> | |
| """ | |
| print(f"β Complete! Processed {len(images)} images in {total_time:.2f}s\n") | |
| return html_output, pdf_file | |
| # ==================== CREATE MINIMAL GRADIO INTERFACE ==================== | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| /* Upload button styling */ | |
| .upload-button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| font-size: 16px !important; | |
| font-weight: bold !important; | |
| padding: 25px 40px !important; | |
| border-radius: 12px !important; | |
| border: 3px dashed rgba(255, 255, 255, 0.5) !important; | |
| cursor: pointer !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .upload-button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4) !important; | |
| border-color: white !important; | |
| } | |
| details summary { | |
| cursor: pointer; | |
| padding: 10px 15px; | |
| background: #f0f0f0; | |
| border-radius: 6px; | |
| font-weight: bold; | |
| color: #333; | |
| border: 1px solid #ddd; | |
| user-select: none; | |
| } | |
| details[open] summary { | |
| background: #667eea; | |
| color: white; | |
| border-color: #667eea; | |
| } | |
| details { | |
| margin-bottom: 15px; | |
| } | |
| details div { | |
| padding: 10px 15px; | |
| background: white; | |
| border: 1px solid #ddd; | |
| border-top: none; | |
| border-radius: 0 0 6px 6px; | |
| max-height: 200px; | |
| overflow-y: auto; | |
| } | |
| """ | |
| with gr.Blocks(title="π Bus Classifier", css=custom_css) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; margin-bottom: 20px; box-shadow: 0 4px 15px rgba(102,126,234,0.4);"> | |
| <h1 style="color: white; font-size: 32px; margin: 0; font-weight: bold;">π Bus Component Classifier</h1> | |
| <p style="color: white; font-size: 15px; margin: 8px 0 0 0; opacity: 0.95;">EfficientNet-B0 | Accuracy: 98.71% | Real-time Classification</p> | |
| </div> | |
| """) | |
| # Collapsible System Info | |
| with gr.Accordion("π System Information", open=False): | |
| gr.HTML(f""" | |
| <div style="padding: 15px; background: white; border-radius: 8px; border: 2px solid #667eea;"> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); gap: 15px; margin-bottom: 15px;"> | |
| <div style="background: #f0f4ff; padding: 12px; border-radius: 6px; border-left: 4px solid #667eea;"> | |
| <strong style="color: #333; font-size: 14px;">Model:</strong> | |
| <span style="color: #667eea; font-weight: bold; font-size: 14px;">EfficientNet-B0</span> | |
| </div> | |
| <div style="background: #f0f4ff; padding: 12px; border-radius: 6px; border-left: 4px solid #667eea;"> | |
| <strong style="color: #333; font-size: 14px;">Classes:</strong> | |
| <span style="color: #667eea; font-weight: bold; font-size: 14px;">{len(predictor.class_names)}</span> | |
| </div> | |
| <div style="background: #f0f4ff; padding: 12px; border-radius: 6px; border-left: 4px solid #4CAF50;"> | |
| <strong style="color: #333; font-size: 14px;">Accuracy:</strong> | |
| <span style="color: #4CAF50; font-weight: bold; font-size: 14px;">98.71%</span> | |
| </div> | |
| <div style="background: #f0f4ff; padding: 12px; border-radius: 6px; border-left: 4px solid #FF9800;"> | |
| <strong style="color: #333; font-size: 14px;">Device:</strong> | |
| <span style="color: #FF9800; font-weight: bold; font-size: 14px;">{predictor.device.upper()}</span> | |
| </div> | |
| <div style="background: #f0f4ff; padding: 12px; border-radius: 6px; border-left: 4px solid #2196F3;"> | |
| <strong style="color: #333; font-size: 14px;">Max Images:</strong> | |
| <span style="color: #2196F3; font-weight: bold; font-size: 14px;">50 per batch</span> | |
| </div> | |
| </div> | |
| <div style="padding: 15px; background: #f9f9f9; border-radius: 6px; border: 2px solid #ddd;"> | |
| <div style="margin-bottom: 8px;"> | |
| <strong style="color: #333; font-size: 15px;">π¦ Supported Classes:</strong> | |
| </div> | |
| <div style="display: flex; flex-wrap: wrap; gap: 8px;"> | |
| {' '.join([f'<span style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 6px 12px; border-radius: 20px; font-size: 13px; font-weight: bold; display: inline-block;">{cls}</span>' for cls in predictor.class_names])} | |
| </div> | |
| </div> | |
| </div> | |
| """) | |
| # Upload Section with clear button | |
| gr.HTML(""" | |
| <div style="margin: 20px 0 15px 0;"> | |
| <h3 style="color: #333; font-size: 20px; margin: 0; font-weight: bold;">π€ Upload Images</h3> | |
| <p style="color: #666; font-size: 14px; margin: 5px 0 0 0;">Click the button below to select images (JPG, PNG | Max: 50 images)</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| label="", | |
| show_label=False, | |
| elem_classes=["upload-button"] | |
| ) | |
| # File count and collapsible list | |
| file_list_html = gr.HTML() | |
| def update_file_list(files): | |
| if not files or len(files) == 0: | |
| return "" | |
| file_count = len(files) | |
| # Show first 5 files | |
| visible_files = files[:5] if file_count > 5 else files | |
| html = f""" | |
| <div style="background: #f5f5f5; padding: 15px; border-radius: 8px; margin: 10px 0; border: 2px solid #ddd;"> | |
| <div style="font-weight: bold; color: #333; margin-bottom: 10px; font-size: 16px;"> | |
| π {file_count} image{'s' if file_count != 1 else ''} selected | |
| </div> | |
| """ | |
| # Show first 5 files | |
| for idx, file in enumerate(visible_files): | |
| filename = file.name if hasattr(file, 'name') else str(file).split('/')[-1] | |
| html += f""" | |
| <div style="background: white; padding: 8px 12px; margin: 5px 0; border-radius: 5px; border-left: 3px solid #667eea; font-size: 13px; color: #333;"> | |
| {idx + 1}. {filename} | |
| </div> | |
| """ | |
| # If more than 5, show collapsible | |
| if file_count > 5: | |
| html += f""" | |
| <details style="margin-top: 10px;"> | |
| <summary style="cursor: pointer; padding: 8px 12px; background: #667eea; color: white; border-radius: 5px; font-size: 14px; font-weight: bold;"> | |
| β Show {file_count - 5} more files | |
| </summary> | |
| <div style="max-height: 200px; overflow-y: auto; padding: 10px; background: white; margin-top: 5px; border-radius: 5px;"> | |
| """ | |
| for idx, file in enumerate(files[5:], start=6): | |
| filename = file.name if hasattr(file, 'name') else str(file).split('/')[-1] | |
| html += f""" | |
| <div style="padding: 6px 10px; margin: 3px 0; border-radius: 4px; border-left: 3px solid #764ba2; font-size: 12px; color: #333; background: #f9f9f9;"> | |
| {idx}. {filename} | |
| </div> | |
| """ | |
| html += """ | |
| </div> | |
| </details> | |
| """ | |
| html += "</div>" | |
| return html | |
| image_input.change( | |
| fn=update_file_list, | |
| inputs=[image_input], | |
| outputs=[file_list_html] | |
| ) | |
| # Buttons | |
| with gr.Row(): | |
| predict_btn = gr.Button( | |
| "π Classify Images", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "ποΈ Clear All", | |
| size="lg" | |
| ) | |
| # Results Section | |
| gr.HTML(""" | |
| <div style="margin: 25px 0 15px 0;"> | |
| <h3 style="color: #333; font-size: 20px; margin: 0; font-weight: bold;">π Classification Results</h3> | |
| </div> | |
| """) | |
| results_output = gr.HTML() | |
| # PDF Download Section | |
| gr.HTML(""" | |
| <div style="margin: 20px 0 10px 0;"> | |
| <h3 style="color: #333; font-size: 18px; margin: 0; font-weight: bold;">π Download Report</h3> | |
| </div> | |
| """) | |
| pdf_output = gr.File(label="", show_label=False) | |
| # Footer Info (Collapsible) | |
| with gr.Accordion("βΉοΈ How to Interpret Results", open=False): | |
| gr.HTML(""" | |
| <div style="padding: 15px; background: #f9f9f9; border-radius: 6px; font-size: 13px; line-height: 1.8;"> | |
| <div style="margin: 8px 0;"><span style="color: #4CAF50; font-weight: bold; font-size: 20px;">β</span> <strong style="color: #4CAF50;">Green (β₯95%):</strong> High confidence - Very reliable prediction</div> | |
| <div style="margin: 8px 0;"><span style="color: #FF9800; font-weight: bold; font-size: 20px;">β</span> <strong style="color: #FF9800;">Orange (80-95%):</strong> Medium confidence - Generally reliable</div> | |
| <div style="margin: 8px 0;"><span style="color: #F44336; font-weight: bold; font-size: 20px;">β</span> <strong style="color: #F44336;">Red (<80%):</strong> Low confidence - Manual review recommended</div> | |
| </div> | |
| """) | |
| # Button actions | |
| predict_btn.click( | |
| fn=predict_images, | |
| inputs=[image_input], | |
| outputs=[results_output, pdf_output] | |
| ) | |
| def clear_all(): | |
| return None, None, None, "" | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[image_input, results_output, pdf_output, file_list_html] | |
| ) | |
| # ==================== LAUNCH ==================== | |
| print("\n" + "="*80) | |
| print("π LAUNCHING GRADIO INTERFACE (LOCAL)") | |
| print("="*80) | |
| print(f"Model: EfficientNet-B0") | |
| print(f"Classes: {len(predictor.class_names)}") | |
| print(f"Device: {predictor.device.upper()}") | |
| print(f"{'='*80}\n") | |
| demo.launch() |