# app/inference_utils.py import base64 import tempfile import uuid import zipfile from io import BytesIO from pathlib import Path import datetime from typing import Callable # Add this import import numpy as np import torch from easydict import EasyDict as edict from PIL import Image from sklearn.metrics.pairwise import cosine_similarity from torch.utils.data import DataLoader from cad_retrieval_utils.augmentations import build_img_transforms from cad_retrieval_utils.datasets import (InferenceImageDataset, InferenceMeshDataset, InferenceTextDataset) from cad_retrieval_utils.evaluation import (get_inference_embeddings_image, get_inference_embeddings_mesh, get_inference_embeddings_text) from cad_retrieval_utils.inference import (load_image_encoder, load_pc_encoder, load_text_encoder) from cad_retrieval_utils.models import (ImageEncoder, InferencePcEncoder, InferenceTextEncoder) from cad_retrieval_utils.utils import init_environment, load_config CONFIG: edict = None IMG_TRANSFORM = None PC_ENCODER: InferencePcEncoder = None IMG_ENCODER: ImageEncoder = None TEXT_ENCODER: InferenceTextEncoder = None DATASET_CACHE = {} TOP_K_MATCHES = 5 def load_models_and_config(config_path: str, model_paths: dict) -> None: # This function is unchanged global CONFIG, IMG_TRANSFORM, PC_ENCODER, IMG_ENCODER, TEXT_ENCODER print("🚀 Загрузка конфигурации и моделей...") if CONFIG is not None: print(" Модели уже загружены.") return try: CONFIG = load_config(config_path) CONFIG.paths.model_spec = model_paths init_environment(CONFIG) PC_ENCODER = load_pc_encoder(CONFIG.paths.model_spec, CONFIG) IMG_ENCODER = load_image_encoder(CONFIG.paths.model_spec, CONFIG) TEXT_ENCODER = load_text_encoder(CONFIG.paths.model_spec, CONFIG) IMG_TRANSFORM = build_img_transforms(CONFIG.img_size) print("✅ Все модели успешно загружены в память.") except Exception as e: print(f"🔥 Критическая ошибка при загрузке моделей: {e}") raise @torch.no_grad() def get_embedding_for_single_item(modality: str, content_bytes: bytes) -> np.ndarray: # This function is unchanged if modality == "image": image = Image.open(BytesIO(content_bytes)).convert("RGB") tensor = IMG_TRANSFORM(image).unsqueeze(0).to(CONFIG.device) emb = IMG_ENCODER.encode_image(tensor, normalize=True) return emb.cpu().numpy() if modality == "text": text = content_bytes.decode("utf-8") emb = TEXT_ENCODER.encode_text([text], normalize=True) return emb.cpu().numpy() if modality == "mesh": with tempfile.NamedTemporaryFile(suffix=".stl", delete=True) as tmp: tmp.write(content_bytes) tmp.flush() dataset = InferenceMeshDataset([tmp.name], CONFIG.npoints, CONFIG.seed) tensor = dataset[0].unsqueeze(0).to(CONFIG.device) emb = PC_ENCODER.encode_pc(tensor, normalize=True) return emb.cpu().numpy() raise ValueError(f"Неизвестная модальность: {modality}") def process_uploaded_zip( zip_file_bytes: bytes, original_filename: str, update_status: Callable[[str, int], None] ) -> dict: """ Основная функция для обработки ZIP-архива с обратными вызовами для обновления статуса. """ dataset_id = str(uuid.uuid4()) print(f"⚙️ Начало обработки нового датасета: {original_filename} (ID: {dataset_id})") update_status("Starting", 0) with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) zip_path = tmp_path / "data.zip" zip_path.write_bytes(zip_file_bytes) update_status("Unpacking Files", 5) with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(tmp_path) print(f" 🗂️ Архив распакован в {tmpdir}") update_status("Preparing Data", 10) image_paths = sorted(list(tmp_path.glob("**/*.png"))) text_paths = sorted(list(tmp_path.glob("**/*.txt"))) mesh_paths = sorted(list(tmp_path.glob("**/*.stl"))) image_ds = InferenceImageDataset([str(p) for p in image_paths], IMG_TRANSFORM) text_ds = InferenceTextDataset([str(p) for p in text_paths]) mesh_ds = InferenceMeshDataset([str(p) for p in mesh_paths], CONFIG.npoints, CONFIG.seed) image_loader = DataLoader(image_ds, batch_size=CONFIG.infer_img_batch_size, shuffle=False) text_loader = DataLoader(text_ds, batch_size=CONFIG.infer_text_batch_size, shuffle=False) mesh_loader = DataLoader(mesh_ds, batch_size=CONFIG.infer_pc_batch_size, shuffle=False) print(" 🧠 Вычисление эмбеддингов...") update_status("Processing Images", 15) image_embs = get_inference_embeddings_image(IMG_ENCODER, image_loader, CONFIG) update_status("Processing Texts", 50) text_embs = get_inference_embeddings_text(TEXT_ENCODER, text_loader, CONFIG) update_status("Processing 3D Models", 55) mesh_embs = get_inference_embeddings_mesh(PC_ENCODER, mesh_loader, CONFIG) print(" ✅ Эмбеддинги вычислены.") update_status("Caching Data", 90) image_names = [p.name for p in image_paths] text_names = [p.name for p in text_paths] mesh_names = [p.name for p in mesh_paths] image_items = [{"id": f"image_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(image_paths, image_names))] text_items = [{"id": f"text_{i}", "name": name, "content": p.read_text()} for i, (p, name) in enumerate(zip(text_paths, text_names))] mesh_items = [{"id": f"mesh_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(mesh_paths, mesh_names))] dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items} DATASET_CACHE[dataset_id] = { "data": dataset_data, "embeddings": { "image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs) } } print(f" 💾 Датасет {dataset_id} сохранен в кэш.") print(" ⚖️ Вычисление полной матрицы схожести...") update_status("Building Matrix", 95) full_comparison = {"images": [], "texts": [], "meshes": []} all_embeddings = { "image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs) } for source_modality, (source_names, source_embs) in all_embeddings.items(): for i, source_name in enumerate(source_names): source_emb = source_embs[i:i+1] matches = {} for target_modality, (target_names, target_embs) in all_embeddings.items(): if not target_names: continue sims = cosine_similarity(source_emb, target_embs).flatten() if source_modality == target_modality: sims[i] = -1 top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] matches[target_modality] = [ {"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1 ] key_name = "meshes" if source_modality == "mesh" else source_modality + 's' full_comparison[key_name].append({"source": source_name, "matches": matches}) print(" ✅ Матрица схожести готова.") final_response = { "id": dataset_id, "name": original_filename, "uploadDate": datetime.datetime.utcnow().isoformat() + "Z", "data": dataset_data, "processingState": "processed", "processingProgress": 100, "fullComparison": full_comparison } print(f"✅ Обработка датасета {dataset_id} завершена.") return final_response def process_shared_dataset_directory(directory_path: Path, embeddings_path: Path, dataset_id: str, dataset_name: str) -> dict: # This function is unchanged print(f"⚙️ Начало обработки общего датасета: {dataset_name} (ID: {dataset_id})") print(" 📂 Сканирование файлов данных...") image_paths = sorted(list(directory_path.glob("**/*.png"))) text_paths = sorted(list(directory_path.glob("**/*.txt"))) mesh_paths = sorted(list(directory_path.glob("**/*.stl"))) if not any([image_paths, text_paths, mesh_paths]): print(f"⚠️ В директории общего датасета '{directory_path}' не найдено файлов.") return None print(f" ✅ Найдено: {len(image_paths)} изображений, {len(text_paths)} текстов, {len(mesh_paths)} моделей.") print(" 🧠 Индексирование предварительно вычисленных эмбеддингов...") all_embedding_paths = list(embeddings_path.glob("**/*.npy")) embedding_map = {p.stem: p for p in all_embedding_paths} print(f" ✅ Найдено {len(embedding_map)} файлов эмбеддингов.") def load_embeddings_for_paths(data_paths: list[Path]): names = [] embs_list = [] for data_path in data_paths: file_stem = data_path.stem if file_stem in embedding_map: embedding_path = embedding_map[file_stem] try: emb = np.load(embedding_path) embs_list.append(emb) names.append(data_path.name) except Exception as e: print(f" ⚠️ Не удалось загрузить или разобрать эмбеддинг для {data_path.name}: {e}") else: print(f" ⚠️ Внимание: не найден соответствующий эмбеддинг для {data_path.name}") return names, np.array(embs_list) if embs_list else np.array([]) print(" 🚚 Загрузка и сопоставление эмбеддингов...") image_names, image_embs = load_embeddings_for_paths(image_paths) text_names, text_embs = load_embeddings_for_paths(text_paths) mesh_names, mesh_embs = load_embeddings_for_paths(mesh_paths) print(" ✅ Эмбеддинги для общего датасета загружены.") static_root = Path("static") image_items = [{"id": f"image_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(image_paths)] text_items = [{"id": f"text_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(text_paths)] mesh_items = [{"id": f"mesh_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(mesh_paths)] dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items} DATASET_CACHE[dataset_id] = {"data": dataset_data, "embeddings": {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)}} print(f" 💾 Эмбеддинги для общего датасета {dataset_id} сохранены в кэш.") print(" ⚖️ Вычисление полной матрицы схожести для общего датасета...") full_comparison = {"images": [], "texts": [], "meshes": []} all_embeddings = {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)} for source_modality, (source_names, source_embs) in all_embeddings.items(): if len(source_names) == 0: continue for i, source_name in enumerate(source_names): source_emb = source_embs[i:i+1] matches = {} for target_modality, (target_names, target_embs) in all_embeddings.items(): if len(target_names) == 0: continue sims = cosine_similarity(source_emb, target_embs).flatten() if source_modality == target_modality: sims[i] = -1 top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] matches[target_modality] = [{"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1] key_name = "meshes" if source_modality == "mesh" else source_modality + 's' full_comparison[key_name].append({"source": source_name, "matches": matches}) print(" ✅ Матрица схожести для общего датасета готова.") try: creation_time = datetime.datetime.fromtimestamp(directory_path.stat().st_ctime) except Exception: creation_time = datetime.datetime.utcnow() final_response = {"id": dataset_id, "name": dataset_name, "uploadDate": creation_time.isoformat() + "Z", "data": dataset_data, "processingState": "processed", "processingProgress": 100, "fullComparison": full_comparison, "isShared": True} print(f"✅ Обработка общего датасета {dataset_id} завершена.") return final_response def find_matches_for_item(modality: str, content_base64: str, dataset_id: str) -> dict: # This function is unchanged print(f"🔍 Поиск совпадений для объекта ({modality}) в датасете {dataset_id}...") if dataset_id not in DATASET_CACHE: raise ValueError(f"Датасет с ID {dataset_id} не найден в кэше.") content_bytes = base64.b64decode(content_base64) source_emb = get_embedding_for_single_item(modality, content_bytes) cached_dataset = DATASET_CACHE[dataset_id] results = {} for target_modality, (target_names, target_embs) in cached_dataset["embeddings"].items(): key_name = "meshes" if target_modality == "mesh" else target_modality + 's' if not target_names: continue sims = cosine_similarity(source_emb, target_embs).flatten() top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] target_items_map = {item['name']: item for item in cached_dataset['data'][key_name]} matches = [] for j in top_indices: item_name = target_names[j] if item_name in target_items_map: matches.append({"item": target_items_map[item_name], "confidence": float(sims[j])}) results[key_name] = matches print(" ✅ Поиск завершен.") return {"results": results} def cache_local_dataset(dataset: dict) -> None: """ Receives a full dataset object from the frontend, computes embeddings, and loads it into the in-memory cache. """ dataset_id = dataset.get('id') if not dataset_id: print("⚠️ Attempted to cache a dataset without an ID.") return if dataset_id in DATASET_CACHE: print(f"✅ Dataset {dataset_id} is already in the backend cache. Skipping re-hydration.") return print(f"🧠 Re-hydrating backend cache for local dataset ID: {dataset_id}") try: all_embeddings = {} all_names = {} # The content comes in different formats (data URL for images, text for text, etc.) # We need to decode it before sending to the embedding function. def get_bytes_from_content(content_str: str, modality: str) -> bytes: if modality in ['image', 'mesh']: # Handle data URLs (e.g., "data:image/png;base64,...") or raw base64 if ',' in content_str: header, encoded = content_str.split(',', 1) return base64.b64decode(encoded) else: return base64.b64decode(content_str) else: # text return content_str.encode('utf-8') for modality_plural, items in dataset.get('data', {}).items(): modality_singular = "mesh" if modality_plural == "meshes" else modality_plural[:-1] names = [] embs_list = [] print(f" ⚙️ Processing {len(items)} items for modality: {modality_singular}") for item in items: item_content = item.get('content') if not item_content: continue content_bytes = get_bytes_from_content(item_content, modality_singular) embedding = get_embedding_for_single_item(modality_singular, content_bytes) embs_list.append(embedding[0]) # get_embedding returns shape (1, D) names.append(item.get('name')) all_names[modality_singular] = names all_embeddings[modality_singular] = np.array(embs_list) if embs_list else np.array([]) # Structure the cache entry exactly like process_uploaded_zip does DATASET_CACHE[dataset_id] = { "data": dataset.get('data'), "embeddings": { mod: (all_names[mod], all_embeddings[mod]) for mod in all_embeddings } } print(f" ✅ Successfully cached {dataset_id} with embeddings.") except Exception as e: print(f"🔥 CRITICAL ERROR while re-hydrating cache for {dataset_id}: {e}") import traceback traceback.print_exc()