# -*- coding: utf-8 -*- import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms, datasets from torch.utils.data import DataLoader, ConcatDataset from PIL import Image, ImageOps # Added ImageOps import os import tempfile import random import shutil import matplotlib.pyplot as plt from sklearn.decomposition import PCA import numpy as np import cv2 # from PIL import Image # Already imported from ultralytics import YOLO import pandas as pd import json # Added for saving/loading metadata st.set_page_config(layout="wide") # --- Constants and Setup --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SAVED_MODELS_DIR = "saved_few_shot_models" BASE_DATASET = "App data/data_optimize/Basee_data" RARE_DATASET = "App data/data_optimize/Rare data/" MODEL_WEIGHTS_PATH = "model/efficientnet_coffee (1).pth" # For Standard Classifier and initial feature extractor state YOLO_MODEL_PATH = "model/best.pt" # For Detection # Ensure directories exist os.makedirs(SAVED_MODELS_DIR, exist_ok=True) os.makedirs(RARE_DATASET, exist_ok=True) # Create rare data dir if not present # Check required files/folders exist if not os.path.isdir(BASE_DATASET): st.error(f"Base dataset directory not found: {BASE_DATASET}") st.stop() if not os.path.isfile(MODEL_WEIGHTS_PATH): st.error(f"Classifier weights file not found: {MODEL_WEIGHTS_PATH}") st.stop() if not os.path.isfile(YOLO_MODEL_PATH): st.error(f"YOLO detection model file not found: {YOLO_MODEL_PATH}") st.stop() st.sidebar.info(f"Using device: {DEVICE}") TEMP_DIR = tempfile.mkdtemp() # For temporary files if needed # --- Helper Functions for Saving/Loading Few-Shot States --- # d def list_saved_models(): """Returns a list of names of saved few-shot model states.""" if not os.path.isdir(SAVED_MODELS_DIR): return [] # List only directories within the SAVED_MODELS_DIR return [d for d in os.listdir(SAVED_MODELS_DIR) if os.path.isdir(os.path.join(SAVED_MODELS_DIR, d))] def save_few_shot_state(name, model, prototypes, proto_labels, current_class_names): """Saves the model state, prototypes, and metadata.""" if not name or not name.strip(): st.error("Please provide a valid name for the saved model.") return False # Sanitize name for directory creation sanitized_name = "".join(c for c in name if c.isalnum() or c in ('_', '-')).rstrip() if not sanitized_name: st.error("Invalid name after sanitization. Use letters, numbers, underscore, or hyphen.") return False save_dir = os.path.join(SAVED_MODELS_DIR, sanitized_name) # Handle existing directory (Ask for overwrite confirmation) if os.path.exists(save_dir): # Use columns for better layout of overwrite confirmation col1, col2 = st.columns([3,1]) with col1: st.warning(f"Model name '{sanitized_name}' already exists.") with col2: # Use a unique key for the overwrite button based on the name overwrite_key = f"overwrite_{sanitized_name}" if st.button("Overwrite?", key=overwrite_key): st.info(f"Overwriting '{sanitized_name}'...") # Proceed with saving below else: st.info("Save cancelled. Choose a different name or click 'Overwrite?'.") return False # Stop if not confirmed overwrite else: st.info(f"Saving new model state '{sanitized_name}'...") try: os.makedirs(save_dir, exist_ok=True) # Create directory if it doesn't exist or was confirmed for overwrite # 1. Save Model State Dictionary (ensure model is on CPU before saving state_dict for better compatibility) model.to('cpu') # Move model to CPU model_path = os.path.join(save_dir, "feature_extractor_state_dict.pth") torch.save(model.state_dict(), model_path) model.to(DEVICE) # Move model back to original device # 2. Save Prototypes Tensor (move to CPU before saving) prototypes_path = os.path.join(save_dir, "prototypes.pt") torch.save(prototypes.cpu(), prototypes_path) # 3. Save Metadata (Labels and Class Names active during training) metadata = { "prototype_labels": proto_labels, # Should be a standard list "class_names_on_save": current_class_names # List of strings } metadata_path = os.path.join(save_dir, "metadata.json") with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=4) st.success(f"Few-shot model state '{sanitized_name}' saved successfully!") return True except Exception as e: st.error(f"Error saving model state '{sanitized_name}': {e}") st.exception(e) # Show full traceback for debugging # Clean up potentially partially saved directory if error occurred AFTER creation/overwrite confirmation if os.path.exists(save_dir): try: shutil.rmtree(save_dir) st.info(f"Cleaned up partially saved directory '{save_dir}'.") except Exception as cleanup_e: st.error(f"Error cleaning up directory during save failure: {cleanup_e}") return False def load_few_shot_state(name, model_to_load_into, current_class_names): """Loads a saved model state, prototypes, and labels into session state and the model.""" load_dir = os.path.join(SAVED_MODELS_DIR, name) if not os.path.isdir(load_dir): st.error(f"Saved model directory '{load_dir}' not found.") return False model_path = os.path.join(load_dir, "feature_extractor_state_dict.pth") prototypes_path = os.path.join(load_dir, "prototypes.pt") metadata_path = os.path.join(load_dir, "metadata.json") if not all(os.path.exists(p) for p in [model_path, prototypes_path, metadata_path]): st.error(f"Saved model '{name}' is incomplete. Files missing in '{load_dir}'.") return False try: # 1. Load Metadata with open(metadata_path, 'r') as f: metadata = json.load(f) loaded_proto_labels = metadata.get("prototype_labels") saved_class_names = metadata.get("class_names_on_save") if loaded_proto_labels is None or saved_class_names is None: st.error(f"Metadata file for '{name}' is corrupted or missing required keys.") return False # **CRUCIAL CHECK**: Compare saved class names with current class names if set(saved_class_names) != set(current_class_names): st.warning(f"⚠️ **Class Mismatch!**") st.warning(f"Saved model '{name}' classes: `{saved_class_names}`") st.warning(f"Current active classes: `{current_class_names}`") st.warning("Predictions might be incorrect or errors may occur. Proceed with caution.") # Decide whether to proceed or stop. Let's proceed with warning. # 2. Load Model State Dictionary # Ensure model architecture is correct *before* loading state dict model_to_load_into.to(DEVICE) # Move model to target device state_dict = torch.load(model_path, map_location=DEVICE) # Load state dict to target device try: # Use strict=False initially if unsure about architecture changes missing_keys, unexpected_keys = model_to_load_into.load_state_dict(state_dict, strict=True) if missing_keys: st.warning(f"Loaded state dict is missing keys: {missing_keys}") if unexpected_keys: st.warning(f"Loaded state dict has unexpected keys: {unexpected_keys}") except RuntimeError as e: st.error(f"RuntimeError loading state_dict for '{name}'. Architecture mismatch?") st.error(e) return False model_to_load_into.eval() # Set to evaluation mode after loading # 3. Load Prototypes (load directly to target device) loaded_prototypes = torch.load(prototypes_path, map_location=DEVICE) # 4. Update Session State st.session_state.final_prototypes = loaded_prototypes st.session_state.prototype_labels = loaded_proto_labels st.session_state.few_shot_trained = True # Mark as trained since we loaded a state st.session_state.model_mode = 'few_shot' # Switch to few-shot mode st.success(f"Successfully loaded few-shot model state '{name}'. Mode set to Few-Shot.") return True except Exception as e: st.error(f"Error loading model state '{name}': {e}") st.exception(e) # Optional: Reset state if loading fails partially # st.session_state.final_prototypes = None # st.session_state.prototype_labels = None # st.session_state.few_shot_trained = False # st.session_state.model_mode = 'standard' return False def delete_saved_model(name): """Deletes a saved model directory.""" delete_dir = os.path.join(SAVED_MODELS_DIR, name) if not os.path.isdir(delete_dir): st.error(f"Cannot delete. Saved model '{name}' not found.") return False try: shutil.rmtree(delete_dir) st.success(f"Deleted saved model '{name}'.") return True except Exception as e: st.error(f"Error deleting saved model '{name}': {e}") return False # --- Model Architectures --- # EfficientNet feature extractor with projection layer class EfficientNetWithProjection(nn.Module): def __init__(self, base_model, output_dim=1024): super(EfficientNetWithProjection, self).__init__() self.model = base_model # This holds the EfficientNet base (feature extractor part) # Determine the input feature size dynamically from the base model if possible # For EfficientNetB0, the layer before the classifier has 1280 features in_features = 1280 # Hardcoding for effnetb0 is reliable here self.projection = nn.Linear(in_features, output_dim) # Projection layer def forward(self, x): features = self.model(x) # Get features from EfficientNet base return self.projection(features) # Project to output_dim dimensions # Base EfficientNet model structure (used for loading weights) def get_base_efficientnet_architecture(num_classes=5): # Load architecture only, ensure it matches the saved weights structure model = models.efficientnet_b0(weights=None) # Start with no pretrained weights here in_features = model.classifier[1].in_features # Get feature dimension model.classifier[1] = nn.Linear(in_features, num_classes) # Adjust final layer to match saved model return model # Feature Extractor model structure (for few-shot) def get_feature_extractor_base(): # This function prepares the base model, loads coffee weights, then removes classifier # Start with the architecture matching the saved weights (5 classes) base_model = get_base_efficientnet_architecture(num_classes=5) # Load the coffee-specific weights into this matching architecture try: # Use weights_only=True for security unless the model itself is saved state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) #, weights_only=True) # Set weights_only based on how model was saved # Load weights strictly as the architecture matches exactly missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=True) # No warnings needed if strict=True and it passes except Exception as e: st.error(f"Error loading model weights from {MODEL_WEIGHTS_PATH} into base architecture: {e}") st.exception(e) st.stop() # Remove the classifier to use it as a feature extractor base_model.classifier = nn.Identity() base_model.eval() # Set base model to eval mode return base_model # Function to load the standard classifier model (for reset/fallback) def load_standard_classifier(): # This loads the model intended for standard 5-class classification model = get_base_efficientnet_architecture(num_classes=5) # Get the 5-class architecture try: # Use weights_only=True for security if appropriate state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) #, weights_only=True) model.load_state_dict(state_dict, strict=True) # Strict loading except Exception as e: st.error(f"Error loading model weights for standard classifier: {e}") st.exception(e) st.stop() model.to(DEVICE) model.eval() return model # --- Caching --- # Cache the feature extractor model resource (Base + Projection) @st.cache_resource def cached_feature_extractor_model(): # This function creates the base, loads weights, removes classifier, then adds projection base_model = get_feature_extractor_base() model = EfficientNetWithProjection(base_model, output_dim=1024) model.to(DEVICE) model.eval() st.sidebar.info("Feature extractor model ready (cached).") return model # Cache the standard classifier model resource @st.cache_resource def cached_standard_classifier(): model = load_standard_classifier() st.sidebar.info("Standard classifier model ready (cached).") return model # Cache data loading and processing @st.cache_data def get_combined_dataset_and_indices(base_path, rare_path): try: # Define transform inside function or ensure it's globally defined before call transform_local = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load base dataset full_dataset = datasets.ImageFolder(base_path, transform_local) num_base_classes = len(full_dataset.classes) base_class_names = sorted(full_dataset.classes) # Get base names sorted # Load rare dataset if it exists and has content rare_classes_found = 0 rare_class_names = [] if os.path.isdir(rare_path) and any(os.scandir(rare_path)): # Check if dir exists and is not empty try: rare_dataset = datasets.ImageFolder(rare_path, transform_local) if len(rare_dataset.samples) > 0: # IMPORTANT: Adjust labels for rare classes to start after base classes rare_dataset.samples = [(path, label + num_base_classes) for path, label in rare_dataset.samples] combined_dataset = ConcatDataset([full_dataset, rare_dataset]) rare_classes_found = len(rare_dataset.classes) rare_class_names = sorted(rare_dataset.classes) # Get rare names sorted else: combined_dataset = full_dataset # Rare dir exists but no samples except Exception as e_rare: st.warning(f"Could not load rare dataset from {rare_path}: {e_rare}. Using base dataset only.") combined_dataset = full_dataset else: combined_dataset = full_dataset # Rare dir doesn't exist or is empty # Create class indices mapping (label -> list of dataset indices) indices = {} current_idx = 0 # Iterate through the combined dataset structure correctly if isinstance(combined_dataset, ConcatDataset): for ds in combined_dataset.datasets: for _, label in ds.samples: # We only need the label here indices.setdefault(label, []).append(current_idx) current_idx += 1 else: # Only base dataset for idx, (_, label) in enumerate(combined_dataset.samples): indices.setdefault(label, []).append(idx) # Combine class names in the correct order class_names = base_class_names + rare_class_names # Display stats in sidebar st.sidebar.metric("Base Classes", num_base_classes) st.sidebar.metric("Rare Classes", rare_classes_found) st.sidebar.metric("Total Classes", len(class_names)) if len(class_names) == 0: st.error("No classes found in base or rare datasets. Check paths/contents.") st.stop() return combined_dataset, indices, class_names, num_base_classes except FileNotFoundError as e: st.error(f"Dataset path error: {e}. Check BASE_DATASET ('{base_path}') and RARE_DATASET ('{rare_path}').") st.stop() except Exception as e: st.error(f"Error loading datasets: {e}") st.exception(e) # Show full traceback st.stop() # --- Re-define transform globally if not done inside get_combined_dataset_and_indices --- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- Few-Shot Learning Functions --- def create_episode(dataset, class_indices, class_list, n_way=5, n_shot=5, n_query=5): """Creates an episode for Prototypical Networks.""" available_classes = list(class_indices.keys()) # Get labels (0, 1, 2...) that have samples if len(available_classes) < n_way: n_way = len(available_classes) # Adjust n_way if not enough classes if n_way < 2: # st.warning("Cannot create episode: < 2 classes available.") # Less verbose return None, None, None, None # Sample N-way CLASS LABELS from the available ones selected_class_ids = random.sample(available_classes, n_way) support_imgs, query_imgs = [], [] support_labels, query_labels = [], [] # Map original label (e.g., 3, 7, 1) to episode-specific label (0, 1, 2) episode_class_map = {original_label: episode_label for episode_label, original_label in enumerate(selected_class_ids)} actual_n_way = 0 # Track classes successfully added for original_label in selected_class_ids: indices_for_class = class_indices.get(original_label, []) min_samples_needed = n_shot + n_query if len(indices_for_class) < min_samples_needed: # Skip class for this episode if not enough samples continue # Sample indices FOR THIS CLASS from the dataset indices list sampled_indices = random.sample(indices_for_class, min_samples_needed) # Get images using the sampled dataset indices try: support_imgs += [dataset[i][0] for i in sampled_indices[:n_shot]] query_imgs += [dataset[i][0] for i in sampled_indices[n_shot:]] except IndexError as e: st.error(f"IndexError during episode creation for class {original_label}, index {i}. Check dataset/indices integrity.") st.exception(e) return None, None, None, None except Exception as e: st.error(f"Error retrieving data during episode creation: {e}") st.exception(e) return None, None, None, None # Assign new sequential labels (0 to n_way-1) for the episode new_label = episode_class_map[original_label] support_labels += [new_label] * n_shot query_labels += [new_label] * n_query actual_n_way += 1 # Increment count of classes successfully added if actual_n_way < 2: # Check if enough classes were actually added # st.warning(f"Episode creation resulted in < 2 valid classes ({actual_n_way}). Skipping.") return None, None, None, None # Return tensors on the correct device try: s_imgs_tensor = torch.stack(support_imgs).to(DEVICE) s_labels_tensor = torch.tensor(support_labels, dtype=torch.long).to(DEVICE) q_imgs_tensor = torch.stack(query_imgs).to(DEVICE) q_labels_tensor = torch.tensor(query_labels, dtype=torch.long).to(DEVICE) return s_imgs_tensor, s_labels_tensor, q_imgs_tensor, q_labels_tensor except Exception as e: st.error(f"Error stacking tensors in create_episode: {e}") st.exception(e) return None, None, None, None def proto_loss(support_embeddings, support_labels, query_embeddings, query_labels): """Calculates the Prototypical Network loss and accuracy.""" if support_embeddings is None or support_embeddings.numel() == 0 or \ query_embeddings is None or query_embeddings.numel() == 0: return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 unique_episode_labels = torch.unique(support_labels) n_way_actual = len(unique_episode_labels) if n_way_actual < 2: # Need at least 2 classes for meaningful loss/accuracy return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 prototypes = [] # Map the episode labels (e.g., 0, 1, 2...) to their prototype index (0, 1, 2...) proto_map = {label.item(): i for i, label in enumerate(unique_episode_labels)} for episode_label in unique_episode_labels: class_mask = (support_labels == episode_label) class_embeddings = support_embeddings[class_mask] if class_embeddings.size(0) > 0: prototypes.append(class_embeddings.mean(dim=0)) else: st.warning(f"ProtoLoss: No support embeddings found for episode label {episode_label}. Skipping.") return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 if len(prototypes) != n_way_actual: # Should match unique labels count st.warning("ProtoLoss: Mismatch between unique labels and calculated prototypes.") return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 prototypes = torch.stack(prototypes) # Shape: [n_way_actual, embedding_dim] # Ensure query labels correspond to the classes present in the support set valid_query_mask = torch.isin(query_labels, unique_episode_labels) if not torch.any(valid_query_mask): # st.warning("ProtoLoss: No query samples match support set labels.") return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 # No valid query samples filtered_query_embeddings = query_embeddings[valid_query_mask] filtered_query_labels_original = query_labels[valid_query_mask] # Keep original episode labels (e.g., 0, 1, 2...) # Calculate distances between filtered query embeddings and prototypes # Shape: [num_filtered_query, n_way_actual] distances = torch.cdist(filtered_query_embeddings, prototypes) # Get predictions based on nearest prototype (indices 0 to n_way_actual-1) predictions = torch.argmin(distances, dim=1) # Shape: [num_filtered_query] # Map the original filtered query labels to the prototype indices (0 to n_way_actual-1) for comparison mapped_true_labels = torch.tensor([proto_map[lbl.item()] for lbl in filtered_query_labels_original], dtype=torch.long).to(DEVICE) correct_predictions = (predictions == mapped_true_labels).sum().item() total_predictions = mapped_true_labels.size(0) accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 # Calculate cross-entropy loss using negative distances (closer = higher logit) loss = F.cross_entropy(-distances, mapped_true_labels) return loss, accuracy # Recalculate prototypes after training for the entire dataset @st.cache_data(show_spinner="Calculating final prototypes for all classes...") # Added spinner message def calculate_final_prototypes(_model, _dataset, _class_names): _model.eval() all_embeddings = {} # Dict: label -> list of tensors # Use DataLoader for efficient batch processing # Consider pinning memory if using GPU and workers > 0 loader = DataLoader(_dataset, batch_size=128, shuffle=False, num_workers=0) with torch.no_grad(): for imgs, labs in loader: imgs = imgs.to(DEVICE) try: emb = _model(imgs) # Move embeddings to CPU *before* storing to avoid accumulating GPU memory emb_cpu = emb.cpu() labs_list = labs.tolist() # Convert labels tensor to list for i in range(emb_cpu.size(0)): label = labs_list[i] # Use setdefault for cleaner initialization all_embeddings.setdefault(label, []).append(emb_cpu[i]) # Append CPU tensor except Exception as e: st.error(f"Error during embedding calculation batch: {e}") # Decide whether to continue or stop continue # Continue with next batch final_prototypes = [] prototype_labels = [] # Store the original dataset labels corresponding to prototypes unique_labels_present = sorted(list(all_embeddings.keys())) # Labels for which embeddings were found if not unique_labels_present: st.warning("No embeddings were generated. Cannot calculate prototypes.") return None, None for label in unique_labels_present: if not (0 <= label < len(_class_names)): st.warning(f"Skipping label {label} during prototype calculation: Out of bounds for class names list (len={len(_class_names)}).") continue class_embeddings_list = all_embeddings[label] if class_embeddings_list: try: class_embeddings = torch.stack(class_embeddings_list) # Stack CPU tensors prototype = class_embeddings.mean(dim=0) # Calculate mean on CPU final_prototypes.append(prototype) prototype_labels.append(label) # Append the original label except Exception as e: st.error(f"Error processing embeddings for class {label} ('{_class_names[label]}'): {e}") continue # Skip this class if stacking/mean fails # else: # Should not happen if label is in unique_labels_present keys # st.warning(f"No embeddings found for class ID {label} ('{_class_names[label]}') though key existed.") if not final_prototypes: st.warning("Could not calculate any valid final prototypes.") return None, None # Stack final prototypes and move to target device final_prototypes_tensor = torch.stack(final_prototypes).to(DEVICE) st.success(f"Calculated {len(final_prototypes)} final prototypes for labels: {prototype_labels}") return final_prototypes_tensor, prototype_labels # Visualize Prototypes Function def visualize_prototypes(prototypes_tensor, prototype_labels, class_names_list): st.write("Visualizing Prototypes using PCA...") if prototypes_tensor is None or prototypes_tensor.numel() == 0: st.warning("⚠️ No prototypes available to visualize.") return num_prototypes = prototypes_tensor.size(0) if num_prototypes < 2: st.warning(f"⚠️ Need at least 2 prototypes for PCA. Found {num_prototypes}.") # Optionally display the single prototype info if num_prototypes == 1: st.write(f"Single prototype label: {prototype_labels[0]} ({class_names_list[prototype_labels[0]]})") return # Ensure number of labels matches number of prototypes if len(prototype_labels) != num_prototypes: st.error(f"Mismatch between number of prototypes ({num_prototypes}) and labels ({len(prototype_labels)}). Cannot visualize.") return pca = PCA(n_components=2) try: # Detach, move to CPU, convert to numpy prototypes_np = prototypes_tensor.detach().cpu().numpy() # Check for NaN or Inf before PCA if not np.all(np.isfinite(prototypes_np)): st.error("Prototypes contain NaN or Infinite values. Cannot perform PCA.") # Optionally show where NaNs are: np.isnan(prototypes_np).any(axis=1) return prototypes_2d = pca.fit_transform(prototypes_np) fig, ax = plt.subplots(figsize=(12, 9)) # Adjusted size # Use a colormap for potentially many classes cmap = plt.get_cmap('tab20', len(class_names_list)) # Use tab20 or other suitable map plotted_labels = set() # Keep track of labels already in legend for i, label_index in enumerate(prototype_labels): if 0 <= label_index < len(class_names_list): class_name = class_names_list[label_index] legend_label = f"{class_name} ({label_index})" color = cmap(label_index / len(class_names_list)) # Assign color based on global index # Only add label to legend once per class if label_index not in plotted_labels: ax.scatter(prototypes_2d[i, 0], prototypes_2d[i, 1], label=legend_label, s=100, color=color, alpha=0.8) plotted_labels.add(label_index) else: ax.scatter(prototypes_2d[i, 0], prototypes_2d[i, 1], s=100, color=color, alpha=0.8) # No label if already added else: # Handle labels out of bounds ax.scatter(prototypes_2d[i, 0], prototypes_2d[i, 1], label=f"Invalid Label ({label_index})", s=100, marker='x', color='red') st.warning(f"Label index {label_index} out of bounds for class names list (length {len(class_names_list)}).") ax.set_title("Prototypes Visualization (PCA - 2 Components)") ax.set_xlabel("PCA Component 1") ax.set_ylabel("PCA Component 2") # Adjust legend position/size if too many classes ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small') ax.grid(True, linestyle='--', alpha=0.6) fig.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend outside st.pyplot(fig) except ValueError as ve: st.error(f"PCA Error: {ve}. Check prototype dimensions and values.") except Exception as e: st.error(f"Error during PCA visualization: {e}") st.exception(e) # --- Object Detection (YOLO) --- @st.cache_resource(show_spinner="Loading detection model...") def load_yolo_model(): try: model = YOLO(YOLO_MODEL_PATH) model.to(DEVICE) model.eval() return model except Exception as e: st.error(f"Failed to load YOLO detection model from {YOLO_MODEL_PATH}: {e}") st.exception(e) st.stop() # Detection function for YOLOv11 def detect_objects(image): model = load_yolo_model() # Convert PIL Image to NumPy array (OpenCV format BGR) img_array = np.array(image.convert("RGB")) # Ensure RGB first img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) # Perform inference try: results = model(img_bgr) # Run YOLO model except Exception as e: st.error(f"Error during YOLO inference: {e}") return img_array, pd.DataFrame() # Return original image and empty dataframe on error # Render results (draw bounding boxes on the image) # results[0].plot() returns a NumPy array (BGR) result_image_bgr = results[0].plot(conf=True, labels=True) # Show conf and labels on boxes # Convert result back to RGB for Streamlit display result_image_rgb = cv2.cvtColor(result_image_bgr, cv2.COLOR_BGR2RGB) # Extract detection details detections_list = [] if results[0].boxes is not None: boxes = results[0].boxes.xyxy.cpu().numpy() # Bounding boxes (xmin, ymin, xmax, ymax) confs = results[0].boxes.conf.cpu().numpy() # Confidence scores cls_ids = results[0].boxes.cls.cpu().numpy().astype(int) # Class IDs class_names = model.names # Get class names mapping from the model for i in range(len(boxes)): detections_list.append({ "Class": class_names.get(cls_ids[i], f"ID {cls_ids[i]}"), # Use .get for safety "Confidence": confs[i], "X_min": boxes[i, 0], "Y_min": boxes[i, 1], "X_max": boxes[i, 2], "Y_max": boxes[i, 3], }) detections_df = pd.DataFrame(detections_list) return result_image_rgb, detections_df # === Main App Logic === st.title("🌿 Coffee Leaf Disease Classifier + Few-Shot Learning + Detection") # --- Initialize Session State --- # Use .setdefault() for cleaner initialization st.session_state.setdefault('few_shot_trained', False) st.session_state.setdefault('final_prototypes', None) # Tensor of prototypes st.session_state.setdefault('prototype_labels', None) # List of original labels st.session_state.setdefault('model_mode', 'standard') # 'standard' or 'few_shot' # --- Load Data --- # This runs once and caches results, or re-runs if cache is cleared or args change # get_combined_dataset_and_indices handles errors internally and stops if needed combined_dataset, class_indices, class_names, num_base_classes = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # --- Sidebar --- st.sidebar.header("⚙️ Options & Status") # --- Mode Selection / Status --- st.sidebar.subheader("Mode") # Button to switch back to standard classification if st.sidebar.button("🔄 Reset to Standard Classifier"): st.session_state.model_mode = 'standard' st.session_state.few_shot_trained = False st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.success("Switched to Standard Classification Mode.") # Clear cache related to few-shot results if necessary, e.g., prototype calculation cache # Note: calculate_final_prototypes uses @st.cache_data, might need manual clearing if model state changes matter st.cache_data.clear() # Clear data cache might be needed if dataset changed st.rerun() # Display current mode mode_status = "Standard Classifier" if st.session_state.model_mode == 'few_shot' and st.session_state.final_prototypes is not None: mode_status = f"Few-Shot ({len(st.session_state.final_prototypes)} Prototypes Active)" st.sidebar.info(f"**Current Mode:** {mode_status}") # --- Load/Delete Saved Few-Shot Models --- st.sidebar.divider() st.sidebar.subheader("💾 Saved Few-Shot Models") saved_model_names = list_saved_models() # --- Loading Section --- if not saved_model_names: st.sidebar.info("No saved few-shot models found.") else: selected_model_to_load = st.sidebar.selectbox( "Load a saved few-shot state:", options=[""] + saved_model_names, # Add empty option for placeholder key="load_model_select", index=0 # Default to empty selection ) if st.sidebar.button("📥 Load Selected State", key="load_model_button", disabled=(not selected_model_to_load)): if selected_model_to_load: # Get the current feature extractor instance to load weights into model_instance = cached_feature_extractor_model() # Get from cache # Get current class names for the crucial check during loading # Recalculate dataset info to be absolutely sure it's current _, _, current_cls_names_on_load, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) if load_few_shot_state(selected_model_to_load, model_instance, current_cls_names_on_load): st.rerun() # Rerun to reflect loaded state in UI # No need for else, button is disabled if nothing selected # --- Deleting Section --- if saved_model_names: # Only show delete options if models exist st.sidebar.markdown("---") # Separator within the section selected_model_to_delete = st.sidebar.selectbox( "Delete a saved few-shot state:", options=[""] + saved_model_names, key="delete_model_select", index=0 ) if selected_model_to_delete: # Only show checkbox if a model is selected confirm_delete = st.sidebar.checkbox(f"Confirm deletion of '{selected_model_to_delete}'", key="delete_confirm") if st.sidebar.button("❌ Delete Selected State", key="delete_model_button", disabled=(not confirm_delete)): if confirm_delete: # Double check confirmation state if delete_saved_model(selected_model_to_delete): # Clear potentially cached list of models? list_saved_models isn't cached, so rerun is enough st.rerun() # No need for else, button disabled if not confirmed else: st.sidebar.write("Select a model above to enable deletion.") # Guide user else: # This case is covered by the "No saved models found" message above loading. pass # --- Main Panel Options --- option = st.radio( "Choose an action:", ["Upload & Predict", "Add/Manage Rare Classes", "Train Few-Shot Model", "Visualize Prototypes", "Detection"], horizontal=True, key="main_option" # Added key for stability ) # Select the appropriate model based on mode (done inside prediction logic) # We load both cached models initially, ready for use feature_extractor_model = cached_feature_extractor_model() standard_classifier_model = cached_standard_classifier() # --- Action Implementation --- if option == "Upload & Predict": st.header("🔎 Upload Image for Prediction") uploaded_file = st.file_uploader("Choose a coffee leaf image...", type=["jpg", "jpeg", "png"], key="file_uploader") if uploaded_file: try: image = Image.open(uploaded_file).convert("RGB") # Display smaller image to save space st.image(image, caption="Uploaded Image", width=300) # Control width input_tensor = transform(image).unsqueeze(0).to(DEVICE) # Determine which model/method to use use_few_shot = (st.session_state.model_mode == 'few_shot' and st.session_state.final_prototypes is not None and st.session_state.prototype_labels is not None and st.session_state.final_prototypes.numel() > 0 ) # Ensure prototypes are not empty if use_few_shot: # --- Few-Shot Prototype Prediction --- st.subheader("Prediction using Prototypes") model_to_use = feature_extractor_model # Use the feature extractor model_to_use.eval() with torch.no_grad(): embedding = model_to_use(input_tensor) # [1, embed_dim] # Ensure prototypes are on the correct device and shape [num_prototypes, embed_dim] prototypes_for_pred = st.session_state.final_prototypes.to(DEVICE) # Calculate distances: [1, embed_dim] vs [num_prototypes, embed_dim] -> [1, num_prototypes] distances = torch.cdist(embedding, prototypes_for_pred) pred_prototype_index = torch.argmin(distances, dim=1).item() # Index within the prototype list # Get the original dataset label corresponding to this prototype index predicted_original_label = st.session_state.prototype_labels[pred_prototype_index] # Map the original label to the class name if 0 <= predicted_original_label < len(class_names): predicted_class_name = class_names[predicted_original_label] # Confidence calculation (using softmax on negative distances) confidence_scores = torch.softmax(-distances, dim=1) confidence = confidence_scores[0, pred_prototype_index].item() st.metric(label="Prediction (Prototype)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") st.info(f"(Matched prototype for class ID: {predicted_original_label})") else: st.error(f"Predicted prototype label index {predicted_original_label} is out of range for known class names ({len(class_names)}). Prototype labels might be inconsistent.") else: # --- Standard Classification Prediction --- st.subheader("Prediction using Standard Classifier") if st.session_state.model_mode != 'standard': st.warning("Falling back to Standard Classifier mode (Few-shot prototypes not available or mode not set).") model_to_use = standard_classifier_model # Use the standard classifier model_to_use.eval() with torch.no_grad(): outputs = model_to_use(input_tensor) probs = torch.softmax(outputs, dim=1) pred_label = torch.argmax(probs, dim=1).item() # This label corresponds to base classes (0-4) confidence = probs[0][pred_label].item() # Map predicted label (0-4) to the corresponding class name from the base set if 0 <= pred_label < num_base_classes: # Check against NUM_BASE_CLASSES predicted_class_name = class_names[pred_label] # Get name from the combined list using the base index st.metric(label="Prediction (Standard)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") else: # This case should technically not happen if the standard model only outputs 0-4 st.error(f"Standard classifier predicted label {pred_label}, which is out of range for base classes ({num_base_classes}). Model output issue?") except Exception as e: st.error(f"An error occurred during prediction: {e}") st.exception(e) elif option == "Detection": st.header("🕵️ Object Detection with YOLO") uploaded_file_detect = st.file_uploader("Upload an image for detection", type=["jpg", "jpeg", "png"], key="detect_uploader") if uploaded_file_detect: image_detect = Image.open(uploaded_file_detect).convert("RGB") result_image, detections = detect_objects(image_detect) # Calls the detection function # Resize for display (maintain aspect ratio) using Pillow ImageOps display_image = Image.fromarray(result_image) # Convert numpy array back to PIL Image display_image = ImageOps.contain(display_image, (900, 700)) # Resize while keeping aspect ratio within bounds st.image(display_image, caption="Detection Result", use_column_width=True) # Use column width # Show detection results if not detections.empty: st.subheader("📋 Detection Results:") # Format confidence as percentage detections['Confidence'] = detections['Confidence'].map('{:.1%}'.format) st.dataframe(detections[['Class', 'Confidence', 'X_min', 'Y_min', 'X_max', 'Y_max']]) # Display selected columns else: st.info("No objects detected.") elif option == "Add/Manage Rare Classes": st.header("➕ Add New Rare Class (Few-Shot)") st.write(f"Upload exactly 5 sample images for the new disease class.") with st.form("add_class_form"): new_class_name = st.text_input("Enter the name for the new rare class:") uploaded_files_rare = st.file_uploader( "Upload 5 images:", accept_multiple_files=True, type=["jpg", "jpeg", "png"], key="add_class_uploader" ) submitted_add = st.form_submit_button("Add Class") if submitted_add: valid = True if not new_class_name: st.warning("Please enter a class name.") valid = False if len(uploaded_files_rare) != 5: st.warning(f"Please upload exactly 5 images. You uploaded {len(uploaded_files_rare)}.") valid = False if valid: # Sanitize class name for directory sanitized_class_name = "".join(c for c in new_class_name if c.isalnum() or c in (' ', '_')).strip().replace(" ", "_") if not sanitized_class_name: st.error("Invalid class name after sanitization.") else: new_class_dir = os.path.join(RARE_DATASET, sanitized_class_name) if os.path.exists(new_class_dir): st.warning(f"A class directory named '{sanitized_class_name}' already exists. Choose a different name or delete the existing one first.") else: try: os.makedirs(new_class_dir, exist_ok=True) image_save_errors = 0 for i, file in enumerate(uploaded_files_rare): try: img = Image.open(file).convert("RGB") # Optional: Resize or standardize images on upload if needed # img = img.resize((256, 256)) save_path = os.path.join(new_class_dir, f"sample_{i+1}.jpg") img.save(save_path, format='JPEG', quality=95) # Save as JPEG except Exception as img_e: st.error(f"Error saving image {i+1} ({file.name}): {img_e}") image_save_errors += 1 if image_save_errors == 0: st.success(f"✅ Added new class: '{sanitized_class_name}'. Please re-run 'Train Few-Shot Model' to incorporate it.") # CRITICAL: Clear cache so dataset is reloaded with the new class st.cache_data.clear() # Clear potentially outdated prototypes if a class is added st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.few_shot_trained = False st.session_state.model_mode = 'standard' # Reset to standard after adding class st.rerun() # Force rerun to reload data else: st.error(f"Failed to save {image_save_errors} images. Class directory might be incomplete. Please check and try again.") # Optionally remove the created directory if saving failed partially # shutil.rmtree(new_class_dir) except Exception as e: st.error(f"Error creating directory or saving images for class '{sanitized_class_name}': {e}") st.exception(e) st.divider() st.header("❌ Delete a Rare Class") try: # List only directories inside RARE_DATASET rare_class_dirs = [d for d in os.listdir(RARE_DATASET) if os.path.isdir(os.path.join(RARE_DATASET, d))] if not rare_class_dirs: st.info("No rare classes found to delete.") else: with st.form("delete_class_form"): to_delete = st.selectbox("Select rare class to delete:", rare_class_dirs, key="delete_rare_select") confirm_delete_rare = st.checkbox(f"Are you sure you want to permanently delete '{to_delete}' and its contents?", key="delete_rare_confirm") delete_submit_rare = st.form_submit_button("Delete Class") if delete_submit_rare: if confirm_delete_rare: delete_path = os.path.join(RARE_DATASET, to_delete) try: shutil.rmtree(delete_path) st.success(f"✅ Deleted rare class: {to_delete}") # CRITICAL: Clear cache and reset state st.cache_data.clear() st.session_state.few_shot_trained = False st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.model_mode = 'standard' st.rerun() # Rerun to reload data and update UI except Exception as e: st.error(f"Error deleting directory {delete_path}: {e}") else: st.warning("Please confirm the deletion by checking the box.") except FileNotFoundError: st.info(f"Rare dataset directory '{RARE_DATASET}' not found or inaccessible.") except Exception as e: st.error(f"Error listing rare classes: {e}") elif option == "Train Few-Shot Model": st.header("🚀 Train Few-Shot Model") st.warning("⚠️ This will fine-tune the feature extractor. Performance on original classes might change. Consider saving the state after training.") # --- Check if enough classes exist before showing the form --- if len(class_names) < 2: # Need at least 2 classes for N-way=2 training st.error("Need at least two classes (Base + Rare combined) to perform few-shot training.") st.stop() else: st.info(f"Training will use all {len(class_names)} available classes: {class_names}") # --- Hardcoded Training Parameters --- # --- Training Parameters --- epochs = 10 # Set N-Way dynamically to the total number of available classes n_way_train = len(class_names) # Ensure N-way is at least 2 for meaningful training if n_way_train < 2: st.error(f"Training requires at least 2 classes, but found only {n_way_train}. Cannot proceed with few-shot training.") st.stop() # Stop if not enough classes episodes_per_epoch = 5 # Keep other parameters n_shot = 2 n_query = 2 learning_rate = 1e-5 # Add a warning about potential resource usage if N-way is high if n_way_train > 7: # Example threshold, adjust as needed st.warning(f"⚠️ Using N-Way = {n_way_train} (all classes). This uses more memory and computation per episode than lower N-way settings. Ensure your system has sufficient resources (especially GPU VRAM).") elif n_way_train > 5: st.info(f"Using N-Way = {n_way_train} (all classes). This might be resource-intensive.") st.info(f"Training Parameters: Epochs={epochs}, N-Way={n_way_train}, Episodes/Epoch={episodes_per_epoch}, N-Shot={n_shot}, N-Query={n_query}, LR={learning_rate:.0e}") # --- Training Form --- with st.form("train_form"): freeze_backbone = st.checkbox("❄️ Freeze Base Model Layers (Train only projection layer)", value=True, help="Recommended to prevent catastrophic forgetting of base classes.") submitted_train = st.form_submit_button("Start Training") if submitted_train: # Re-check class count just before training starts if len(class_names) < n_way_train: st.error(f"Cannot start training. Need at least {n_way_train} classes for {n_way_train}-way training, found {len(class_names)}.") st.stop() st.info("🚀 Starting few-shot training...") progress_bar = st.progress(0) status_placeholder = st.empty() chart_placeholder = st.empty() # Placeholder for the chart # --- Get model instance for training --- model_train = cached_feature_extractor_model() # Get the cached model model_train.train() # Set model to training mode # --- Optimizer Setup --- trainable_params = [] if freeze_backbone: st.info("Freezing base model layers...") try: # Freeze base model (assuming it's in model_train.model) for param in model_train.model.parameters(): param.requires_grad = False # Ensure projection layer is trainable (assuming model_train.projection) for param in model_train.projection.parameters(): param.requires_grad = True trainable_params = list(filter(lambda p: p.requires_grad, model_train.parameters())) if not trainable_params: st.error("No trainable parameters found (projection layer)! Check model structure ('model' and 'projection' attributes).") st.stop() st.success("Base layers frozen. Training projection layer only.") except AttributeError: st.error("Could not access model.model or model.projection attributes to freeze/unfreeze. Training ALL layers.") for param in model_train.parameters(): # Ensure all are trainable if fallback param.requires_grad = True trainable_params = list(model_train.parameters()) else: st.info("Training all layers (base model + projection).") # Ensure all parameters are trainable for param in model_train.parameters(): param.requires_grad = True trainable_params = list(model_train.parameters()) # Check if any trainable parameters were actually found if not trainable_params: st.error("Optimizer setup failed: No trainable parameters collected.") st.stop() optimizer = torch.optim.Adam(trainable_params, lr=learning_rate, weight_decay=1e-5) st.info(f"Using Adam optimizer with LR={learning_rate:.0e}. Training {len(trainable_params)} parameters.") # --- Training Loop --- loss_history = [] accuracy_history = [] total_steps = epochs * episodes_per_epoch current_step = 0 # Get fresh dataset info for episode creation within the loop if needed, or assume cached is fine # _, current_class_indices_train, current_class_names_train, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) for epoch in range(epochs): epoch_loss = 0.0 epoch_accuracy = 0.0 valid_episodes_in_epoch = 0 for episode in range(episodes_per_epoch): current_step += 1 # Create episode using current dataset state s_imgs, s_labels, q_imgs, q_labels = create_episode( combined_dataset, class_indices, class_names, n_way=n_way_train, n_shot=n_shot, n_query=n_query ) if s_imgs is None or q_imgs is None: continue # Skip if episode creation failed optimizer.zero_grad() try: s_emb = model_train(s_imgs) q_emb = model_train(q_imgs) except Exception as model_e: st.error(f"Error during model forward pass in training (Epoch {epoch+1}, Ep {episode+1}): {model_e}") st.exception(model_e) continue # Skip episode on model error loss, accuracy = proto_loss(s_emb, s_labels, q_emb, q_labels) if loss is not None and not torch.isnan(loss) and loss.requires_grad: try: loss.backward() # Optional: Gradient clipping # torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0) optimizer.step() epoch_loss += loss.item() epoch_accuracy += accuracy valid_episodes_in_epoch += 1 except Exception as optim_e: st.error(f"Error during optimizer step or backward pass (Epoch {epoch+1}, Ep {episode+1}): {optim_e}") st.exception(optim_e) # Consider stopping or just skipping step? Skipping for now. # else: # Reduce verbosity for invalid loss skipping # Update status and progress bar less frequently for performance if (episode + 1) % 10 == 0 or episode == episodes_per_epoch - 1: progress = current_step / total_steps progress_bar.progress(progress) status_placeholder.text(f"Epoch {epoch+1}/{epochs} | Episode {episode+1}/{episodes_per_epoch} | Current Loss: {loss.item():.4f} | Current Acc: {accuracy*100:.2f}%") # --- Log epoch results --- if valid_episodes_in_epoch > 0: avg_loss = epoch_loss / valid_episodes_in_epoch avg_accuracy = epoch_accuracy / valid_episodes_in_epoch loss_history.append(avg_loss) accuracy_history.append(avg_accuracy) status_placeholder.text(f"Epoch {epoch+1}/{epochs} Completed - Avg Loss: {avg_loss:.4f} - Avg Accuracy: {avg_accuracy*100:.2f}%") else: # Handle epochs where no valid episodes ran loss_history.append(float('nan')) accuracy_history.append(float('nan')) status_placeholder.text(f"Epoch {epoch+1}/{epochs} Completed - No valid episodes were run.") status_placeholder.success("✅ Few-Shot Training Finished!") # --- Final Prototype Calculation (still inside 'if submitted') --- st.info("Calculating final prototypes...") # Ensure model is in eval mode model_train.eval() # Clear cache before calculating prototypes based on potentially new model state st.cache_data.clear() # Clear data cache # Recalculate dataset info to ensure it's current current_combined_dataset_proto, _, current_class_names_proto, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # Pass the *trained* model instance final_prototypes_tensor, final_prototype_labels = calculate_final_prototypes(model_train, current_combined_dataset_proto, current_class_names_proto) # --- Store results in session state --- if final_prototypes_tensor is not None: st.session_state.final_prototypes = final_prototypes_tensor st.session_state.prototype_labels = final_prototype_labels st.session_state.few_shot_trained = True st.session_state.model_mode = 'few_shot' # Switch mode automatically st.success("Prototypes Calculated. Model ready for Few-Shot Prediction.") # --- Display Training Curves (use the placeholder) --- chart_data = pd.DataFrame({ "Epoch": list(range(1, epochs + 1)), "Average Loss": loss_history, "Average Accuracy": accuracy_history }) chart_placeholder.line_chart(chart_data.set_index("Epoch")) else: # Ensure state is reset if prototype calculation fails st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.few_shot_trained = False # Optionally reset mode? Or leave as is? Resetting is safer. # st.session_state.model_mode = 'standard' st.error("Failed to calculate final prototypes after training.") chart_placeholder.empty() # Clear chart placeholder on failure # --- END OF TRAINING FORM `with st.form("train_form"):` BLOCK --- # --- SAVING SECTION (OUTSIDE AND AFTER THE FORM) --- # Show this section ONLY if few-shot training has run successfully (prototypes exist) # Check session state for prototypes to decide whether to show this. if st.session_state.get('final_prototypes') is not None and st.session_state.get('model_mode') == 'few_shot': st.divider() st.subheader("💾 Save Current Few-Shot State") st.info("Save the fine-tuned model weights and calculated prototypes.") # Use a different key for the text input to avoid conflicts save_model_name = st.text_input("Enter a name for this state:", key="save_model_name_input_main") # This button is now OUTSIDE the form, so it's allowed. if st.button("Save State", key="save_state_button_main"): if save_model_name: # Need the model instance that was potentially trained. # Getting it from cache should work if it was modified in-place. model_to_save = cached_feature_extractor_model() model_to_save.eval() # Ensure eval mode # Get the current class names list for metadata _, _, current_cls_names_for_saving, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) save_few_shot_state( save_model_name, model_to_save, # Pass the model instance st.session_state.final_prototypes, st.session_state.prototype_labels, current_cls_names_for_saving ) # Clear input field is tricky without callbacks, maybe omit for simplicity else: st.warning("Please enter a name before saving.") elif option == "Visualize Prototypes": st.header("📊 Visualize Class Prototypes (PCA)") if st.session_state.final_prototypes is not None and st.session_state.prototype_labels is not None: # Pass the required arguments: tensor, labels list, class names list visualize_prototypes( st.session_state.final_prototypes, st.session_state.prototype_labels, class_names # Pass the globally available class names list ) else: st.warning("⚠️ No prototypes calculated yet. Run 'Train Few-Shot Model' or load a saved state.") st.divider() st.info("You can calculate prototypes based on the *current* feature extractor state (useful before/after training or loading).") if st.button("Calculate/Recalculate Prototypes Now"): with st.spinner("Calculating prototypes..."): # Use the current state of the cached feature extractor proto_model = cached_feature_extractor_model() proto_model.eval() # Clear cache before calculation? Depends if calculate_final_prototypes uses cached data internally st.cache_data.clear() # Clear data cache to be safe # Recalculate dataset info current_combined_dataset_viz, _, current_class_names_viz, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # Calculate prototypes temp_prototypes, temp_labels = calculate_final_prototypes(proto_model, current_combined_dataset_viz, current_class_names_viz) if temp_prototypes is not None: st.session_state.final_prototypes = temp_prototypes st.session_state.prototype_labels = temp_labels # Don't set few_shot_trained=True here automatically, maybe? # Let's set mode to few_shot if prototypes are calculated successfully st.session_state.model_mode = 'few_shot' st.success("Prototypes calculated/recalculated.") st.rerun() # Rerun to show visualization or update status else: st.error("Failed to calculate prototypes.") # --- Cleanup Temporary Directory (Optional) --- # This might run on every script run, which is usually fine for temp dirs. # Consider more robust cleanup if needed. # try: #     shutil.rmtree(TEMP_DIR) # except Exception as e: #     st.sidebar.warning(f"Could not cleanup temp dir {TEMP_DIR}: {e}")