Spaces:
Sleeping
Sleeping
| # app/inference_utils.py | |
| import base64 | |
| import tempfile | |
| import uuid | |
| import zipfile | |
| from io import BytesIO | |
| from pathlib import Path | |
| import datetime | |
| from typing import Callable # Add this import | |
| import numpy as np | |
| import torch | |
| from easydict import EasyDict as edict | |
| from PIL import Image | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from torch.utils.data import DataLoader | |
| from cad_retrieval_utils.augmentations import build_img_transforms | |
| from cad_retrieval_utils.datasets import (InferenceImageDataset, | |
| InferenceMeshDataset, | |
| InferenceTextDataset) | |
| from cad_retrieval_utils.evaluation import (get_inference_embeddings_image, | |
| get_inference_embeddings_mesh, | |
| get_inference_embeddings_text) | |
| from cad_retrieval_utils.inference import (load_image_encoder, load_pc_encoder, | |
| load_text_encoder) | |
| from cad_retrieval_utils.models import (ImageEncoder, InferencePcEncoder, | |
| InferenceTextEncoder) | |
| from cad_retrieval_utils.utils import init_environment, load_config | |
| CONFIG: edict = None | |
| IMG_TRANSFORM = None | |
| PC_ENCODER: InferencePcEncoder = None | |
| IMG_ENCODER: ImageEncoder = None | |
| TEXT_ENCODER: InferenceTextEncoder = None | |
| DATASET_CACHE = {} | |
| TOP_K_MATCHES = 5 | |
| def load_models_and_config(config_path: str, model_paths: dict) -> None: | |
| # This function is unchanged | |
| global CONFIG, IMG_TRANSFORM, PC_ENCODER, IMG_ENCODER, TEXT_ENCODER | |
| print("🚀 Загрузка конфигурации и моделей...") | |
| if CONFIG is not None: | |
| print(" Модели уже загружены.") | |
| return | |
| try: | |
| CONFIG = load_config(config_path) | |
| CONFIG.paths.model_spec = model_paths | |
| init_environment(CONFIG) | |
| PC_ENCODER = load_pc_encoder(CONFIG.paths.model_spec, CONFIG) | |
| IMG_ENCODER = load_image_encoder(CONFIG.paths.model_spec, CONFIG) | |
| TEXT_ENCODER = load_text_encoder(CONFIG.paths.model_spec, CONFIG) | |
| IMG_TRANSFORM = build_img_transforms(CONFIG.img_size) | |
| print("✅ Все модели успешно загружены в память.") | |
| except Exception as e: | |
| print(f"🔥 Критическая ошибка при загрузке моделей: {e}") | |
| raise | |
| def get_embedding_for_single_item(modality: str, content_bytes: bytes) -> np.ndarray: | |
| # This function is unchanged | |
| if modality == "image": | |
| image = Image.open(BytesIO(content_bytes)).convert("RGB") | |
| tensor = IMG_TRANSFORM(image).unsqueeze(0).to(CONFIG.device) | |
| emb = IMG_ENCODER.encode_image(tensor, normalize=True) | |
| return emb.cpu().numpy() | |
| if modality == "text": | |
| text = content_bytes.decode("utf-8") | |
| emb = TEXT_ENCODER.encode_text([text], normalize=True) | |
| return emb.cpu().numpy() | |
| if modality == "mesh": | |
| with tempfile.NamedTemporaryFile(suffix=".stl", delete=True) as tmp: | |
| tmp.write(content_bytes) | |
| tmp.flush() | |
| dataset = InferenceMeshDataset([tmp.name], CONFIG.npoints, CONFIG.seed) | |
| tensor = dataset[0].unsqueeze(0).to(CONFIG.device) | |
| emb = PC_ENCODER.encode_pc(tensor, normalize=True) | |
| return emb.cpu().numpy() | |
| raise ValueError(f"Неизвестная модальность: {modality}") | |
| def process_uploaded_zip( | |
| zip_file_bytes: bytes, | |
| original_filename: str, | |
| update_status: Callable[[str, int], None] | |
| ) -> dict: | |
| """ | |
| Основная функция для обработки ZIP-архива с обратными вызовами для обновления статуса. | |
| """ | |
| dataset_id = str(uuid.uuid4()) | |
| print(f"⚙️ Начало обработки нового датасета: {original_filename} (ID: {dataset_id})") | |
| update_status("Starting", 0) | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| tmp_path = Path(tmpdir) | |
| zip_path = tmp_path / "data.zip" | |
| zip_path.write_bytes(zip_file_bytes) | |
| update_status("Unpacking Files", 5) | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(tmp_path) | |
| print(f" 🗂️ Архив распакован в {tmpdir}") | |
| update_status("Preparing Data", 10) | |
| image_paths = sorted(list(tmp_path.glob("**/*.png"))) | |
| text_paths = sorted(list(tmp_path.glob("**/*.txt"))) | |
| mesh_paths = sorted(list(tmp_path.glob("**/*.stl"))) | |
| image_ds = InferenceImageDataset([str(p) for p in image_paths], IMG_TRANSFORM) | |
| text_ds = InferenceTextDataset([str(p) for p in text_paths]) | |
| mesh_ds = InferenceMeshDataset([str(p) for p in mesh_paths], CONFIG.npoints, CONFIG.seed) | |
| image_loader = DataLoader(image_ds, batch_size=CONFIG.infer_img_batch_size, shuffle=False) | |
| text_loader = DataLoader(text_ds, batch_size=CONFIG.infer_text_batch_size, shuffle=False) | |
| mesh_loader = DataLoader(mesh_ds, batch_size=CONFIG.infer_pc_batch_size, shuffle=False) | |
| print(" 🧠 Вычисление эмбеддингов...") | |
| update_status("Processing Images", 15) | |
| image_embs = get_inference_embeddings_image(IMG_ENCODER, image_loader, CONFIG) | |
| update_status("Processing Texts", 50) | |
| text_embs = get_inference_embeddings_text(TEXT_ENCODER, text_loader, CONFIG) | |
| update_status("Processing 3D Models", 55) | |
| mesh_embs = get_inference_embeddings_mesh(PC_ENCODER, mesh_loader, CONFIG) | |
| print(" ✅ Эмбеддинги вычислены.") | |
| update_status("Caching Data", 90) | |
| image_names = [p.name for p in image_paths] | |
| text_names = [p.name for p in text_paths] | |
| mesh_names = [p.name for p in mesh_paths] | |
| image_items = [{"id": f"image_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(image_paths, image_names))] | |
| text_items = [{"id": f"text_{i}", "name": name, "content": p.read_text()} for i, (p, name) in enumerate(zip(text_paths, text_names))] | |
| mesh_items = [{"id": f"mesh_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(mesh_paths, mesh_names))] | |
| dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items} | |
| DATASET_CACHE[dataset_id] = { | |
| "data": dataset_data, | |
| "embeddings": { | |
| "image": (image_names, image_embs), | |
| "text": (text_names, text_embs), | |
| "mesh": (mesh_names, mesh_embs) | |
| } | |
| } | |
| print(f" 💾 Датасет {dataset_id} сохранен в кэш.") | |
| print(" ⚖️ Вычисление полной матрицы схожести...") | |
| update_status("Building Matrix", 95) | |
| full_comparison = {"images": [], "texts": [], "meshes": []} | |
| all_embeddings = { | |
| "image": (image_names, image_embs), | |
| "text": (text_names, text_embs), | |
| "mesh": (mesh_names, mesh_embs) | |
| } | |
| for source_modality, (source_names, source_embs) in all_embeddings.items(): | |
| for i, source_name in enumerate(source_names): | |
| source_emb = source_embs[i:i+1] | |
| matches = {} | |
| for target_modality, (target_names, target_embs) in all_embeddings.items(): | |
| if not target_names: continue | |
| sims = cosine_similarity(source_emb, target_embs).flatten() | |
| if source_modality == target_modality: | |
| sims[i] = -1 | |
| top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] | |
| matches[target_modality] = [ | |
| {"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1 | |
| ] | |
| key_name = "meshes" if source_modality == "mesh" else source_modality + 's' | |
| full_comparison[key_name].append({"source": source_name, "matches": matches}) | |
| print(" ✅ Матрица схожести готова.") | |
| final_response = { | |
| "id": dataset_id, | |
| "name": original_filename, | |
| "uploadDate": datetime.datetime.utcnow().isoformat() + "Z", | |
| "data": dataset_data, | |
| "processingState": "processed", | |
| "processingProgress": 100, | |
| "fullComparison": full_comparison | |
| } | |
| print(f"✅ Обработка датасета {dataset_id} завершена.") | |
| return final_response | |
| def process_shared_dataset_directory(directory_path: Path, embeddings_path: Path, dataset_id: str, dataset_name: str) -> dict: | |
| # This function is unchanged | |
| print(f"⚙️ Начало обработки общего датасета: {dataset_name} (ID: {dataset_id})") | |
| print(" 📂 Сканирование файлов данных...") | |
| image_paths = sorted(list(directory_path.glob("**/*.png"))) | |
| text_paths = sorted(list(directory_path.glob("**/*.txt"))) | |
| mesh_paths = sorted(list(directory_path.glob("**/*.stl"))) | |
| if not any([image_paths, text_paths, mesh_paths]): | |
| print(f"⚠️ В директории общего датасета '{directory_path}' не найдено файлов.") | |
| return None | |
| print(f" ✅ Найдено: {len(image_paths)} изображений, {len(text_paths)} текстов, {len(mesh_paths)} моделей.") | |
| print(" 🧠 Индексирование предварительно вычисленных эмбеддингов...") | |
| all_embedding_paths = list(embeddings_path.glob("**/*.npy")) | |
| embedding_map = {p.stem: p for p in all_embedding_paths} | |
| print(f" ✅ Найдено {len(embedding_map)} файлов эмбеддингов.") | |
| def load_embeddings_for_paths(data_paths: list[Path]): | |
| names = [] | |
| embs_list = [] | |
| for data_path in data_paths: | |
| file_stem = data_path.stem | |
| if file_stem in embedding_map: | |
| embedding_path = embedding_map[file_stem] | |
| try: | |
| emb = np.load(embedding_path) | |
| embs_list.append(emb) | |
| names.append(data_path.name) | |
| except Exception as e: | |
| print(f" ⚠️ Не удалось загрузить или разобрать эмбеддинг для {data_path.name}: {e}") | |
| else: | |
| print(f" ⚠️ Внимание: не найден соответствующий эмбеддинг для {data_path.name}") | |
| return names, np.array(embs_list) if embs_list else np.array([]) | |
| print(" 🚚 Загрузка и сопоставление эмбеддингов...") | |
| image_names, image_embs = load_embeddings_for_paths(image_paths) | |
| text_names, text_embs = load_embeddings_for_paths(text_paths) | |
| mesh_names, mesh_embs = load_embeddings_for_paths(mesh_paths) | |
| print(" ✅ Эмбеддинги для общего датасета загружены.") | |
| static_root = Path("static") | |
| image_items = [{"id": f"image_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(image_paths)] | |
| text_items = [{"id": f"text_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(text_paths)] | |
| mesh_items = [{"id": f"mesh_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(mesh_paths)] | |
| dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items} | |
| DATASET_CACHE[dataset_id] = {"data": dataset_data, "embeddings": {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)}} | |
| print(f" 💾 Эмбеддинги для общего датасета {dataset_id} сохранены в кэш.") | |
| print(" ⚖️ Вычисление полной матрицы схожести для общего датасета...") | |
| full_comparison = {"images": [], "texts": [], "meshes": []} | |
| all_embeddings = {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)} | |
| for source_modality, (source_names, source_embs) in all_embeddings.items(): | |
| if len(source_names) == 0: continue | |
| for i, source_name in enumerate(source_names): | |
| source_emb = source_embs[i:i+1] | |
| matches = {} | |
| for target_modality, (target_names, target_embs) in all_embeddings.items(): | |
| if len(target_names) == 0: continue | |
| sims = cosine_similarity(source_emb, target_embs).flatten() | |
| if source_modality == target_modality: | |
| sims[i] = -1 | |
| top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] | |
| matches[target_modality] = [{"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1] | |
| key_name = "meshes" if source_modality == "mesh" else source_modality + 's' | |
| full_comparison[key_name].append({"source": source_name, "matches": matches}) | |
| print(" ✅ Матрица схожести для общего датасета готова.") | |
| try: | |
| creation_time = datetime.datetime.fromtimestamp(directory_path.stat().st_ctime) | |
| except Exception: | |
| creation_time = datetime.datetime.utcnow() | |
| final_response = {"id": dataset_id, "name": dataset_name, "uploadDate": creation_time.isoformat() + "Z", "data": dataset_data, "processingState": "processed", "processingProgress": 100, "fullComparison": full_comparison, "isShared": True} | |
| print(f"✅ Обработка общего датасета {dataset_id} завершена.") | |
| return final_response | |
| def find_matches_for_item(modality: str, content_base64: str, dataset_id: str) -> dict: | |
| # This function is unchanged | |
| print(f"🔍 Поиск совпадений для объекта ({modality}) в датасете {dataset_id}...") | |
| if dataset_id not in DATASET_CACHE: | |
| raise ValueError(f"Датасет с ID {dataset_id} не найден в кэше.") | |
| content_bytes = base64.b64decode(content_base64) | |
| source_emb = get_embedding_for_single_item(modality, content_bytes) | |
| cached_dataset = DATASET_CACHE[dataset_id] | |
| results = {} | |
| for target_modality, (target_names, target_embs) in cached_dataset["embeddings"].items(): | |
| key_name = "meshes" if target_modality == "mesh" else target_modality + 's' | |
| if not target_names: continue | |
| sims = cosine_similarity(source_emb, target_embs).flatten() | |
| top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES] | |
| target_items_map = {item['name']: item for item in cached_dataset['data'][key_name]} | |
| matches = [] | |
| for j in top_indices: | |
| item_name = target_names[j] | |
| if item_name in target_items_map: | |
| matches.append({"item": target_items_map[item_name], "confidence": float(sims[j])}) | |
| results[key_name] = matches | |
| print(" ✅ Поиск завершен.") | |
| return {"results": results} | |
| def cache_local_dataset(dataset: dict) -> None: | |
| """ | |
| Receives a full dataset object from the frontend, computes embeddings, | |
| and loads it into the in-memory cache. | |
| """ | |
| dataset_id = dataset.get('id') | |
| if not dataset_id: | |
| print("⚠️ Attempted to cache a dataset without an ID.") | |
| return | |
| if dataset_id in DATASET_CACHE: | |
| print(f"✅ Dataset {dataset_id} is already in the backend cache. Skipping re-hydration.") | |
| return | |
| print(f"🧠 Re-hydrating backend cache for local dataset ID: {dataset_id}") | |
| try: | |
| all_embeddings = {} | |
| all_names = {} | |
| # The content comes in different formats (data URL for images, text for text, etc.) | |
| # We need to decode it before sending to the embedding function. | |
| def get_bytes_from_content(content_str: str, modality: str) -> bytes: | |
| if modality in ['image', 'mesh']: | |
| # Handle data URLs (e.g., "data:image/png;base64,...") or raw base64 | |
| if ',' in content_str: | |
| header, encoded = content_str.split(',', 1) | |
| return base64.b64decode(encoded) | |
| else: | |
| return base64.b64decode(content_str) | |
| else: # text | |
| return content_str.encode('utf-8') | |
| for modality_plural, items in dataset.get('data', {}).items(): | |
| modality_singular = "mesh" if modality_plural == "meshes" else modality_plural[:-1] | |
| names = [] | |
| embs_list = [] | |
| print(f" ⚙️ Processing {len(items)} items for modality: {modality_singular}") | |
| for item in items: | |
| item_content = item.get('content') | |
| if not item_content: | |
| continue | |
| content_bytes = get_bytes_from_content(item_content, modality_singular) | |
| embedding = get_embedding_for_single_item(modality_singular, content_bytes) | |
| embs_list.append(embedding[0]) # get_embedding returns shape (1, D) | |
| names.append(item.get('name')) | |
| all_names[modality_singular] = names | |
| all_embeddings[modality_singular] = np.array(embs_list) if embs_list else np.array([]) | |
| # Structure the cache entry exactly like process_uploaded_zip does | |
| DATASET_CACHE[dataset_id] = { | |
| "data": dataset.get('data'), | |
| "embeddings": { | |
| mod: (all_names[mod], all_embeddings[mod]) for mod in all_embeddings | |
| } | |
| } | |
| print(f" ✅ Successfully cached {dataset_id} with embeddings.") | |
| except Exception as e: | |
| print(f"🔥 CRITICAL ERROR while re-hydrating cache for {dataset_id}: {e}") | |
| import traceback | |
| traceback.print_exc() |