Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms, datasets | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| from PIL import Image, ImageOps # Added ImageOps | |
| import os | |
| # import tempfile # Not used for primary storage here | |
| import random | |
| import shutil | |
| import matplotlib.pyplot as plt | |
| # from sklearn.decomposition import PCA # PCA not used, removed | |
| import numpy as np | |
| import cv2 | |
| # from PIL import Image # Already imported | |
| from ultralytics import YOLO | |
| import pandas as pd | |
| import json # Added for saving/loading metadata | |
| # --- README.md Configuration Reminder --- | |
| # Make sure your README.md contains at least: | |
| # --- | |
| # hardware: your_gpu_id_here # e.g., t4-small (Required for persistent storage) | |
| # storage: | |
| # mount_point: /data | |
| # --- | |
| # And other keys like sdk, app_file. | |
| # ---------------------------------------- | |
| st.set_page_config(layout="wide") | |
| # --- Constants and Setup --- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Path Configuration --- | |
| # Get the directory where this script is located | |
| APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # --- Persistent Storage Mount Point --- | |
| PERSISTENT_STORAGE_MOUNT_POINT = "/data" # Standard mount point | |
| # Define paths relative to the script directory or persistent storage | |
| BASE_DATASET = os.path.join(APP_DIR, "App data/data_optimize/Basee_data") # From Git Repo | |
| MODEL_WEIGHTS_PATH = os.path.join(APP_DIR, "model/efficientnet_coffee (1).pth") # From Git Repo | |
| YOLO_MODEL_PATH = os.path.join(APP_DIR, "model/best.pt") # From Git Repo | |
| # *** MODIFICATION START: Point dynamic data to persistent storage *** | |
| RARE_DATASET = os.path.join(PERSISTENT_STORAGE_MOUNT_POINT, "Rare_data") # On Persistent Volume | |
| SAVED_MODELS_DIR = os.path.join(PERSISTENT_STORAGE_MOUNT_POINT, "saved_few_shot_models") # On Persistent Volume | |
| # *** MODIFICATION END *** | |
| # --- Ensure Persistent Directories Exist --- | |
| # *** MODIFICATION START: Create persistent dirs if needed *** | |
| try: | |
| os.makedirs(RARE_DATASET, exist_ok=True) | |
| os.makedirs(SAVED_MODELS_DIR, exist_ok=True) | |
| # Optional: Show effective paths being used | |
| # st.sidebar.info(f"Rare Data Path: {RARE_DATASET}") | |
| # st.sidebar.info(f"Saved Models Path: {SAVED_MODELS_DIR}") | |
| except OSError as e: | |
| st.error(f"Fatal: Error creating directories in persistent storage ('{PERSISTENT_STORAGE_MOUNT_POINT}'). Check Space config/permissions. Error: {e}") | |
| st.stop() # Stop execution if persistent storage dirs can't be created | |
| except Exception as e: | |
| st.error(f"Fatal: An unexpected error occurred accessing persistent storage: {e}") | |
| st.stop() | |
| # *** MODIFICATION END *** | |
| # Check required files/folders from Git Repo exist (using paths relative to APP_DIR) | |
| if not os.path.isdir(BASE_DATASET): | |
| st.error(f"Base dataset directory not found: {BASE_DATASET}") | |
| st.stop() | |
| if not os.path.isfile(MODEL_WEIGHTS_PATH): | |
| st.error(f"Classifier weights file not found: {MODEL_WEIGHTS_PATH}") | |
| st.stop() | |
| if not os.path.isfile(YOLO_MODEL_PATH): | |
| st.error(f"YOLO detection model file not found: {YOLO_MODEL_PATH}") | |
| st.stop() | |
| st.sidebar.info(f"Using device: {DEVICE}") | |
| # --- Helper Functions for Saving/Loading Few-Shot States --- | |
| # (Original Code - relies on SAVED_MODELS_DIR which now points to /data/...) | |
| def list_saved_models(): | |
| """Returns a list of names of saved few-shot model states.""" | |
| if not os.path.isdir(SAVED_MODELS_DIR): | |
| # st.warning(f"Directory not found: {SAVED_MODELS_DIR}") # Less verbose | |
| return [] | |
| try: | |
| return [d for d in os.listdir(SAVED_MODELS_DIR) if os.path.isdir(os.path.join(SAVED_MODELS_DIR, d))] | |
| except Exception as e: | |
| st.error(f"Error listing {SAVED_MODELS_DIR}: {e}") | |
| return [] | |
| def save_few_shot_state(name, model, prototypes, proto_labels, current_class_names, few_shot_strategy): | |
| """Saves the model state, prototypes, strategy, and metadata.""" | |
| if not name or not name.strip(): | |
| st.error("Please provide a valid name for the saved model.") | |
| return False | |
| sanitized_name = "".join(c for c in name if c.isalnum() or c in ('_', '-')).rstrip() | |
| if not sanitized_name: | |
| st.error("Invalid name after sanitization. Use letters, numbers, underscore, or hyphen.") | |
| return False | |
| save_dir = os.path.join(SAVED_MODELS_DIR, sanitized_name) # Path uses /data/... | |
| proceed_with_save = True | |
| if os.path.exists(save_dir): | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.warning(f"Model name '{sanitized_name}' already exists.") | |
| with col2: | |
| overwrite_key = f"overwrite_button_{sanitized_name}" | |
| if not st.button("Overwrite?", key=overwrite_key): | |
| st.info("Save cancelled. Choose a different name or click 'Overwrite?'.") | |
| proceed_with_save = False | |
| else: | |
| st.info(f"Overwriting '{sanitized_name}'...") | |
| else: | |
| st.info(f"Saving new model state '{sanitized_name}'...") | |
| if not proceed_with_save: | |
| return False | |
| try: | |
| os.makedirs(save_dir, exist_ok=True) # Creates dir in /data/... | |
| # 1. Save Model State Dictionary | |
| model.to('cpu') | |
| model_path = os.path.join(save_dir, "feature_extractor_state_dict.pth") | |
| torch.save(model.state_dict(), model_path) | |
| model.to(DEVICE) | |
| # 2. Save Prototypes Tensor | |
| prototypes_path = os.path.join(save_dir, "prototypes.pt") | |
| torch.save(prototypes.cpu(), prototypes_path) | |
| # 3. Save Metadata | |
| if few_shot_strategy != 'train_projection': | |
| st.error(f"Internal Error: Attempting to save with invalid strategy '{few_shot_strategy}'. Expected 'train_projection'. Save cancelled.") | |
| if os.path.exists(save_dir): shutil.rmtree(save_dir) | |
| return False | |
| metadata = { | |
| "prototype_labels": proto_labels, | |
| "class_names_on_save": current_class_names, | |
| "few_shot_strategy": 'train_projection' | |
| } | |
| metadata_path = os.path.join(save_dir, "metadata.json") | |
| with open(metadata_path, 'w') as f: | |
| json.dump(metadata, f, indent=4) | |
| st.success(f"Few-shot model state '{sanitized_name}' saved successfully!") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving model state '{sanitized_name}' to {save_dir}: {e}") | |
| st.exception(e) | |
| if os.path.exists(save_dir): | |
| try: | |
| shutil.rmtree(save_dir) | |
| st.info(f"Cleaned up partially saved directory '{save_dir}'.") | |
| except Exception as cleanup_e: | |
| st.error(f"Error cleaning up directory during save failure: {cleanup_e}") | |
| return False | |
| def load_few_shot_state(name, model_to_load_into, current_class_names): | |
| """Loads a saved model state, prototypes, labels, and strategy into session state and the model.""" | |
| load_dir = os.path.join(SAVED_MODELS_DIR, name) # Path uses /data/... | |
| if not os.path.isdir(load_dir): | |
| st.error(f"Saved model directory '{load_dir}' not found.") | |
| return False | |
| model_path = os.path.join(load_dir, "feature_extractor_state_dict.pth") | |
| prototypes_path = os.path.join(load_dir, "prototypes.pt") | |
| metadata_path = os.path.join(load_dir, "metadata.json") | |
| if not all(os.path.exists(p) for p in [model_path, prototypes_path, metadata_path]): | |
| st.error(f"Saved model '{name}' is incomplete. Files missing in '{load_dir}'.") | |
| return False | |
| try: | |
| # 1. Load Metadata | |
| with open(metadata_path, 'r') as f: | |
| metadata = json.load(f) | |
| loaded_proto_labels = metadata.get("prototype_labels") | |
| saved_class_names = metadata.get("class_names_on_save") | |
| loaded_strategy = metadata.get("few_shot_strategy") | |
| if loaded_proto_labels is None or saved_class_names is None or loaded_strategy is None: | |
| st.error(f"Metadata file for '{name}' is corrupted or missing required keys (labels, class_names, strategy).") | |
| return False | |
| if loaded_strategy != 'train_projection': | |
| st.error(f"Saved model '{name}' used strategy '{loaded_strategy}', but only 'train_projection' (frozen backbone) is currently supported. Cannot load.") | |
| return False | |
| if set(saved_class_names) != set(current_class_names): | |
| st.warning(f"⚠️ **Class Mismatch!**") | |
| st.warning(f"Saved model '{name}' classes: `{saved_class_names}`") | |
| st.warning(f"Current active classes: `{current_class_names}`") | |
| st.warning("Predictions might be incorrect or errors may occur. Proceed with caution.") | |
| # 2. Load Model State Dictionary | |
| model_to_load_into.to(DEVICE) | |
| state_dict = torch.load(model_path, map_location=DEVICE) | |
| try: | |
| missing_keys, unexpected_keys = model_to_load_into.load_state_dict(state_dict, strict=True) # Keep strict=True | |
| if missing_keys: st.warning(f"Loaded state dict is missing keys: {missing_keys}") | |
| if unexpected_keys: st.warning(f"Loaded state dict has unexpected keys: {unexpected_keys}") | |
| except RuntimeError as e: | |
| st.error(f"RuntimeError loading state_dict for '{name}'. Architecture mismatch? {e}") | |
| st.error("This usually means the saved model structure (base + projection) doesn't match the current code's structure.") | |
| return False | |
| model_to_load_into.eval() | |
| # 3. Load Prototypes | |
| loaded_prototypes = torch.load(prototypes_path, map_location=DEVICE) | |
| # 4. Update Session State | |
| st.session_state.final_prototypes = loaded_prototypes | |
| st.session_state.prototype_labels = loaded_proto_labels | |
| st.session_state.few_shot_strategy = loaded_strategy | |
| st.session_state.few_shot_trained = True | |
| st.session_state.model_mode = 'few_shot' | |
| st.success(f"Successfully loaded few-shot model state '{name}' (Strategy: {loaded_strategy}). Mode set to Few-Shot.") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error loading model state '{name}' from {load_dir}: {e}") | |
| st.exception(e) | |
| # Reset state if loading fails partially | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.few_shot_strategy = None | |
| st.session_state.few_shot_trained = False | |
| st.session_state.model_mode = 'standard' | |
| return False | |
| def delete_saved_model(name): | |
| """Deletes a saved model directory.""" | |
| delete_dir = os.path.join(SAVED_MODELS_DIR, name) # Path uses /data/... | |
| if not os.path.isdir(delete_dir): | |
| st.error(f"Cannot delete. Saved model '{name}' not found in {SAVED_MODELS_DIR}.") # Updated path in msg | |
| return False | |
| try: | |
| shutil.rmtree(delete_dir) | |
| st.success(f"Deleted saved model '{name}'.") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error deleting saved model '{name}' from {delete_dir}: {e}") # Updated path in msg | |
| return False | |
| # --- Model Architectures --- | |
| # (Original Code) | |
| class EfficientNetWithProjection(nn.Module): | |
| def __init__(self, base_model, output_dim=1024): | |
| super(EfficientNetWithProjection, self).__init__() | |
| self.model = base_model | |
| in_features = 1280 | |
| self.projection = nn.Linear(in_features, output_dim) | |
| def forward(self, x): | |
| features = self.model(x) | |
| projected_features = self.projection(features) | |
| return projected_features | |
| def get_base_efficientnet_architecture(num_classes=5): | |
| model = models.efficientnet_b0(weights=None) | |
| in_features = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(in_features, num_classes) | |
| return model | |
| def get_feature_extractor_base(): | |
| base_model = get_base_efficientnet_architecture(num_classes=5) | |
| try: | |
| # Load from path relative to script dir (Git repo) | |
| state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) | |
| missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False) | |
| if unexpected_keys and not all(k.startswith('classifier.') for k in unexpected_keys): | |
| st.warning(f"Loading base weights: Unexpected keys found beyond classifier: {unexpected_keys}") | |
| if missing_keys: | |
| st.warning(f"Loading base weights: Missing keys: {missing_keys}") | |
| except Exception as e: | |
| st.error(f"Error loading model weights from {MODEL_WEIGHTS_PATH} into base architecture: {e}") | |
| st.exception(e) | |
| st.stop() | |
| base_model.classifier = nn.Identity() | |
| base_model.eval() | |
| return base_model | |
| def load_standard_classifier(): | |
| model = get_base_efficientnet_architecture(num_classes=5) | |
| try: | |
| # Load from path relative to script dir (Git repo) | |
| state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict, strict=True) | |
| except Exception as e: | |
| st.error(f"Error loading model weights for standard classifier: {e}") | |
| st.exception(e) | |
| st.stop() | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| # --- Caching --- | |
| # (Original Code) | |
| def cached_feature_extractor_model(): | |
| base_model = get_feature_extractor_base() | |
| model = EfficientNetWithProjection(base_model, output_dim=1024) | |
| model.to(DEVICE) | |
| model.eval() | |
| st.sidebar.info("Feature extractor model ready (cached).") | |
| return model | |
| def cached_standard_classifier(): | |
| model = load_standard_classifier() | |
| st.sidebar.info("Standard classifier model ready (cached).") | |
| return model | |
| # --- Data Loading --- | |
| # (Original Code - uses base_path from Git, rare_path from /data/...) | |
| def get_combined_dataset_and_indices(base_path, rare_path): | |
| """Loads base data from Git repo path and rare data from persistent storage path.""" | |
| try: | |
| transform_local = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load Base dataset from Git repo path | |
| if not os.path.isdir(base_path): | |
| st.error(f"Base dataset path not found: {base_path}") | |
| st.stop() | |
| full_dataset = datasets.ImageFolder(base_path, transform=transform_local) | |
| num_base_classes = len(full_dataset.classes) | |
| base_class_names = sorted(full_dataset.classes) | |
| rare_classes_found = 0 | |
| rare_class_names = [] | |
| combined_dataset = full_dataset # Start with base dataset | |
| # Try loading Rare dataset from persistent storage path (rare_path = /data/Rare_data) | |
| if os.path.isdir(rare_path) and any(os.scandir(rare_path)): | |
| try: | |
| rare_dataset = datasets.ImageFolder(rare_path, transform=transform_local) | |
| if len(rare_dataset.samples) > 0: | |
| rare_dataset.samples = [(path, label + num_base_classes) for path, label in rare_dataset.samples] | |
| combined_dataset = ConcatDataset([full_dataset, rare_dataset]) | |
| rare_classes_found = len(rare_dataset.classes) | |
| rare_class_names = sorted(rare_dataset.classes) | |
| else: | |
| st.info(f"Rare dataset directory '{rare_path}' exists but is empty.") | |
| except Exception as e_rare: | |
| st.warning(f"Could not load rare dataset from {rare_path}: {e_rare}. Using base dataset only.") | |
| else: | |
| st.info(f"Rare dataset directory '{rare_path}' not found or empty. Using base dataset only.") | |
| # --- Index calculation (Original logic) --- | |
| indices = {} | |
| current_idx = 0 | |
| if isinstance(combined_dataset, ConcatDataset): | |
| for ds in combined_dataset.datasets: | |
| if hasattr(ds, 'samples'): | |
| for _, label in ds.samples: | |
| indices.setdefault(label, []).append(current_idx) | |
| current_idx += 1 | |
| else: | |
| for i in range(len(ds)): | |
| _, label = ds[i] | |
| indices.setdefault(label, []).append(current_idx) | |
| current_idx += 1 | |
| elif isinstance(combined_dataset, datasets.ImageFolder): | |
| for idx, (_, label) in enumerate(combined_dataset.samples): | |
| indices.setdefault(label, []).append(idx) | |
| else: | |
| st.error("Unexpected dataset type encountered when building indices.") | |
| st.stop() | |
| class_names = base_class_names + rare_class_names | |
| st.sidebar.metric("Base Classes (Git)", num_base_classes) | |
| st.sidebar.metric("Rare Classes (Storage)", rare_classes_found) | |
| st.sidebar.metric("Total Classes", len(class_names)) | |
| if len(class_names) == 0: | |
| st.error("No classes found in base or rare datasets. Check paths/contents.") | |
| st.stop() | |
| return combined_dataset, indices, class_names, num_base_classes | |
| except FileNotFoundError as e: | |
| st.error(f"Dataset path error: {e}. Check BASE_DATASET ('{base_path}') and RARE_DATASET ('{rare_path}').") | |
| st.stop() | |
| except Exception as e: | |
| st.error(f"Error loading datasets: {e}") | |
| st.exception(e) | |
| st.stop() | |
| # --- Global transform --- | |
| # (Original Code) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # --- Few-Shot Learning Functions --- | |
| # (Original Code) | |
| def create_episode(dataset, class_indices, class_list, n_way=5, n_shot=5, n_query=5): | |
| """Creates an episode for Prototypical Networks.""" | |
| available_classes = list(class_indices.keys()) | |
| if len(available_classes) < n_way: | |
| n_way = len(available_classes) | |
| if n_way < 2: | |
| st.error(f"Episode creation failed: Need at least 2 classes with enough samples, found {n_way}.") | |
| return None, None, None, None | |
| eligible_classes = [ | |
| cls_id for cls_id in available_classes | |
| if len(class_indices.get(cls_id, [])) >= (n_shot + n_query) | |
| ] | |
| if len(eligible_classes) < n_way: | |
| n_way = len(eligible_classes) | |
| if n_way < 2: | |
| st.error(f"Episode creation failed: Need at least 2 eligible classes (with {n_shot+n_query} samples each). Found {n_way}.") | |
| return None, None, None, None | |
| selected_class_ids = random.sample(eligible_classes, n_way) | |
| support_imgs, query_imgs = [], [] | |
| support_labels, query_labels = [], [] | |
| episode_class_map = {original_label: episode_label for episode_label, original_label in enumerate(selected_class_ids)} | |
| for original_label in selected_class_ids: | |
| indices_for_class = class_indices.get(original_label, []) | |
| sampled_indices = random.sample(indices_for_class, n_shot + n_query) | |
| try: | |
| items_getter = lambda dataset_obj, index: dataset_obj[index] | |
| support_imgs += [items_getter(dataset, i)[0] for i in sampled_indices[:n_shot]] | |
| query_imgs += [items_getter(dataset, i)[0] for i in sampled_indices[n_shot:]] | |
| except IndexError as e: | |
| st.error(f"IndexError during episode creation for class {original_label}. Sampled: {sampled_indices}. Dataset len: {len(dataset)}.") | |
| st.exception(e) | |
| return None, None, None, None | |
| except Exception as e: | |
| st.error(f"Error retrieving data during episode creation: {e}") | |
| st.exception(e) | |
| return None, None, None, None | |
| new_label = episode_class_map[original_label] | |
| support_labels += [new_label] * n_shot | |
| query_labels += [new_label] * n_query | |
| try: | |
| s_imgs_tensor = torch.stack(support_imgs).to(DEVICE) | |
| s_labels_tensor = torch.tensor(support_labels, dtype=torch.long).to(DEVICE) | |
| q_imgs_tensor = torch.stack(query_imgs).to(DEVICE) | |
| q_labels_tensor = torch.tensor(query_labels, dtype=torch.long).to(DEVICE) | |
| return s_imgs_tensor, s_labels_tensor, q_imgs_tensor, q_labels_tensor | |
| except Exception as e: | |
| st.error(f"Error stacking tensors in create_episode: {e}") | |
| st.exception(e) | |
| return None, None, None, None | |
| # (Original Code - may have issues, see previous discussions if needed) | |
| def proto_loss(support_embeddings, support_labels, query_embeddings, query_labels): | |
| """Calculates the Prototypical Network loss and accuracy.""" | |
| if support_embeddings is None or support_embeddings.numel() == 0 or \ | |
| query_embeddings is None or query_embeddings.numel() == 0: | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| unique_episode_labels = torch.unique(support_labels) | |
| n_way_actual = len(unique_episode_labels) | |
| if n_way_actual < 2: | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| prototypes = [] | |
| # Original loop might have issues if a label has no samples - less robust | |
| for episode_label in range(n_way_actual): # Assumes labels are 0 to n_way_actual-1 | |
| class_mask = (support_labels == episode_label) | |
| if torch.any(class_mask): | |
| class_embeddings = support_embeddings[class_mask] | |
| prototypes.append(class_embeddings.mean(dim=0)) | |
| else: | |
| # Original code didn't explicitly handle this case well | |
| st.warning(f"ProtoLoss (Original): No support embeddings found for episode label {episode_label}. Potential issue.") | |
| # Returning 0 here might be problematic if other prototypes exist | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| if len(prototypes) != n_way_actual: | |
| st.warning(f"ProtoLoss (Original): Mismatch ways ({n_way_actual}) vs prototypes ({len(prototypes)}).") | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| prototypes = torch.stack(prototypes) | |
| # Original code didn't filter query labels based on actual prototypes formed | |
| valid_query_mask = torch.isin(query_labels, unique_episode_labels) | |
| if not torch.any(valid_query_mask): | |
| return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 | |
| filtered_query_embeddings = query_embeddings[valid_query_mask] | |
| filtered_query_labels = query_labels[valid_query_mask] | |
| distances = torch.cdist(filtered_query_embeddings, prototypes) | |
| predictions = torch.argmin(distances, dim=1) | |
| correct_predictions = (predictions == filtered_query_labels).sum().item() # Original comparison might be offset if labels aren't 0..N-1 | |
| total_predictions = filtered_query_labels.size(0) | |
| accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 | |
| loss = F.cross_entropy(-distances, filtered_query_labels) # Original label usage might be offset | |
| return loss, accuracy | |
| # (Original Code) | |
| def calculate_final_prototypes(_model, _dataset, _class_names, _strategy): | |
| if _strategy != 'train_projection': | |
| st.warning(f"calculate_final_prototypes called with unexpected strategy: '{_strategy}'. Proceeding as if 'train_projection'.") | |
| _model.eval() | |
| all_embeddings = {} | |
| loader = DataLoader(_dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True if DEVICE=='cuda' else False) | |
| with torch.no_grad(): | |
| for imgs, labs in loader: | |
| imgs = imgs.to(DEVICE) | |
| try: | |
| emb = _model(imgs) | |
| emb_cpu = emb.cpu() | |
| labs_list = labs.tolist() | |
| for i in range(emb_cpu.size(0)): | |
| label = labs_list[i] | |
| all_embeddings.setdefault(label, []).append(emb_cpu[i]) | |
| except Exception as e: | |
| st.error(f"Error during embedding calculation batch: {e}") | |
| continue | |
| final_prototypes = [] | |
| prototype_labels = [] | |
| unique_labels_present = sorted(list(all_embeddings.keys())) | |
| if not unique_labels_present: | |
| st.warning("No embeddings were generated. Cannot calculate prototypes.") | |
| return None, None | |
| for label in unique_labels_present: | |
| if not (0 <= label < len(_class_names)): | |
| st.warning(f"Skipping label {label} during prototype calculation: Out of bounds for class names list (len={len(_class_names)}).") | |
| continue | |
| class_embeddings_list = all_embeddings[label] | |
| if class_embeddings_list: | |
| try: | |
| class_embeddings = torch.stack(class_embeddings_list) | |
| prototype = class_embeddings.mean(dim=0) | |
| final_prototypes.append(prototype) | |
| prototype_labels.append(label) | |
| except Exception as e: | |
| st.error(f"Error processing embeddings for class {label} ('{_class_names[label]}'): {e}") | |
| continue | |
| if not final_prototypes: | |
| st.warning("Could not calculate any valid final prototypes.") | |
| return None, None | |
| final_prototypes_tensor = torch.stack(final_prototypes).to(DEVICE) | |
| st.success(f"Calculated {len(final_prototypes)} final prototypes (Strategy: {_strategy}) for original labels: {prototype_labels}") | |
| return final_prototypes_tensor, prototype_labels | |
| # --- Object Detection (YOLO) --- | |
| # (Original Code) | |
| def load_yolo_model(): | |
| try: | |
| model = YOLO(YOLO_MODEL_PATH) | |
| return model | |
| except Exception as e: | |
| st.error(f"Failed to load YOLO detection model from {YOLO_MODEL_PATH}: {e}") | |
| st.exception(e) | |
| st.stop() | |
| def detect_objects(image): | |
| model = load_yolo_model() | |
| img_array = np.array(image.convert("RGB")) | |
| img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
| try: | |
| results = model(img_bgr, device=DEVICE) | |
| except Exception as e: | |
| st.error(f"Error during YOLO inference: {e}") | |
| return img_array, pd.DataFrame() | |
| result_image_bgr = results[0].plot(conf=True, labels=True) | |
| result_image_rgb = cv2.cvtColor(result_image_bgr, cv2.COLOR_BGR2RGB) | |
| detections_list = [] | |
| if results[0].boxes is not None: | |
| boxes = results[0].boxes.xyxy.cpu().numpy() | |
| confs = results[0].boxes.conf.cpu().numpy() | |
| cls_ids = results[0].boxes.cls.cpu().numpy().astype(int) | |
| class_names_map = model.names # Use class names from YOLO model | |
| for i in range(len(boxes)): | |
| detections_list.append({ | |
| "Class": class_names_map.get(cls_ids[i], f"ID {cls_ids[i]}"), | |
| "Confidence": confs[i], | |
| "X_min": boxes[i, 0], | |
| "Y_min": boxes[i, 1], | |
| "X_max": boxes[i, 2], | |
| "Y_max": boxes[i, 3], | |
| }) | |
| detections_df = pd.DataFrame(detections_list) | |
| return result_image_rgb, detections_df | |
| # === Main App Logic === | |
| # (Original Code) | |
| st.title("🌿 Coffee Leaf Disease Classifier + Few-Shot Learning + Detection") | |
| # --- Initialize Session State --- | |
| st.session_state.setdefault('few_shot_trained', False) | |
| st.session_state.setdefault('final_prototypes', None) | |
| st.session_state.setdefault('prototype_labels', None) | |
| st.session_state.setdefault('model_mode', 'standard') | |
| st.session_state.setdefault('few_shot_strategy', None) | |
| # --- Load Data --- | |
| # Uses BASE_DATASET (Git) and RARE_DATASET (/data/Rare_data) | |
| combined_dataset, class_indices, class_names, num_base_classes = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| # --- Sidebar --- | |
| st.sidebar.header("⚙️ Options & Status") | |
| # --- Mode Selection / Status --- | |
| st.sidebar.subheader("Mode") | |
| if st.sidebar.button("🔄 Reset to Standard Classifier"): | |
| st.session_state.model_mode = 'standard' | |
| st.session_state.few_shot_trained = False | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.few_shot_strategy = None | |
| st.success("Switched to Standard Classification Mode.") | |
| st.cache_data.clear() | |
| st.cache_resource.clear() | |
| st.rerun() | |
| mode_status = "Standard Classifier" | |
| strategy_info = "" | |
| # Check prototype_labels for length as well | |
| if st.session_state.model_mode == 'few_shot' and \ | |
| st.session_state.final_prototypes is not None and \ | |
| st.session_state.prototype_labels is not None and \ | |
| len(st.session_state.prototype_labels) > 0: | |
| mode_status = f"Few-Shot ({len(st.session_state.prototype_labels)} Prototypes)" # Use label length | |
| strategy_info = f"(Strategy: {st.session_state.get('few_shot_strategy', 'N/A').replace('_', ' ').title()})" | |
| st.sidebar.info(f"**Current Mode:** {mode_status} {strategy_info}") | |
| # --- Load/Delete Saved Few-Shot Models --- | |
| st.sidebar.divider() | |
| st.sidebar.subheader("💾 Saved Few-Shot Models (Persistent)") # Updated title | |
| saved_model_names = list_saved_models() # Reads from /data/... | |
| # --- Loading Section --- | |
| if not saved_model_names: | |
| st.sidebar.info("No saved few-shot models found in persistent storage.") # Updated msg | |
| else: | |
| selected_model_to_load = st.sidebar.selectbox( | |
| "Load a saved few-shot state:", | |
| options=[""] + saved_model_names, | |
| key="load_model_select", | |
| index=0 | |
| ) | |
| if st.sidebar.button("📥 Load Selected State", key="load_model_button", disabled=(not selected_model_to_load)): | |
| if selected_model_to_load: | |
| model_instance = cached_feature_extractor_model() | |
| _, _, current_cls_names_on_load, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| if load_few_shot_state(selected_model_to_load, model_instance, current_cls_names_on_load): # Reads from /data/... | |
| st.rerun() | |
| # --- Deleting Section --- | |
| if saved_model_names: | |
| st.sidebar.markdown("---") | |
| selected_model_to_delete = st.sidebar.selectbox( | |
| "Delete a saved few-shot state:", | |
| options=[""] + saved_model_names, | |
| key="delete_model_select", | |
| index=0 | |
| ) | |
| if selected_model_to_delete: | |
| confirm_delete = st.sidebar.checkbox(f"Confirm deletion of '{selected_model_to_delete}' from persistent storage", key="delete_confirm") # Updated msg | |
| if st.sidebar.button("❌ Delete Selected State", key="delete_model_button", disabled=(not confirm_delete)): | |
| if confirm_delete: | |
| if delete_saved_model(selected_model_to_delete): # Deletes from /data/... | |
| st.rerun() | |
| # --- Main Panel Options --- | |
| option = st.radio( | |
| "Choose an action:", | |
| ["Upload & Predict", "Add/Manage Rare Classes", "Train Few-Shot Model", "Detection"], | |
| horizontal=True, key="main_option" | |
| ) | |
| # Load models (cached) - Load from Git paths | |
| feature_extractor_model = cached_feature_extractor_model() | |
| standard_classifier_model = cached_standard_classifier() | |
| # --- Action Implementation --- | |
| # (Original Code) | |
| if option == "Upload & Predict": | |
| st.header("🔎 Upload Image for Prediction") | |
| uploaded_file = st.file_uploader("Choose a coffee leaf image...", type=["jpg", "jpeg", "png"], key="file_uploader") | |
| if uploaded_file: | |
| try: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded Image", width=300) | |
| input_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| # Determine mode based on session state (which might be loaded from /data/...) | |
| use_few_shot = (st.session_state.model_mode == 'few_shot' and | |
| st.session_state.final_prototypes is not None and | |
| st.session_state.prototype_labels is not None and | |
| st.session_state.few_shot_strategy == 'train_projection' and | |
| st.session_state.final_prototypes.numel() > 0 ) | |
| if use_few_shot: | |
| st.subheader("Prediction using Prototypes") | |
| model_to_use = feature_extractor_model | |
| model_to_use.eval() | |
| strategy_for_pred = st.session_state.few_shot_strategy | |
| st.info(f"Using Few-Shot Strategy: {strategy_for_pred.replace('_', ' ').title()}") # Added info | |
| with torch.no_grad(): | |
| embedding = model_to_use(input_tensor) | |
| prototypes_for_pred = st.session_state.final_prototypes.to(DEVICE) | |
| if embedding.shape[1] != prototypes_for_pred.shape[1]: | |
| st.error(f"Dimension mismatch! Emb: {embedding.shape[1]}, Proto: {prototypes_for_pred.shape[1]}.") | |
| st.stop() | |
| distances = torch.cdist(embedding, prototypes_for_pred) | |
| pred_prototype_index = torch.argmin(distances, dim=1).item() | |
| # Use original label list from session state | |
| predicted_original_label = st.session_state.prototype_labels[pred_prototype_index] | |
| # Get current class names list (base+rare) | |
| _, _, current_class_names_pred, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # Re-fetch current names | |
| if 0 <= predicted_original_label < len(current_class_names_pred): | |
| predicted_class_name = current_class_names_pred[predicted_original_label] | |
| confidence_scores = torch.softmax(-distances, dim=1) | |
| confidence = confidence_scores[0, pred_prototype_index].item() | |
| st.metric(label="Prediction (Prototype)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") | |
| st.info(f"(Matched prototype for class: '{predicted_class_name}' [Orig Label: {predicted_original_label}])") # Added info | |
| else: | |
| st.error(f"Predicted prototype label index {predicted_original_label} out of range for current classes ({len(current_class_names_pred)}).") | |
| else: | |
| st.subheader("Prediction using Standard Classifier") | |
| if st.session_state.model_mode != 'standard': | |
| st.warning("Falling back to Standard Classifier mode.") | |
| model_to_use = standard_classifier_model | |
| model_to_use.eval() | |
| with torch.no_grad(): | |
| outputs = model_to_use(input_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| pred_label = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_label].item() | |
| # Use current class names list (base+rare) and num_base_classes | |
| _, _, current_class_names_pred, num_base_classes_pred = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # Re-fetch | |
| if 0 <= pred_label < num_base_classes_pred: # Compare against num_base_classes | |
| predicted_class_name = current_class_names_pred[pred_label] # Get name from full list | |
| st.metric(label="Prediction (Standard)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") | |
| else: | |
| st.error(f"Standard classifier predicted label {pred_label}, out of range for base classes ({num_base_classes_pred}).") | |
| except Exception as e: | |
| st.error(f"An error occurred during prediction: {e}") | |
| st.exception(e) | |
| elif option == "Detection": | |
| # (Original Code - Uses YOLO model from Git) | |
| st.header("🕵️ Object Detection with YOLO") | |
| uploaded_file_detect = st.file_uploader("Upload an image for detection", type=["jpg", "jpeg", "png"], key="detect_uploader") | |
| if uploaded_file_detect: | |
| try: | |
| image_detect = Image.open(uploaded_file_detect).convert("RGB") | |
| result_image, detections = detect_objects(image_detect) | |
| display_image = Image.fromarray(result_image) | |
| display_image = ImageOps.contain(display_image, (900, 700)) | |
| st.image(display_image, caption="Detection Result", use_container_width=True) | |
| if not detections.empty: | |
| st.subheader("📋 Detection Results:") | |
| detections['Confidence'] = detections['Confidence'].map('{:.1%}'.format) | |
| detections[['X_min', 'Y_min', 'X_max', 'Y_max']] = detections[['X_min', 'Y_min', 'X_max', 'Y_max']].round(1) | |
| st.dataframe(detections[['Class', 'Confidence', 'X_min', 'Y_min', 'X_max', 'Y_max']]) | |
| else: | |
| st.info("No objects detected.") | |
| except Exception as e: | |
| st.error(f"An error occurred during detection: {e}") | |
| st.exception(e) | |
| elif option == "Add/Manage Rare Classes": | |
| # (Original Code - interacts with RARE_DATASET which now points to /data/...) | |
| st.header("➕ Add New Rare Class (to Persistent Storage)") # Updated title | |
| n_shot_req = 2 | |
| n_query_req = 2 | |
| required_samples = n_shot_req + n_query_req | |
| st.write(f"Upload at least **{required_samples}** sample images. Images saved to `{RARE_DATASET}`.") # Show path | |
| with st.form("add_class_form"): | |
| new_class_name = st.text_input("Enter the name for the new rare class:") | |
| uploaded_files_rare = st.file_uploader( | |
| f"Upload {required_samples} or more images:", accept_multiple_files=True, type=["jpg", "jpeg", "png"], key="add_class_uploader" | |
| ) | |
| submitted_add = st.form_submit_button("Add Class") | |
| if submitted_add: | |
| valid = True | |
| if not new_class_name or not new_class_name.strip(): | |
| st.warning("Please enter a valid class name."); valid = False | |
| if len(uploaded_files_rare) < required_samples: | |
| st.warning(f"Please upload at least {required_samples} images."); valid = False | |
| if valid: | |
| sanitized_class_name = "".join(c for c in new_class_name if c.isalnum() or c in (' ', '_')).strip().replace(" ", "_") | |
| if not sanitized_class_name: | |
| st.error("Invalid class name after sanitization.") | |
| else: | |
| new_class_dir = os.path.join(RARE_DATASET, sanitized_class_name) # Path in /data/... | |
| if os.path.exists(new_class_dir): | |
| st.warning(f"Class directory '{sanitized_class_name}' already exists in {RARE_DATASET}.") # Updated msg | |
| else: | |
| try: | |
| os.makedirs(new_class_dir, exist_ok=True) # Create in /data/... | |
| image_save_errors = 0 | |
| for i, file in enumerate(uploaded_files_rare): | |
| try: | |
| img = Image.open(file).convert("RGB") | |
| base, ext = os.path.splitext(file.name) | |
| safe_base = "".join(c for c in base if c.isalnum() or c in ('_', '-')).strip()[:50] | |
| filename = f"{safe_base}_{random.randint(1000, 9999)}_{i+1}.jpg" | |
| save_path = os.path.join(new_class_dir, filename) # Save to /data/... | |
| img.save(save_path, format='JPEG', quality=95) | |
| except Exception as img_e: | |
| st.error(f"Error saving image {i+1} ({file.name}): {img_e}") | |
| image_save_errors += 1 | |
| if image_save_errors == 0: | |
| st.success(f"✅ Added class: '{sanitized_class_name}'. Re-run 'Train Few-Shot Model'.") # Simplified msg | |
| st.cache_data.clear() | |
| # Don't clear resource cache (models) unless needed | |
| # Reset state | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.few_shot_strategy = None | |
| st.session_state.few_shot_trained = False | |
| st.session_state.model_mode = 'standard' | |
| st.rerun() | |
| else: | |
| st.error(f"Failed to save {image_save_errors} images.") | |
| except Exception as e: | |
| st.error(f"Error creating directory or saving images: {e}") | |
| st.exception(e) | |
| st.divider() | |
| st.header("❌ Delete a Rare Class (from Persistent Storage)") # Updated title | |
| try: | |
| if os.path.isdir(RARE_DATASET): # Check /data/Rare_data | |
| rare_class_dirs = [d for d in os.listdir(RARE_DATASET) if os.path.isdir(os.path.join(RARE_DATASET, d))] | |
| if not rare_class_dirs: | |
| st.info(f"No rare classes found to delete in {RARE_DATASET}.") # Updated msg | |
| else: | |
| with st.form("delete_class_form"): | |
| to_delete = st.selectbox("Select rare class to delete:", rare_class_dirs, key="delete_rare_select") | |
| confirm_delete_rare = st.checkbox(f"Confirm deletion of '{to_delete}' from persistent storage?", key="delete_rare_confirm") # Updated msg | |
| delete_submit_rare = st.form_submit_button("Delete Class") | |
| if delete_submit_rare: | |
| if confirm_delete_rare and to_delete: | |
| delete_path = os.path.join(RARE_DATASET, to_delete) # Path in /data/... | |
| try: | |
| shutil.rmtree(delete_path) # Deletes from /data/... | |
| st.success(f"✅ Deleted rare class: {to_delete}") | |
| st.cache_data.clear() | |
| # Reset state | |
| st.session_state.few_shot_trained = False | |
| st.session_state.final_prototypes = None | |
| st.session_state.prototype_labels = None | |
| st.session_state.few_shot_strategy = None | |
| st.session_state.model_mode = 'standard' | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error deleting directory {delete_path}: {e}") | |
| elif not confirm_delete_rare: | |
| st.warning("Please confirm the deletion.") | |
| else: | |
| st.info(f"Rare dataset directory '{RARE_DATASET}' does not exist.") # Updated msg | |
| except Exception as e: | |
| st.error(f"Error listing/deleting rare classes from {RARE_DATASET}: {e}") # Updated msg | |
| st.exception(e) | |
| elif option == "Train Few-Shot Model": | |
| # (Original Code - uses data potentially from /data/..., saves state to /data/...) | |
| st.header("🚀 Train Few-Shot Model") | |
| if len(class_names) < 2: | |
| st.error("Need at least two classes (Base + Rare combined) to perform few-shot training.") | |
| st.stop() | |
| # --- Training Parameters --- | |
| epochs = 10 | |
| n_way_train = len(class_names) # Original: Use all available classes | |
| episodes_per_epoch = 5 | |
| n_shot = 2 | |
| n_query = 2 | |
| learning_rate = 1e-4 | |
| weight_decay_proj = 1e-4 | |
| eligible_classes_check = [ | |
| cls_id for cls_id in class_indices | |
| if len(class_indices.get(cls_id, [])) >= (n_shot + n_query) | |
| ] | |
| if len(eligible_classes_check) < 2: | |
| st.error(f"Need >= 2 classes with {n_shot + n_query} samples. Found {len(eligible_classes_check)}.") | |
| st.stop() | |
| # Adjust n_way based on eligibility if using all classes was intended | |
| if n_way_train > len(eligible_classes_check): | |
| st.warning(f"Adjusting n-way from {n_way_train} to {len(eligible_classes_check)} based on eligible classes.") | |
| n_way_train = len(eligible_classes_check) | |
| # --- Training Form --- | |
| with st.form("train_form"): | |
| submitted_train = st.form_submit_button("Start Few-Shot Training") | |
| if submitted_train: | |
| active_strategy = 'train_projection' | |
| st.info(f"🚀 Starting few-shot process...") | |
| # Re-fetch/confirm data state | |
| current_combined_dataset_train, current_indices_train, current_names_train, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| current_eligible_train = [ | |
| cls_id for cls_id in current_indices_train | |
| if len(current_indices_train.get(cls_id, [])) >= (n_shot + n_query) | |
| ] | |
| if len(current_eligible_train) < 2: | |
| st.error(f"Error before starting: Need >= 2 eligible classes, found {len(current_eligible_train)}."); st.stop() | |
| n_way_for_episode = min(n_way_train, len(current_eligible_train)) # Final check on n_way | |
| if n_way_for_episode < 2: st.error("Error: Cannot create 2-way+ episodes."); st.stop() | |
| progress_bar = st.progress(0) | |
| status_placeholder = st.empty() | |
| chart_placeholder = st.empty() # Keep placeholder | |
| model_train = cached_feature_extractor_model() | |
| model_train.train() | |
| # --- Optimizer Setup (Original) --- | |
| try: | |
| for param in model_train.model.parameters(): param.requires_grad = False | |
| for param in model_train.projection.parameters(): param.requires_grad = True | |
| trainable_params = list(filter(lambda p: p.requires_grad, model_train.parameters())) | |
| if not trainable_params: st.error("No trainable params found!"); st.stop() | |
| except AttributeError as e: st.error(f"Model setup error: {e}"); st.stop() | |
| optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay_proj) | |
| # --- Training Loop (Original) --- | |
| loss_history = [] | |
| accuracy_history = [] | |
| total_steps = epochs * episodes_per_epoch | |
| current_step = 0 | |
| training_successful = False | |
| for epoch in range(epochs): | |
| epoch_loss = 0.0 | |
| epoch_accuracy = 0.0 | |
| valid_episodes_in_epoch = 0 | |
| for episode in range(episodes_per_epoch): | |
| current_step += 1 | |
| s_imgs, s_labels, q_imgs, q_labels = create_episode( | |
| current_combined_dataset_train, current_indices_train, current_names_train, | |
| n_way=n_way_for_episode, n_shot=n_shot, n_query=n_query | |
| ) | |
| if s_imgs is None: continue | |
| try: | |
| s_emb = model_train(s_imgs) | |
| q_emb = model_train(q_imgs) | |
| loss, accuracy = proto_loss(s_emb, s_labels, q_emb, q_labels) # Use original proto_loss | |
| except Exception as model_e: | |
| st.error(f"Model/Loss Error Ep {epoch+1}-{episode+1}: {model_e}"); continue | |
| if loss is not None and not torch.isnan(loss) and loss.requires_grad: | |
| try: | |
| optimizer.zero_grad(); loss.backward(); optimizer.step() | |
| epoch_loss += loss.item(); epoch_accuracy += accuracy; valid_episodes_in_epoch += 1 | |
| except Exception as optim_e: st.error(f"Optim Error Ep {epoch+1}-{episode+1}: {optim_e}") | |
| elif torch.isnan(loss): st.warning(f"NaN Loss Ep {epoch+1}-{episode+1}") | |
| if (episode + 1) % 5 == 0 or episode == episodes_per_epoch - 1: | |
| progress = current_step / total_steps | |
| progress_bar.progress(min(progress, 1.0)) | |
| # Log epoch results (Original logic) | |
| if valid_episodes_in_epoch > 0: | |
| avg_epoch_loss = epoch_loss / valid_episodes_in_epoch | |
| avg_epoch_accuracy = epoch_accuracy / valid_episodes_in_epoch | |
| loss_history.append(avg_epoch_loss) | |
| accuracy_history.append(avg_epoch_accuracy) | |
| status_placeholder.text(f"Epoch {epoch+1}/{epochs} | Loss: {avg_epoch_loss:.4f} | Acc: {avg_epoch_accuracy:.2%}") | |
| else: | |
| status_placeholder.text(f"Epoch {epoch+1}/{epochs} | No valid episodes.") | |
| loss_history.append(float('nan')); accuracy_history.append(float('nan')) | |
| status_placeholder.success("✅ Training Finished!") | |
| # --- Final Prototype Calculation (Original) --- | |
| st.info(f"Calculating final prototypes...") | |
| model_train.eval() | |
| st.cache_data.clear() | |
| final_combined_dataset_proto, _, final_class_names_proto, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| final_prototypes_tensor, final_prototype_labels = calculate_final_prototypes( | |
| model_train, final_combined_dataset_proto, final_class_names_proto, active_strategy | |
| ) | |
| # --- Store results in session state (Original) --- | |
| if final_prototypes_tensor is not None and final_prototype_labels is not None: | |
| st.session_state.final_prototypes = final_prototypes_tensor | |
| st.session_state.prototype_labels = final_prototype_labels | |
| st.session_state.few_shot_strategy = active_strategy | |
| st.session_state.few_shot_trained = True | |
| st.session_state.model_mode = 'few_shot' | |
| st.success(f"Prototypes Calculated.") | |
| training_successful = True | |
| else: | |
| st.session_state.final_prototypes = None; st.session_state.prototype_labels = None | |
| st.session_state.few_shot_strategy = None; st.session_state.few_shot_trained = False | |
| st.session_state.model_mode = 'standard'; st.error("Prototype calculation failed.") | |
| training_successful = False | |
| st.rerun() # Rerun to clear form | |
| # --- SAVING SECTION (Original logic - saves to SAVED_MODELS_DIR which is now /data/...) --- | |
| if st.session_state.get('final_prototypes') is not None and \ | |
| st.session_state.get('model_mode') == 'few_shot' and \ | |
| st.session_state.get('few_shot_strategy') == 'train_projection': | |
| st.divider() | |
| st.subheader("💾 Save Current Few-Shot State (to Persistent Storage)") # Updated title | |
| st.info(f"Saves current state to {SAVED_MODELS_DIR}") # Show path | |
| save_model_name = st.text_input("Enter name:", key="save_model_name_input_main") | |
| if st.button("Save State", key="save_state_button_main"): | |
| if save_model_name: | |
| model_to_save = cached_feature_extractor_model() | |
| model_to_save.eval() | |
| _, _, current_cls_names_for_saving, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) | |
| # save_few_shot_state saves to /data/... | |
| save_successful = save_few_shot_state( | |
| save_model_name, model_to_save, st.session_state.final_prototypes, | |
| st.session_state.prototype_labels, current_cls_names_for_saving, | |
| st.session_state.few_shot_strategy | |
| ) | |
| if save_successful: st.rerun() # Refresh sidebar list | |
| else: | |
| st.warning("Please enter name.") | |
| # --- END OF `elif option == "Train Few-Shot Model":` --- |