Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms, datasets | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| from PIL import Image, ImageOps # Added ImageOps | |
| import os | |
| import tempfile | |
| import random | |
| import shutil | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| import numpy as np | |
| import cv2 | |
| # from PIL import Image # Already imported | |
| from ultralytics import YOLO | |
| import pandas as pd | |
| import json # Added for saving/loading metadata | |
| st.set_page_config(layout="wide") | |
| # --- Constants and Setup --- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| SAVED_MODELS_DIR = "saved_few_shot_models" | |
| BASE_DATASET = "App data/data_optimize/Basee_data" | |
| RARE_DATASET = "App data/data_optimize/Rare data/" | |
| MODEL_WEIGHTS_PATH = "model/efficientnet_coffee (1).pth" # For Standard Classifier and initial feature extractor state | |
| YOLO_MODEL_PATH = "model/best.pt" # For Detection | |
| # Ensure directories exist | |
| os.makedirs(SAVED_MODELS_DIR, exist_ok=True) | |
| os.makedirs(RARE_DATASET, exist_ok=True) # Create rare data dir if not present | |
| # Check required files/folders exist | |
| if not os.path.isdir(BASE_DATASET): | |
| st.error(f"Base dataset directory not found: {BASE_DATASET}") | |
| st.stop() | |
| if not os.path.isfile(MODEL_WEIGHTS_PATH): | |
| st.error(f"Classifier weights file not found: {MODEL_WEIGHTS_PATH}") | |
| st.stop() | |
| if not os.path.isfile(YOLO_MODEL_PATH): | |
| st.error(f"YOLO detection model file not found: {YOLO_MODEL_PATH}") | |
| st.stop() | |
| st.sidebar.info(f"Using device: {DEVICE}") | |
| TEMP_DIR = tempfile.mkdtemp() # For temporary files if needed | |
| # --- Helper Functions for Saving/Loading Few-Shot States --- | |
| # d | |
| def list_saved_models(): | |
| """Returns a list of names of saved few-shot model states.""" | |
| if not os.path.isdir(SAVED_MODELS_DIR): | |
| return [] | |
| # List only directories within the SAVED_MODELS_DIR | |
| return [d for d in os.listdir(SAVED_MODELS_DIR) if os.path.isdir(os.path.join(SAVED_MODELS_DIR, d))] | |
| def save_few_shot_state(name, model, prototypes, proto_labels, current_class_names): | |
| """Saves the model state, prototypes, and metadata.""" | |
| if not name or not name.strip(): | |
| st.error("Please provide a valid name for the saved model.") | |
| return False | |
| # Sanitize name for directory creation | |
| sanitized_name = "".join(c for c in name if c.isalnum() or c in ('_', '-')).rstrip() | |
| if not sanitized_name: | |
| st.error("Invalid name after sanitization. Use letters, numbers, underscore, or hyphen.") | |
| return False | |
| save_dir = os.path.join(SAVED_MODELS_DIR, sanitized_name) | |
| # Handle existing directory (Ask for overwrite confirmation) | |
| if os.path.exists(save_dir): | |
| # Use columns for better layout of overwrite confirmation | |
| col1, col2 = st.columns([3,1]) | |
| with col1: | |
| st.warning(f"Model name '{sanitized_name}' already exists.") | |
| with col2: | |
| # Use a unique key for the overwrite button based on the name | |
| overwrite_key = f"overwrite_{sanitized_name}" | |
| if st.button("Overwrite?", key=overwrite_key): | |
| st.info(f"Overwriting '{sanitized_name}'...") | |
| # Proceed with saving below | |
| else: | |
| st.info("Save cancelled. Choose a different name or click 'Overwrite?'.") | |
| return False # Stop if not confirmed overwrite | |
| else: | |
| st.info(f"Saving new model state '{sanitized_name}'...") | |
| try: | |
| os.makedirs(save_dir, exist_ok=True) # Create directory if it doesn't exist or was confirmed for overwrite | |
| # 1. Save Model State Dictionary (ensure model is on CPU before saving state_dict for better compatibility) | |
| model.to('cpu') # Move model to CPU | |
| model_path = os.path.join(save_dir, "feature_extractor_state_dict.pth") | |
| torch.save(model.state_dict(), model_path) | |
| model.to(DEVICE) # Move model back to original device | |
| # 2. Save Prototypes Tensor (move to CPU before saving) | |
| prototypes_path = os.path.join(save_dir, "prototypes.pt") | |
| torch.save(prototypes.cpu(), prototypes_path) | |
| # 3. Save Metadata (Labels and Class Names active during training) | |
| metadata = { | |
| "prototype_labels": proto_labels, # Should be a standard list | |
| "class_names_on_save": current_class_names # List of strings | |
| } | |
| metadata_path = os.path.join(save_dir, "metadata.json") | |
| with open(metadata_path, 'w') as f: | |
| json.dump(metadata, f, indent=4) | |
| st.success(f"Few-shot model state '{sanitized_name}' saved successfully!") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving model state '{sanitized_name}': {e}") | |
| st.exception(e) # Show full traceback for debugging | |
| # Clean up potentially partially saved directory if error occurred AFTER creation/overwrite confirmation | |
| if os.path.exists(save_dir): | |
| try: | |
| shutil.rmtree(save_dir) | |
| st.info(f"Cleaned up partially saved directory '{save_dir}'.") | |
| except Exception as cleanup_e: | |
| st.error(f"Error cleaning up directory during save failure: {cleanup_e}") | |
| return False | |
| def load_few_shot_state(name, model_to_load_into, current_class_names): | |
| """Loads a saved model state, prototypes, and labels into session state and the model.""" | |
| load_dir = os.path.join(SAVED_MODELS_DIR, name) | |
| if not os.path.isdir(load_dir): | |
| st.error(f"Saved model directory '{load_dir}' not found.") | |
| return False | |
| model_path = os.path.join(load_dir, "feature_extractor_state_dict.pth") | |
| prototypes_path = os.path.join(load_dir, "prototypes.pt") | |
| metadata_path = os.path.join(load_dir, "metadata.json") | |
| if not all(os.path.exists(p) for p in [model_path, prototypes_path, metadata_path]): | |
| st.error(f"Saved model '{name}' is incomplete. Files missing in '{load_dir}'.") | |
| return False | |
| try: | |
| # 1. Load Metadata | |
| with open(metadata_path, 'r') as f: | |
| metadata = json.load(f) | |
| loaded_proto_labels = metadata.get("prototype_labels") | |
| saved_class_names = metadata.get("class_names_on_save") | |
| if loaded_proto_labels is None or saved_class_names is None: | |
| st.error(f"Metadata file for '{name}' is corrupted or missing required keys.") | |
| return False | |
| # **CRUCIAL CHECK**: Compare saved class names with current class names | |
| if set(saved_class_names) != set(current_class_names): | |
| st.warning(f"⚠️ **Class Mismatch!**") | |
| st.warning(f"Saved model '{name}' classes: `{saved_class_names}`") | |
| st.warning(f"Current active classes: `{current_class_names}`") | |
| st.warning("Predictions might be incorrect or errors may occur. Proceed with caution.") | |
| # Decide whether to proceed or stop. Let's proceed with warning. | |
| # 2. Load Model State Dictionary | |
| # Ensure model architecture is correct *before* loading state dict | |
| model_to_load_into.to(DEVICE) # Move model to target device | |
| state_dict = torch.load(model_path, map_location=DEVICE) # Load state dict to target device | |
| try: | |
| # Use strict=False initially if unsure about architecture changes | |
| missing_keys, unexpected_keys = model_to_load_into.load_state_dict(state_dict, strict=True) | |
| if missing_keys: st.warning(f"Loaded state dict is missing keys: {missing_keys}") | |
| if unexpected_keys: st.warning(f"Loaded state dict has unexpected keys: {unexpected_keys}") | |
| except RuntimeError as e: | |
| st.error(f"RuntimeError loading state_dict for '{name}'. Architecture mismatch?") | |
| st.error(e) | |
| return False | |
| model_to_load_into.eval() # Set to evaluation mode after loading | |
| # 3. Load Prototypes (load directly to target device) | |
| loaded_prototypes = torch.load(prototypes_path, map_location=DEVICE) | |
| # 4. Update Session State | |
| st.session_state.final_prototypes = loaded_prototypes | |
| st.session_state.prototype_labels = loaded_proto_labels | |
| st.session_state.few_shot_trained = True # Mark as trained since we loaded a state | |
| st.session_state.model_mode = 'few_shot' # Switch to few-shot mode | |
| st.success(f"Successfully loaded few-shot model state '{name}'. Mode set to Few-Shot.") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error loading model state '{name}': {e}") | |
| st.exception(e) | |
| # Optional: Reset state if loading fails partially | |
| # st.session_state.final_prototypes = None | |
| # st.session_state.prototype_labels = None | |
| # st.session_state.few_shot_trained = False | |
| # st.session_state.model_mode = 'standard' | |
| return False | |
| def delete_saved_model(name): | |
| """Deletes a saved model directory.""" | |
| delete_dir = os.path.join(SAVED_MODELS_DIR, name) | |
| if not os.path.isdir(delete_dir): | |
| st.error(f"Cannot delete. Saved model '{name}' not found.") | |
| return False | |
| try: | |
| shutil.rmtree(delete_dir) | |
| st.success(f"Deleted saved model '{name}'.") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error deleting saved model '{name}': {e}") | |
| return False | |
| # --- Model Architectures --- | |
| # EfficientNet feature extractor with projection layer | |
| class EfficientNetWithProjection(nn.Module): | |
| def __init__(self, base_model, output_dim=1024): | |
| super(EfficientNetWithProjection, self).__init__() | |
| self.model = base_model # This holds the EfficientNet base (feature extractor part) | |
| # Determine the input feature size dynamically from the base model if possible | |
| # For EfficientNetB0, the layer before the classifier has 1280 features | |
| in_features = 1280 # Hardcoding for effnetb0 is reliable here | |
| self.projection = nn.Linear(in_features, output_dim) # Projection layer | |
| def forward(self, x): | |
| features = self.model(x) # Get features from EfficientNet base | |
| return self.projection(features) # Project to output_dim dimensions | |
| # Base EfficientNet model structure (used for loading weights) | |
| def get_base_efficientnet_architecture(num_classes=5): | |
| # Load architecture only, ensure it matches the saved weights structure | |
| model = models.efficientnet_b0(weights=None) # Start with no pretrained weights here | |
| in_features = model.classifier[1].in_features # Get feature dimension | |
| model.classifier[1] = nn.Linear(in_features, num_classes) # Adjust final layer to match saved model | |
| return model | |
| # Feature Extractor model structure (for few-shot) | |
| def get_feature_extractor_base(): | |
| # This function prepares the base model, loads coffee weights, then removes classifier | |
| # Start with the architecture matching the saved weights (5 classes) | |
| base_model = get_base_efficientnet_architecture(num_classes=5) | |
| # Load the coffee-specific weights into this matching architecture | |
| try: | |
| # Use weights_only=True for security unless the model itself is saved | |
| state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) #, weights_only=True) # Set weights_only based on how model was saved | |
| # Load weights strictly as the architecture matches exactly | |
| missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=True) | |
| # No warnings needed if strict=True and it passes | |
| except Exception as e: | |
| st.error(f"Error loading model weights from {MODEL_WEIGHTS_PATH} into base architecture: {e}") | |
| st.exception(e) | |
| st.stop() | |
| # Remove the classifier to use it as a feature extractor | |
| base_model.classifier = nn.Identity() | |
| base_model.eval() # Set base model to eval mode | |
| return base_model | |
| # Function to load the standard classifier model (for reset/fallback) | |
| def load_standard_classifier(): | |
| # This loads the model intended for standard 5-class classification | |
| model = get_base_efficientnet_architecture(num_classes=5) # Get the 5-class architecture | |
| try: | |
| # Use weights_only=True for security if appropriate | |
| state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) #, weights_only=True) | |
| model.load_state_dict(state_dict, strict=True) # Strict loading | |
| except Exception as e: | |
| st.error(f"Error loading model weights for standard classifier: {e}") | |
| st.exception(e) | |
| st.stop() | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| # --- Caching --- | |
| # Cache the feature extractor model resource (Base + Projection) | |
| def cached_feature_extractor_model(): | |
| # This function creates the base, loads weights, removes classifier, then adds projection | |
| base_model = get_feature_extractor_base() | |
| model = EfficientNetWithProjection(base_model, output_dim=1024) | |
| model.to(DEVICE) | |
| model.eval() | |
| st.sidebar.info("Feature extractor model ready (cached).") | |
| return model | |
| # Cache the standard classifier model resource | |
| def cached_standard_classifier(): | |
| model = load_standard_classifier() | |
| st.sidebar.info("Standard classifier model ready (cached).") | |
| return model | |
| # Cache data loading and processing | |
| def get_combined_dataset_and_indices(base_path, rare_path): | |
| try: | |
| # Define transform inside function or ensure it's globally defined before call | |
| transform_local = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load base dataset | |
| full_dataset = datasets.ImageFolder(base_path, transform_local) | |
| num_base_classes = len(full_dataset.classes) | |
| base_class_names = sorted(full_dataset.classes) # Get base names sorted | |
| # Load rare dataset if it exists and has content | |
| rare_classes_found = 0 | |
| rare_class_names = [] | |
| if os.path.isdir(rare_path) and any(os.scandir(rare_path)): # Check if dir exists and is not empty | |
| try: | |
| rare_dataset = datasets.ImageFolder(rare_path, transform_local) | |
| if len(rare_dataset.samples) > 0: | |
| # IMPORTANT: Adjust labels for rare classes to start after base classes | |
| rare_dataset.samples = [(path, label + num_base_classes) for path, label in rare_dataset.samples] | |
| combined_dataset = ConcatDataset([full_dataset, rare_dataset]) | |
| rare_classes_found = len(rare_dataset.classes) | |
| rare_class_names = sorted(rare_dataset.classes) # Get rare names sorted | |
| else: | |
| combined_dataset = full_dataset # Rare dir exists but no samples | |
| except Exception as e_rare: | |
| st.warning(f"Could not load rare dataset from {rare_path}: {e_rare}. Using base dataset only.") | |
| combined_dataset = full_dataset | |
| else: | |
| combined_dataset = full_dataset # Rare dir doesn't exist or is empty | |
| # Create class indices mapping (label -> list of dataset indices) | |
| indices = {} | |
| current_idx = 0 | |
| # Iterate through the combined dataset structure correctly | |
| if isinstance(combined_dataset, ConcatDataset): | |
| for ds in combined_dataset.datasets: | |
| for _, label in ds.samples: # We only need the label here | |
| indices.setdefault(label, []).append(current_idx) | |
| current_idx += 1 | |
| else: # Only base dataset | |
| for idx, (_, label) in enumerate(combined_dataset.samples): | |
| indices.setdefault(label, []).append(idx) | |
| # Combine class names in the correct order | |
| class_names = base_class_names + rare_class_names | |
| # Display stats in sidebar | |
| st.sidebar.metric("Base Classes", num_base_classes) | |
| st.sidebar.metric("Rare Classes", rare_classes_found) | |
| st.sidebar.metric("Total Classes", len(class_names)) | |
| if len(class_names) == 0: | |
| st.error("No classes found in base or rare datasets. Check paths/contents.") | |
| st.stop() | |
| return combined_dataset, indices, class_names, num_base_classes | |
| except FileNotFoundError as e: | |
| st.error(f"Dataset path error: {e}. Check BASE_DATASET ('{base_path}') and RARE_DATASET ('{rare_path}').") | |
| st.stop() | |
| except Exception as e: | |
| st.error(f"Error loading datasets: {e}") | |
| st.exception(e) # Show full traceback | |
| st.stop() | |
| # --- Re-define transform globally if not done inside get_combined_dataset_and_indices --- | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # --- Few-Shot Learning Functions --- | |
| def create_episode(dataset, class_indices, class_list, n_way=5, n_shot=5, n_query=5): | |
| """Creates an episode for Prototypical Networks.""" | |
| available_classes = list(class_indices.keys()) # Get labels (0, 1, 2...) that have samples | |
| if len(available_classes) < n_way: | |
| n_way = len(available_classes) # Adjust n_way if not enough classes | |
| if n_way < 2: | |
| # st.warning("Cannot create episode: < 2 classes available.") # Less verbose | |
| return None, None, None, None | |
| # Sample N-way CLASS LABELS from the available ones | |
| selected_class_ids = random.sample(available_classes, n_way) | |
| support_imgs, query_imgs = [], [] | |
| support_labels, query_labels = [], [] | |
| # Map original label (e.g., 3, 7, 1) to episode-specific label (0, 1, 2) | |
| episode_class_map = {original_label: episode_label for episode_label, original_label in enumerate(selected_class_ids)} | |
| actual_n_way = 0 # Track classes successfully added | |
| for original_label in selected_class_ids: | |
| indices_for_class = class_indices.get(original_label, []) | |
| min_samples_needed = n_shot + n_query | |
| if len(indices_for_class) < min_samples_needed: | |
| # Skip class for this episode if not enough samples | |
| continue | |
| # Sample indices FOR THIS CLASS from the dataset indices list | |
| sampled_indices = random.sample(indices_for_class, min_samples_needed) | |
| # Get images using the sampled dataset indices | |
| try: | |
| support_imgs += [dataset[i][0] for i in sampled_indices[:n_shot]] | |
| query_imgs += [dataset[i][0] for i in sampled_indices[n_shot:]] | |
| except IndexError as e: | |
| st.error(f"IndexError during episode creation for class {original_label}, index {i}. Check dataset/indices integrity.") | |
| st.exception(e) | |
| return None, None, None, None | |
| except Exception as e: | |
| st.error(f"Error retrieving data during episode creation: {e}") | |
| st.exception(e) | |
| return None, None, None, None | |
| # Assign new sequential labels (0 to n_way-1) for the episode | |
| new_label = episode_class_map[original_label] | |
| support_labels += [new_label] * n_shot | |
| query_labels += [new_label] * n_query | |
| actual_n_way += 1 # Increment count of classes successfully added | |
| if actual_n_way < 2: # Check if enough classes were actually added | |
| # st.warning(f"Episode creation resulted in < 2 valid classes ({actual_n_way}). Skipping.") | |
| return None, None, None, None | |
| # Return tensors on the correct device | |
| try: | |
| s_imgs_tensor = torch.stack(support_imgs).to(DEVICE) | |
| s_labels_tensor = torch.tensor(support_labels, dtype=torch.long).to(DEVICE) | |
| q_imgs_tensor = torch.stack(query_imgs).to(DEVICE) | |
| q_labels_tensor = torch.tensor(query_labels, dtype=torch.long).to(DEVICE) | |
| return s_imgs_tensor, s_labels_tensor, q_imgs_tensor, q_labels_tensor | |
| except Exception as e: | |
| st.error(f"Error stacking tensors in create_episode: {e}") | |
| st.exception(e) | |
| return None, None, None, None | |
| def proto_loss(support_embeddings, support_labels, query_embeddings, query_labels): | |
| """Calculates the Prototypical Network loss and accuracy.""" | |
| if support_embeddings is None or support_embeddings.numel() == 0 or \ | |
| query_embeddings is None or query_embeddings.numel() == 0: | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| unique_episode_labels = torch.unique(support_labels) | |
| n_way_actual = len(unique_episode_labels) | |
| if n_way_actual < 2: # Need at least 2 classes for meaningful loss/accuracy | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| prototypes = [] | |
| # Map the episode labels (e.g., 0, 1, 2...) to their prototype index (0, 1, 2...) | |
| proto_map = {label.item(): i for i, label in enumerate(unique_episode_labels)} | |
| for episode_label in unique_episode_labels: | |
| class_mask = (support_labels == episode_label) | |
| class_embeddings = support_embeddings[class_mask] | |
| if class_embeddings.size(0) > 0: | |
| prototypes.append(class_embeddings.mean(dim=0)) | |
| else: | |
| st.warning(f"ProtoLoss: No support embeddings found for episode label {episode_label}. Skipping.") | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| if len(prototypes) != n_way_actual: # Should match unique labels count | |
| st.warning("ProtoLoss: Mismatch between unique labels and calculated prototypes.") | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| prototypes = torch.stack(prototypes) # Shape: [n_way_actual, embedding_dim] | |
| # Ensure query labels correspond to the classes present in the support set | |
| valid_query_mask = torch.isin(query_labels, unique_episode_labels) | |
| if not torch.any(valid_query_mask): | |
| # st.warning("ProtoLoss: No query samples match support set labels.") | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 # No valid query samples | |
| filtered_query_embeddings = query_embeddings[valid_query_mask] | |
| filtered_query_labels_original = query_labels[valid_query_mask] # Keep original episode labels (e.g., 0, 1, 2...) | |
| # Calculate distances between filtered query embeddings and prototypes | |
| # Shape: [num_filtered_query, n_way_actual] | |
| distances = torch.cdist(filtered_query_embeddings, prototypes) | |
| # Get predictions based on nearest prototype (indices 0 to n_way_actual-1) | |
| predictions = torch.argmin(distances, dim=1) # Shape: [num_filtered_query] | |
| # Map the original filtered query labels to the prototype indices (0 to n_way_actual-1) for comparison | |
| mapped_true_labels = torch.tensor([proto_map[lbl.item()] for lbl in filtered_query_labels_original], dtype=torch.long).to(DEVICE) | |
| correct_predictions = (predictions == mapped_true_labels).sum().item() | |
| total_predictions = mapped_true_labels.size(0) | |
| accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 | |
| # Calculate cross-entropy loss using negative distances (closer = higher logit) | |
| loss = F.cross_entropy(-distances, mapped_true_labels) | |
| return loss, accuracy | |
| # Recalculate prototypes after training for the entire dataset | |
| # Added spinner message | |
| def calculate_final_prototypes(_model, _dataset, _class_names): | |
| _model.eval() | |
| all_embeddings = {} # Dict: label -> list of tensors | |
| # Use DataLoader for efficient batch processing | |
| # Consider pinning memory if using GPU and workers > 0 | |
| loader = DataLoader(_dataset, batch_size=128, shuffle=False, num_workers=0) | |
| with torch.no_grad(): | |
| for imgs, labs in loader: | |
| imgs = imgs.to(DEVICE) | |
| try: | |
| emb = _model(imgs) | |
| # Move embeddings to CPU *before* storing to avoid accumulating GPU memory | |
| emb_cpu = emb.cpu() | |
| labs_list = labs.tolist() # Convert labels tensor to list | |
| for i in range(emb_cpu.size(0)): | |
| label = labs_list[i] | |
| # Use setdefault for cleaner initialization | |
| all_embeddings.setdefault(label, []).append(emb_cpu[i]) # Append CPU tensor | |
| except Exception as e: | |
| st.error(f"Error during embedding calculation batch: {e}") | |
| # Decide whether to continue or stop | |
| continue # Continue with next batch | |
| final_prototypes = [] | |
| prototype_labels = [] # Store the original dataset labels corresponding to prototypes | |
| unique_labels_present = sorted(list(all_embeddings.keys())) # Labels for which embeddings were found | |
| if not unique_labels_present: | |
| st.warning("No embeddings were generated. Cannot calculate prototypes.") | |
| return None, None | |
| for label in unique_labels_present: | |
| if not (0 <= label < len(_class_names)): | |
| st.warning(f"Skipping label {label} during prototype calculation: Out of bounds for class names list (len={len(_class_names)}).") | |
| continue | |
| class_embeddings_list = all_embeddings[label] | |
| if class_embeddings_list: | |
| try: | |
| class_embeddings = torch.stack(class_embeddings_list) # Stack CPU tensors | |
| prototype = class_embeddings.mean(dim=0) # Calculate mean on CPU | |
| final_prototypes.append(prototype) | |
| prototype_labels.append(label) # Append the original label | |
| except Exception as e: | |
| st.error(f"Error processing embeddings for class {label} ('{_class_names[label]}'): {e}") | |
| continue # Skip this class if stacking/mean fails | |
| # else: # Should not happen if label is in unique_labels_present keys | |
| # st.warning(f"No embeddings found for class ID {label} ('{_class_names[label]}') though key existed.") | |
| if not final_prototypes: | |
| st.warning("Could not calculate any valid final prototypes.") | |
| return None, None | |
| # Stack final prototypes and move to target device | |
| final_prototypes_tensor = torch.stack(final_prototypes).to(DEVICE) | |
| st.success(f"Calculated {len(final_prototypes)} final prototypes for labels: {prototype_labels}") | |
| return final_prototypes_tensor, prototype_labels | |
| # Visualize Prototypes Function | |
| def visualize_prototypes(prototypes_tensor, prototype_labels, class_names_list): | |
| st.write("Visualizing Prototypes using PCA...") | |
| if prototypes_tensor is None or prototypes_tensor.numel() == 0: | |
| st.warning("⚠️ No prototypes available to visualize.") | |
| return | |
| num_prototypes = prototypes_tensor.size(0) | |
| if num_prototypes < 2: | |
| st.warning(f"⚠️ Need at least 2 prototypes for PCA. Found {num_prototypes}.") | |
| # Optionally display the single prototype info | |
| if num_prototypes == 1: | |
| st.write(f"Single prototype label: {prototype_labels[0]} ({class_names_list[prototype_labels[0]]})") | |
| return | |
| # Ensure number of labels matches number of prototypes | |
| if len(prototype_labels) != num_prototypes: | |
| st.error(f"Mismatch between number of prototypes ({num_prototypes}) and labels ({len(prototype_labels)}). Cannot visualize.") | |
| return | |
| pca = PCA(n_components=2) | |
| try: | |
| # Detach, move to CPU, convert to numpy | |
| prototypes_np = prototypes_tensor.detach().cpu().numpy() | |
| # Check for NaN or Inf before PCA | |
| if not np.all(np.isfinite(prototypes_np)): | |
| st.error("Prototypes contain NaN or Infinite values. Cannot perform PCA.") | |
| # Optionally show where NaNs are: np.isnan(prototypes_np).any(axis=1) | |
| return | |
| prototypes_2d = pca.fit_transform(prototypes_np) | |
| fig, ax = plt.subplots(figsize=(12, 9)) # Adjusted size | |
| # Use a colormap for potentially many classes | |
| cmap = plt.get_cmap('tab20', len(class_names_list)) # Use tab20 or other suitable map | |
| plotted_labels = set() # Keep track of labels already in legend | |
| for i, label_index in enumerate(prototype_labels): | |
| if 0 <= label_index < len(class_names_list): | |
| class_name = class_names_list[label_index] | |
| legend_label = f"{class_name} ({label_index})" | |
| color = cmap(label_index / len(class_names_list)) # Assign color based on global index | |
| # Only add label to legend once per class | |
| if label_index not in plotted_labels: | |
| ax.scatter(prototypes_2d[i, 0], prototypes_2d[i, 1], label=legend_label, s=100, color=color, alpha=0.8) | |
| plotted_labels.add(label_index) | |
| else: | |
| ax.scatter(prototypes_2d[i, 0], prototypes_2d[i, 1], s=100, color=color, alpha=0.8) # No label if already added | |
| else: | |
| # Handle labels out of bounds | |
| ax.scatter(prototypes_2d[i, 0], prototypes_2d[i, 1], label=f"Invalid Label ({label_index})", s=100, marker='x', color='red') | |
| st.warning(f"Label index {label_index} out of bounds for class names list (length {len(class_names_list)}).") | |
| ax.set_title("Prototypes Visualization (PCA - 2 Components)") | |
| ax.set_xlabel("PCA Component 1") | |
| ax.set_ylabel("PCA Component 2") | |
| # Adjust legend position/size if too many classes | |
| ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small') | |
| ax.grid(True, linestyle='--', alpha=0.6) | |
| fig.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend outside | |
| st.pyplot(fig) | |
| except ValueError as ve: | |
| st.error(f"PCA Error: {ve}. Check prototype dimensions and values.") | |
| except Exception as e: | |
| st.error(f"Error during PCA visualization: {e}") | |
| st.exception(e) | |
| # --- Object Detection (YOLO) --- | |
| def load_yolo_model(): | |
| try: | |
| model = YOLO(YOLO_MODEL_PATH) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| st.error(f"Failed to load YOLO detection model from {YOLO_MODEL_PATH}: {e}") | |
| st.exception(e) | |
| st.stop() | |
| # Detection function for YOLOv11 | |
| def detect_objects(image): | |
| model = load_yolo_model() | |
| # Convert PIL Image to NumPy array (OpenCV format BGR) | |
| img_array = np.array(image.convert("RGB")) # Ensure RGB first | |
| img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
| # Perform inference | |
| try: | |
| results = model(img_bgr) # Run YOLO model | |
| except Exception as e: | |
| st.error(f"Error during YOLO inference: {e}") | |
| return img_array, pd.DataFrame() # Return original image and empty dataframe on error | |
| # Render results (draw bounding boxes on the image) | |
| # results[0].plot() returns a NumPy array (BGR) | |
| result_image_bgr = results[0].plot(conf=True, labels=True) # Show conf and labels on boxes | |
| # Convert result back to RGB for Streamlit display | |
| result_image_rgb = cv2.cvtColor(result_image_bgr, cv2.COLOR_BGR2RGB) | |
| # Extract detection details | |
| detections_list = [] | |
| if results[0].boxes is not None: | |
| boxes = results[0].boxes.xyxy.cpu().numpy() # Bounding boxes (xmin, ymin, xmax, ymax) | |
| confs = results[0].boxes.conf.cpu().numpy() # Confidence scores | |
| cls_ids = results[0].boxes.cls.cpu().numpy().astype(int) # Class IDs | |
| class_names = model.names # Get class names mapping from the model | |
| for i in range(len(boxes)): | |
| detections_list.append({ | |
| "Class": class_names.get(cls_ids[i], f"ID {cls_ids[i]}"), # Use .get for safety | |
| "Confidence": confs[i], | |
| "X_min": boxes[i, 0], | |
| "Y_min": boxes[i, 1], | |
| "X_max": boxes[i, 2], | |
| "Y_max": boxes[i, 3], | |
| }) | |
| detections_df = pd.DataFrame(detections_list) | |
| return result_image_rgb, detections_df | |
| # === Main App Logic === | |
| st.title("🌿 Coffee Leaf Disease Classifier + Few-Shot Learning + Detection") | |
| # --- Initialize Session State --- | |
| # Use .setdefault() for cleaner initialization | |
| st.session_state.setdefault('few_shot_trained', False) | |
| st.session_state.setdefault('final_prototypes', None) # Tensor of prototypes | |
| st.session_state.setdefault('prototype_labels', None) # List of original labels | |
| st.session_state.setdefault('model_mode', 'standard') # 'standard' or 'few_shot' | |
| # --- Load Data --- | |
| # This runs once and caches results, or re-runs if cache is cleared or args change | |
| # get_combined_dataset_and_indices handles errors internally and stops if needed | |
| combined_dataset, class_indices, class_names, num_base_classes = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| # --- Sidebar --- | |
| st.sidebar.header("⚙️ Options & Status") | |
| # --- Mode Selection / Status --- | |
| st.sidebar.subheader("Mode") | |
| # Button to switch back to standard classification | |
| if st.sidebar.button("🔄 Reset to Standard Classifier"): | |
| st.session_state.model_mode = 'standard' | |
| st.session_state.few_shot_trained = False | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.success("Switched to Standard Classification Mode.") | |
| # Clear cache related to few-shot results if necessary, e.g., prototype calculation cache | |
| # Note: calculate_final_prototypes uses @st.cache_data, might need manual clearing if model state changes matter | |
| st.cache_data.clear() # Clear data cache might be needed if dataset changed | |
| st.rerun() | |
| # Display current mode | |
| mode_status = "Standard Classifier" | |
| if st.session_state.model_mode == 'few_shot' and st.session_state.final_prototypes is not None: | |
| mode_status = f"Few-Shot ({len(st.session_state.final_prototypes)} Prototypes Active)" | |
| st.sidebar.info(f"**Current Mode:** {mode_status}") | |
| # --- Load/Delete Saved Few-Shot Models --- | |
| st.sidebar.divider() | |
| st.sidebar.subheader("💾 Saved Few-Shot Models") | |
| saved_model_names = list_saved_models() | |
| # --- Loading Section --- | |
| if not saved_model_names: | |
| st.sidebar.info("No saved few-shot models found.") | |
| else: | |
| selected_model_to_load = st.sidebar.selectbox( | |
| "Load a saved few-shot state:", | |
| options=[""] + saved_model_names, # Add empty option for placeholder | |
| key="load_model_select", | |
| index=0 # Default to empty selection | |
| ) | |
| if st.sidebar.button("📥 Load Selected State", key="load_model_button", disabled=(not selected_model_to_load)): | |
| if selected_model_to_load: | |
| # Get the current feature extractor instance to load weights into | |
| model_instance = cached_feature_extractor_model() # Get from cache | |
| # Get current class names for the crucial check during loading | |
| # Recalculate dataset info to be absolutely sure it's current | |
| _, _, current_cls_names_on_load, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| if load_few_shot_state(selected_model_to_load, model_instance, current_cls_names_on_load): | |
| st.rerun() # Rerun to reflect loaded state in UI | |
| # No need for else, button is disabled if nothing selected | |
| # --- Deleting Section --- | |
| if saved_model_names: # Only show delete options if models exist | |
| st.sidebar.markdown("---") # Separator within the section | |
| selected_model_to_delete = st.sidebar.selectbox( | |
| "Delete a saved few-shot state:", | |
| options=[""] + saved_model_names, | |
| key="delete_model_select", | |
| index=0 | |
| ) | |
| if selected_model_to_delete: # Only show checkbox if a model is selected | |
| confirm_delete = st.sidebar.checkbox(f"Confirm deletion of '{selected_model_to_delete}'", key="delete_confirm") | |
| if st.sidebar.button("❌ Delete Selected State", key="delete_model_button", disabled=(not confirm_delete)): | |
| if confirm_delete: # Double check confirmation state | |
| if delete_saved_model(selected_model_to_delete): | |
| # Clear potentially cached list of models? list_saved_models isn't cached, so rerun is enough | |
| st.rerun() | |
| # No need for else, button disabled if not confirmed | |
| else: | |
| st.sidebar.write("Select a model above to enable deletion.") # Guide user | |
| else: | |
| # This case is covered by the "No saved models found" message above loading. | |
| pass | |
| # --- Main Panel Options --- | |
| option = st.radio( | |
| "Choose an action:", | |
| ["Upload & Predict", "Add/Manage Rare Classes", "Train Few-Shot Model", "Visualize Prototypes", "Detection"], | |
| horizontal=True, key="main_option" # Added key for stability | |
| ) | |
| # Select the appropriate model based on mode (done inside prediction logic) | |
| # We load both cached models initially, ready for use | |
| feature_extractor_model = cached_feature_extractor_model() | |
| standard_classifier_model = cached_standard_classifier() | |
| # --- Action Implementation --- | |
| if option == "Upload & Predict": | |
| st.header("🔎 Upload Image for Prediction") | |
| uploaded_file = st.file_uploader("Choose a coffee leaf image...", type=["jpg", "jpeg", "png"], key="file_uploader") | |
| if uploaded_file: | |
| try: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| # Display smaller image to save space | |
| st.image(image, caption="Uploaded Image", width=300) # Control width | |
| input_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| # Determine which model/method to use | |
| use_few_shot = (st.session_state.model_mode == 'few_shot' and | |
| st.session_state.final_prototypes is not None and | |
| st.session_state.prototype_labels is not None and | |
| st.session_state.final_prototypes.numel() > 0 ) # Ensure prototypes are not empty | |
| if use_few_shot: | |
| # --- Few-Shot Prototype Prediction --- | |
| st.subheader("Prediction using Prototypes") | |
| model_to_use = feature_extractor_model # Use the feature extractor | |
| model_to_use.eval() | |
| with torch.no_grad(): | |
| embedding = model_to_use(input_tensor) # [1, embed_dim] | |
| # Ensure prototypes are on the correct device and shape [num_prototypes, embed_dim] | |
| prototypes_for_pred = st.session_state.final_prototypes.to(DEVICE) | |
| # Calculate distances: [1, embed_dim] vs [num_prototypes, embed_dim] -> [1, num_prototypes] | |
| distances = torch.cdist(embedding, prototypes_for_pred) | |
| pred_prototype_index = torch.argmin(distances, dim=1).item() # Index within the prototype list | |
| # Get the original dataset label corresponding to this prototype index | |
| predicted_original_label = st.session_state.prototype_labels[pred_prototype_index] | |
| # Map the original label to the class name | |
| if 0 <= predicted_original_label < len(class_names): | |
| predicted_class_name = class_names[predicted_original_label] | |
| # Confidence calculation (using softmax on negative distances) | |
| confidence_scores = torch.softmax(-distances, dim=1) | |
| confidence = confidence_scores[0, pred_prototype_index].item() | |
| st.metric(label="Prediction (Prototype)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") | |
| st.info(f"(Matched prototype for class ID: {predicted_original_label})") | |
| else: | |
| st.error(f"Predicted prototype label index {predicted_original_label} is out of range for known class names ({len(class_names)}). Prototype labels might be inconsistent.") | |
| else: | |
| # --- Standard Classification Prediction --- | |
| st.subheader("Prediction using Standard Classifier") | |
| if st.session_state.model_mode != 'standard': | |
| st.warning("Falling back to Standard Classifier mode (Few-shot prototypes not available or mode not set).") | |
| model_to_use = standard_classifier_model # Use the standard classifier | |
| model_to_use.eval() | |
| with torch.no_grad(): | |
| outputs = model_to_use(input_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| pred_label = torch.argmax(probs, dim=1).item() # This label corresponds to base classes (0-4) | |
| confidence = probs[0][pred_label].item() | |
| # Map predicted label (0-4) to the corresponding class name from the base set | |
| if 0 <= pred_label < num_base_classes: # Check against NUM_BASE_CLASSES | |
| predicted_class_name = class_names[pred_label] # Get name from the combined list using the base index | |
| st.metric(label="Prediction (Standard)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") | |
| else: | |
| # This case should technically not happen if the standard model only outputs 0-4 | |
| st.error(f"Standard classifier predicted label {pred_label}, which is out of range for base classes ({num_base_classes}). Model output issue?") | |
| except Exception as e: | |
| st.error(f"An error occurred during prediction: {e}") | |
| st.exception(e) | |
| elif option == "Detection": | |
| st.header("🕵️ Object Detection with YOLO") | |
| uploaded_file_detect = st.file_uploader("Upload an image for detection", type=["jpg", "jpeg", "png"], key="detect_uploader") | |
| if uploaded_file_detect: | |
| image_detect = Image.open(uploaded_file_detect).convert("RGB") | |
| result_image, detections = detect_objects(image_detect) # Calls the detection function | |
| # Resize for display (maintain aspect ratio) using Pillow ImageOps | |
| display_image = Image.fromarray(result_image) # Convert numpy array back to PIL Image | |
| display_image = ImageOps.contain(display_image, (900, 700)) # Resize while keeping aspect ratio within bounds | |
| st.image(display_image, caption="Detection Result", use_column_width=True) # Use column width | |
| # Show detection results | |
| if not detections.empty: | |
| st.subheader("📋 Detection Results:") | |
| # Format confidence as percentage | |
| detections['Confidence'] = detections['Confidence'].map('{:.1%}'.format) | |
| st.dataframe(detections[['Class', 'Confidence', 'X_min', 'Y_min', 'X_max', 'Y_max']]) # Display selected columns | |
| else: | |
| st.info("No objects detected.") | |
| elif option == "Add/Manage Rare Classes": | |
| st.header("➕ Add New Rare Class (Few-Shot)") | |
| st.write(f"Upload exactly 5 sample images for the new disease class.") | |
| with st.form("add_class_form"): | |
| new_class_name = st.text_input("Enter the name for the new rare class:") | |
| uploaded_files_rare = st.file_uploader( | |
| "Upload 5 images:", accept_multiple_files=True, type=["jpg", "jpeg", "png"], key="add_class_uploader" | |
| ) | |
| submitted_add = st.form_submit_button("Add Class") | |
| if submitted_add: | |
| valid = True | |
| if not new_class_name: | |
| st.warning("Please enter a class name.") | |
| valid = False | |
| if len(uploaded_files_rare) != 5: | |
| st.warning(f"Please upload exactly 5 images. You uploaded {len(uploaded_files_rare)}.") | |
| valid = False | |
| if valid: | |
| # Sanitize class name for directory | |
| sanitized_class_name = "".join(c for c in new_class_name if c.isalnum() or c in (' ', '_')).strip().replace(" ", "_") | |
| if not sanitized_class_name: | |
| st.error("Invalid class name after sanitization.") | |
| else: | |
| new_class_dir = os.path.join(RARE_DATASET, sanitized_class_name) | |
| if os.path.exists(new_class_dir): | |
| st.warning(f"A class directory named '{sanitized_class_name}' already exists. Choose a different name or delete the existing one first.") | |
| else: | |
| try: | |
| os.makedirs(new_class_dir, exist_ok=True) | |
| image_save_errors = 0 | |
| for i, file in enumerate(uploaded_files_rare): | |
| try: | |
| img = Image.open(file).convert("RGB") | |
| # Optional: Resize or standardize images on upload if needed | |
| # img = img.resize((256, 256)) | |
| save_path = os.path.join(new_class_dir, f"sample_{i+1}.jpg") | |
| img.save(save_path, format='JPEG', quality=95) # Save as JPEG | |
| except Exception as img_e: | |
| st.error(f"Error saving image {i+1} ({file.name}): {img_e}") | |
| image_save_errors += 1 | |
| if image_save_errors == 0: | |
| st.success(f"✅ Added new class: '{sanitized_class_name}'. Please re-run 'Train Few-Shot Model' to incorporate it.") | |
| # CRITICAL: Clear cache so dataset is reloaded with the new class | |
| st.cache_data.clear() | |
| # Clear potentially outdated prototypes if a class is added | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.few_shot_trained = False | |
| st.session_state.model_mode = 'standard' # Reset to standard after adding class | |
| st.rerun() # Force rerun to reload data | |
| else: | |
| st.error(f"Failed to save {image_save_errors} images. Class directory might be incomplete. Please check and try again.") | |
| # Optionally remove the created directory if saving failed partially | |
| # shutil.rmtree(new_class_dir) | |
| except Exception as e: | |
| st.error(f"Error creating directory or saving images for class '{sanitized_class_name}': {e}") | |
| st.exception(e) | |
| st.divider() | |
| st.header("❌ Delete a Rare Class") | |
| try: | |
| # List only directories inside RARE_DATASET | |
| rare_class_dirs = [d for d in os.listdir(RARE_DATASET) if os.path.isdir(os.path.join(RARE_DATASET, d))] | |
| if not rare_class_dirs: | |
| st.info("No rare classes found to delete.") | |
| else: | |
| with st.form("delete_class_form"): | |
| to_delete = st.selectbox("Select rare class to delete:", rare_class_dirs, key="delete_rare_select") | |
| confirm_delete_rare = st.checkbox(f"Are you sure you want to permanently delete '{to_delete}' and its contents?", key="delete_rare_confirm") | |
| delete_submit_rare = st.form_submit_button("Delete Class") | |
| if delete_submit_rare: | |
| if confirm_delete_rare: | |
| delete_path = os.path.join(RARE_DATASET, to_delete) | |
| try: | |
| shutil.rmtree(delete_path) | |
| st.success(f"✅ Deleted rare class: {to_delete}") | |
| # CRITICAL: Clear cache and reset state | |
| st.cache_data.clear() | |
| st.session_state.few_shot_trained = False | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.model_mode = 'standard' | |
| st.rerun() # Rerun to reload data and update UI | |
| except Exception as e: | |
| st.error(f"Error deleting directory {delete_path}: {e}") | |
| else: | |
| st.warning("Please confirm the deletion by checking the box.") | |
| except FileNotFoundError: | |
| st.info(f"Rare dataset directory '{RARE_DATASET}' not found or inaccessible.") | |
| except Exception as e: | |
| st.error(f"Error listing rare classes: {e}") | |
| elif option == "Train Few-Shot Model": | |
| st.header("🚀 Train Few-Shot Model") | |
| st.warning("⚠️ This will fine-tune the feature extractor. Performance on original classes might change. Consider saving the state after training.") | |
| # --- Check if enough classes exist before showing the form --- | |
| if len(class_names) < 2: # Need at least 2 classes for N-way=2 training | |
| st.error("Need at least two classes (Base + Rare combined) to perform few-shot training.") | |
| st.stop() | |
| else: | |
| st.info(f"Training will use all {len(class_names)} available classes: {class_names}") | |
| # --- Hardcoded Training Parameters --- | |
| # --- Training Parameters --- | |
| epochs = 10 | |
| # Set N-Way dynamically to the total number of available classes | |
| n_way_train = len(class_names) | |
| # Ensure N-way is at least 2 for meaningful training | |
| if n_way_train < 2: | |
| st.error(f"Training requires at least 2 classes, but found only {n_way_train}. Cannot proceed with few-shot training.") | |
| st.stop() # Stop if not enough classes | |
| episodes_per_epoch = 5 # Keep other parameters | |
| n_shot = 2 | |
| n_query = 2 | |
| learning_rate = 1e-5 | |
| # Add a warning about potential resource usage if N-way is high | |
| if n_way_train > 7: # Example threshold, adjust as needed | |
| st.warning(f"⚠️ Using N-Way = {n_way_train} (all classes). This uses more memory and computation per episode than lower N-way settings. Ensure your system has sufficient resources (especially GPU VRAM).") | |
| elif n_way_train > 5: | |
| st.info(f"Using N-Way = {n_way_train} (all classes). This might be resource-intensive.") | |
| st.info(f"Training Parameters: Epochs={epochs}, N-Way={n_way_train}, Episodes/Epoch={episodes_per_epoch}, N-Shot={n_shot}, N-Query={n_query}, LR={learning_rate:.0e}") | |
| # --- Training Form --- | |
| with st.form("train_form"): | |
| freeze_backbone = st.checkbox("❄️ Freeze Base Model Layers (Train only projection layer)", value=True, help="Recommended to prevent catastrophic forgetting of base classes.") | |
| submitted_train = st.form_submit_button("Start Training") | |
| if submitted_train: | |
| # Re-check class count just before training starts | |
| if len(class_names) < n_way_train: | |
| st.error(f"Cannot start training. Need at least {n_way_train} classes for {n_way_train}-way training, found {len(class_names)}.") | |
| st.stop() | |
| st.info("🚀 Starting few-shot training...") | |
| progress_bar = st.progress(0) | |
| status_placeholder = st.empty() | |
| chart_placeholder = st.empty() # Placeholder for the chart | |
| # --- Get model instance for training --- | |
| model_train = cached_feature_extractor_model() # Get the cached model | |
| model_train.train() # Set model to training mode | |
| # --- Optimizer Setup --- | |
| trainable_params = [] | |
| if freeze_backbone: | |
| st.info("Freezing base model layers...") | |
| try: | |
| # Freeze base model (assuming it's in model_train.model) | |
| for param in model_train.model.parameters(): | |
| param.requires_grad = False | |
| # Ensure projection layer is trainable (assuming model_train.projection) | |
| for param in model_train.projection.parameters(): | |
| param.requires_grad = True | |
| trainable_params = list(filter(lambda p: p.requires_grad, model_train.parameters())) | |
| if not trainable_params: | |
| st.error("No trainable parameters found (projection layer)! Check model structure ('model' and 'projection' attributes).") | |
| st.stop() | |
| st.success("Base layers frozen. Training projection layer only.") | |
| except AttributeError: | |
| st.error("Could not access model.model or model.projection attributes to freeze/unfreeze. Training ALL layers.") | |
| for param in model_train.parameters(): # Ensure all are trainable if fallback | |
| param.requires_grad = True | |
| trainable_params = list(model_train.parameters()) | |
| else: | |
| st.info("Training all layers (base model + projection).") | |
| # Ensure all parameters are trainable | |
| for param in model_train.parameters(): | |
| param.requires_grad = True | |
| trainable_params = list(model_train.parameters()) | |
| # Check if any trainable parameters were actually found | |
| if not trainable_params: | |
| st.error("Optimizer setup failed: No trainable parameters collected.") | |
| st.stop() | |
| optimizer = torch.optim.Adam(trainable_params, lr=learning_rate, weight_decay=1e-5) | |
| st.info(f"Using Adam optimizer with LR={learning_rate:.0e}. Training {len(trainable_params)} parameters.") | |
| # --- Training Loop --- | |
| loss_history = [] | |
| accuracy_history = [] | |
| total_steps = epochs * episodes_per_epoch | |
| current_step = 0 | |
| # Get fresh dataset info for episode creation within the loop if needed, or assume cached is fine | |
| # _, current_class_indices_train, current_class_names_train, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| for epoch in range(epochs): | |
| epoch_loss = 0.0 | |
| epoch_accuracy = 0.0 | |
| valid_episodes_in_epoch = 0 | |
| for episode in range(episodes_per_epoch): | |
| current_step += 1 | |
| # Create episode using current dataset state | |
| s_imgs, s_labels, q_imgs, q_labels = create_episode( | |
| combined_dataset, class_indices, class_names, n_way=n_way_train, n_shot=n_shot, n_query=n_query | |
| ) | |
| if s_imgs is None or q_imgs is None: continue # Skip if episode creation failed | |
| optimizer.zero_grad() | |
| try: | |
| s_emb = model_train(s_imgs) | |
| q_emb = model_train(q_imgs) | |
| except Exception as model_e: | |
| st.error(f"Error during model forward pass in training (Epoch {epoch+1}, Ep {episode+1}): {model_e}") | |
| st.exception(model_e) | |
| continue # Skip episode on model error | |
| loss, accuracy = proto_loss(s_emb, s_labels, q_emb, q_labels) | |
| if loss is not None and not torch.isnan(loss) and loss.requires_grad: | |
| try: | |
| loss.backward() | |
| # Optional: Gradient clipping | |
| # torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| epoch_accuracy += accuracy | |
| valid_episodes_in_epoch += 1 | |
| except Exception as optim_e: | |
| st.error(f"Error during optimizer step or backward pass (Epoch {epoch+1}, Ep {episode+1}): {optim_e}") | |
| st.exception(optim_e) | |
| # Consider stopping or just skipping step? Skipping for now. | |
| # else: # Reduce verbosity for invalid loss skipping | |
| # Update status and progress bar less frequently for performance | |
| if (episode + 1) % 10 == 0 or episode == episodes_per_epoch - 1: | |
| progress = current_step / total_steps | |
| progress_bar.progress(progress) | |
| status_placeholder.text(f"Epoch {epoch+1}/{epochs} | Episode {episode+1}/{episodes_per_epoch} | Current Loss: {loss.item():.4f} | Current Acc: {accuracy*100:.2f}%") | |
| # --- Log epoch results --- | |
| if valid_episodes_in_epoch > 0: | |
| avg_loss = epoch_loss / valid_episodes_in_epoch | |
| avg_accuracy = epoch_accuracy / valid_episodes_in_epoch | |
| loss_history.append(avg_loss) | |
| accuracy_history.append(avg_accuracy) | |
| status_placeholder.text(f"Epoch {epoch+1}/{epochs} Completed - Avg Loss: {avg_loss:.4f} - Avg Accuracy: {avg_accuracy*100:.2f}%") | |
| else: | |
| # Handle epochs where no valid episodes ran | |
| loss_history.append(float('nan')) | |
| accuracy_history.append(float('nan')) | |
| status_placeholder.text(f"Epoch {epoch+1}/{epochs} Completed - No valid episodes were run.") | |
| status_placeholder.success("✅ Few-Shot Training Finished!") | |
| # --- Final Prototype Calculation (still inside 'if submitted') --- | |
| st.info("Calculating final prototypes...") | |
| # Ensure model is in eval mode | |
| model_train.eval() | |
| # Clear cache before calculating prototypes based on potentially new model state | |
| st.cache_data.clear() # Clear data cache | |
| # Recalculate dataset info to ensure it's current | |
| current_combined_dataset_proto, _, current_class_names_proto, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| # Pass the *trained* model instance | |
| final_prototypes_tensor, final_prototype_labels = calculate_final_prototypes(model_train, current_combined_dataset_proto, current_class_names_proto) | |
| # --- Store results in session state --- | |
| if final_prototypes_tensor is not None: | |
| st.session_state.final_prototypes = final_prototypes_tensor | |
| st.session_state.prototype_labels = final_prototype_labels | |
| st.session_state.few_shot_trained = True | |
| st.session_state.model_mode = 'few_shot' # Switch mode automatically | |
| st.success("Prototypes Calculated. Model ready for Few-Shot Prediction.") | |
| # --- Display Training Curves (use the placeholder) --- | |
| chart_data = pd.DataFrame({ | |
| "Epoch": list(range(1, epochs + 1)), | |
| "Average Loss": loss_history, | |
| "Average Accuracy": accuracy_history | |
| }) | |
| chart_placeholder.line_chart(chart_data.set_index("Epoch")) | |
| else: | |
| # Ensure state is reset if prototype calculation fails | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.few_shot_trained = False | |
| # Optionally reset mode? Or leave as is? Resetting is safer. | |
| # st.session_state.model_mode = 'standard' | |
| st.error("Failed to calculate final prototypes after training.") | |
| chart_placeholder.empty() # Clear chart placeholder on failure | |
| # --- END OF TRAINING FORM `with st.form("train_form"):` BLOCK --- | |
| # --- SAVING SECTION (OUTSIDE AND AFTER THE FORM) --- | |
| # Show this section ONLY if few-shot training has run successfully (prototypes exist) | |
| # Check session state for prototypes to decide whether to show this. | |
| if st.session_state.get('final_prototypes') is not None and st.session_state.get('model_mode') == 'few_shot': | |
| st.divider() | |
| st.subheader("💾 Save Current Few-Shot State") | |
| st.info("Save the fine-tuned model weights and calculated prototypes.") | |
| # Use a different key for the text input to avoid conflicts | |
| save_model_name = st.text_input("Enter a name for this state:", key="save_model_name_input_main") | |
| # This button is now OUTSIDE the form, so it's allowed. | |
| if st.button("Save State", key="save_state_button_main"): | |
| if save_model_name: | |
| # Need the model instance that was potentially trained. | |
| # Getting it from cache should work if it was modified in-place. | |
| model_to_save = cached_feature_extractor_model() | |
| model_to_save.eval() # Ensure eval mode | |
| # Get the current class names list for metadata | |
| _, _, current_cls_names_for_saving, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| save_few_shot_state( | |
| save_model_name, | |
| model_to_save, # Pass the model instance | |
| st.session_state.final_prototypes, | |
| st.session_state.prototype_labels, | |
| current_cls_names_for_saving | |
| ) | |
| # Clear input field is tricky without callbacks, maybe omit for simplicity | |
| else: | |
| st.warning("Please enter a name before saving.") | |
| elif option == "Visualize Prototypes": | |
| st.header("📊 Visualize Class Prototypes (PCA)") | |
| if st.session_state.final_prototypes is not None and st.session_state.prototype_labels is not None: | |
| # Pass the required arguments: tensor, labels list, class names list | |
| visualize_prototypes( | |
| st.session_state.final_prototypes, | |
| st.session_state.prototype_labels, | |
| class_names # Pass the globally available class names list | |
| ) | |
| else: | |
| st.warning("⚠️ No prototypes calculated yet. Run 'Train Few-Shot Model' or load a saved state.") | |
| st.divider() | |
| st.info("You can calculate prototypes based on the *current* feature extractor state (useful before/after training or loading).") | |
| if st.button("Calculate/Recalculate Prototypes Now"): | |
| with st.spinner("Calculating prototypes..."): | |
| # Use the current state of the cached feature extractor | |
| proto_model = cached_feature_extractor_model() | |
| proto_model.eval() | |
| # Clear cache before calculation? Depends if calculate_final_prototypes uses cached data internally | |
| st.cache_data.clear() # Clear data cache to be safe | |
| # Recalculate dataset info | |
| current_combined_dataset_viz, _, current_class_names_viz, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| # Calculate prototypes | |
| temp_prototypes, temp_labels = calculate_final_prototypes(proto_model, current_combined_dataset_viz, current_class_names_viz) | |
| if temp_prototypes is not None: | |
| st.session_state.final_prototypes = temp_prototypes | |
| st.session_state.prototype_labels = temp_labels | |
| # Don't set few_shot_trained=True here automatically, maybe? | |
| # Let's set mode to few_shot if prototypes are calculated successfully | |
| st.session_state.model_mode = 'few_shot' | |
| st.success("Prototypes calculated/recalculated.") | |
| st.rerun() # Rerun to show visualization or update status | |
| else: | |
| st.error("Failed to calculate prototypes.") | |
| # --- Cleanup Temporary Directory (Optional) --- | |
| # This might run on every script run, which is usually fine for temp dirs. | |
| # Consider more robust cleanup if needed. | |
| # try: | |
| # shutil.rmtree(TEMP_DIR) | |
| # except Exception as e: | |
| # st.sidebar.warning(f"Could not cleanup temp dir {TEMP_DIR}: {e}") |