project-demo / backend /inference_utils.py
Шатурный Алексей Давыдович
add files
0269f70
# app/inference_utils.py
import base64
import tempfile
import uuid
import zipfile
from io import BytesIO
from pathlib import Path
import datetime
from typing import Callable # Add this import
import numpy as np
import torch
from easydict import EasyDict as edict
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import DataLoader
from cad_retrieval_utils.augmentations import build_img_transforms
from cad_retrieval_utils.datasets import (InferenceImageDataset,
InferenceMeshDataset,
InferenceTextDataset)
from cad_retrieval_utils.evaluation import (get_inference_embeddings_image,
get_inference_embeddings_mesh,
get_inference_embeddings_text)
from cad_retrieval_utils.inference import (load_image_encoder, load_pc_encoder,
load_text_encoder)
from cad_retrieval_utils.models import (ImageEncoder, InferencePcEncoder,
InferenceTextEncoder)
from cad_retrieval_utils.utils import init_environment, load_config
CONFIG: edict = None
IMG_TRANSFORM = None
PC_ENCODER: InferencePcEncoder = None
IMG_ENCODER: ImageEncoder = None
TEXT_ENCODER: InferenceTextEncoder = None
DATASET_CACHE = {}
TOP_K_MATCHES = 5
def load_models_and_config(config_path: str, model_paths: dict) -> None:
# This function is unchanged
global CONFIG, IMG_TRANSFORM, PC_ENCODER, IMG_ENCODER, TEXT_ENCODER
print("🚀 Загрузка конфигурации и моделей...")
if CONFIG is not None:
print(" Модели уже загружены.")
return
try:
CONFIG = load_config(config_path)
CONFIG.paths.model_spec = model_paths
init_environment(CONFIG)
PC_ENCODER = load_pc_encoder(CONFIG.paths.model_spec, CONFIG)
IMG_ENCODER = load_image_encoder(CONFIG.paths.model_spec, CONFIG)
TEXT_ENCODER = load_text_encoder(CONFIG.paths.model_spec, CONFIG)
IMG_TRANSFORM = build_img_transforms(CONFIG.img_size)
print("✅ Все модели успешно загружены в память.")
except Exception as e:
print(f"🔥 Критическая ошибка при загрузке моделей: {e}")
raise
@torch.no_grad()
def get_embedding_for_single_item(modality: str, content_bytes: bytes) -> np.ndarray:
# This function is unchanged
if modality == "image":
image = Image.open(BytesIO(content_bytes)).convert("RGB")
tensor = IMG_TRANSFORM(image).unsqueeze(0).to(CONFIG.device)
emb = IMG_ENCODER.encode_image(tensor, normalize=True)
return emb.cpu().numpy()
if modality == "text":
text = content_bytes.decode("utf-8")
emb = TEXT_ENCODER.encode_text([text], normalize=True)
return emb.cpu().numpy()
if modality == "mesh":
with tempfile.NamedTemporaryFile(suffix=".stl", delete=True) as tmp:
tmp.write(content_bytes)
tmp.flush()
dataset = InferenceMeshDataset([tmp.name], CONFIG.npoints, CONFIG.seed)
tensor = dataset[0].unsqueeze(0).to(CONFIG.device)
emb = PC_ENCODER.encode_pc(tensor, normalize=True)
return emb.cpu().numpy()
raise ValueError(f"Неизвестная модальность: {modality}")
def process_uploaded_zip(
zip_file_bytes: bytes,
original_filename: str,
update_status: Callable[[str, int], None]
) -> dict:
"""
Основная функция для обработки ZIP-архива с обратными вызовами для обновления статуса.
"""
dataset_id = str(uuid.uuid4())
print(f"⚙️ Начало обработки нового датасета: {original_filename} (ID: {dataset_id})")
update_status("Starting", 0)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
zip_path = tmp_path / "data.zip"
zip_path.write_bytes(zip_file_bytes)
update_status("Unpacking Files", 5)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_path)
print(f" 🗂️ Архив распакован в {tmpdir}")
update_status("Preparing Data", 10)
image_paths = sorted(list(tmp_path.glob("**/*.png")))
text_paths = sorted(list(tmp_path.glob("**/*.txt")))
mesh_paths = sorted(list(tmp_path.glob("**/*.stl")))
image_ds = InferenceImageDataset([str(p) for p in image_paths], IMG_TRANSFORM)
text_ds = InferenceTextDataset([str(p) for p in text_paths])
mesh_ds = InferenceMeshDataset([str(p) for p in mesh_paths], CONFIG.npoints, CONFIG.seed)
image_loader = DataLoader(image_ds, batch_size=CONFIG.infer_img_batch_size, shuffle=False)
text_loader = DataLoader(text_ds, batch_size=CONFIG.infer_text_batch_size, shuffle=False)
mesh_loader = DataLoader(mesh_ds, batch_size=CONFIG.infer_pc_batch_size, shuffle=False)
print(" 🧠 Вычисление эмбеддингов...")
update_status("Processing Images", 15)
image_embs = get_inference_embeddings_image(IMG_ENCODER, image_loader, CONFIG)
update_status("Processing Texts", 50)
text_embs = get_inference_embeddings_text(TEXT_ENCODER, text_loader, CONFIG)
update_status("Processing 3D Models", 55)
mesh_embs = get_inference_embeddings_mesh(PC_ENCODER, mesh_loader, CONFIG)
print(" ✅ Эмбеддинги вычислены.")
update_status("Caching Data", 90)
image_names = [p.name for p in image_paths]
text_names = [p.name for p in text_paths]
mesh_names = [p.name for p in mesh_paths]
image_items = [{"id": f"image_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(image_paths, image_names))]
text_items = [{"id": f"text_{i}", "name": name, "content": p.read_text()} for i, (p, name) in enumerate(zip(text_paths, text_names))]
mesh_items = [{"id": f"mesh_{i}", "name": name, "content": base64.b64encode(p.read_bytes()).decode('utf-8')} for i, (p, name) in enumerate(zip(mesh_paths, mesh_names))]
dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items}
DATASET_CACHE[dataset_id] = {
"data": dataset_data,
"embeddings": {
"image": (image_names, image_embs),
"text": (text_names, text_embs),
"mesh": (mesh_names, mesh_embs)
}
}
print(f" 💾 Датасет {dataset_id} сохранен в кэш.")
print(" ⚖️ Вычисление полной матрицы схожести...")
update_status("Building Matrix", 95)
full_comparison = {"images": [], "texts": [], "meshes": []}
all_embeddings = {
"image": (image_names, image_embs),
"text": (text_names, text_embs),
"mesh": (mesh_names, mesh_embs)
}
for source_modality, (source_names, source_embs) in all_embeddings.items():
for i, source_name in enumerate(source_names):
source_emb = source_embs[i:i+1]
matches = {}
for target_modality, (target_names, target_embs) in all_embeddings.items():
if not target_names: continue
sims = cosine_similarity(source_emb, target_embs).flatten()
if source_modality == target_modality:
sims[i] = -1
top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES]
matches[target_modality] = [
{"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1
]
key_name = "meshes" if source_modality == "mesh" else source_modality + 's'
full_comparison[key_name].append({"source": source_name, "matches": matches})
print(" ✅ Матрица схожести готова.")
final_response = {
"id": dataset_id,
"name": original_filename,
"uploadDate": datetime.datetime.utcnow().isoformat() + "Z",
"data": dataset_data,
"processingState": "processed",
"processingProgress": 100,
"fullComparison": full_comparison
}
print(f"✅ Обработка датасета {dataset_id} завершена.")
return final_response
def process_shared_dataset_directory(directory_path: Path, embeddings_path: Path, dataset_id: str, dataset_name: str) -> dict:
# This function is unchanged
print(f"⚙️ Начало обработки общего датасета: {dataset_name} (ID: {dataset_id})")
print(" 📂 Сканирование файлов данных...")
image_paths = sorted(list(directory_path.glob("**/*.png")))
text_paths = sorted(list(directory_path.glob("**/*.txt")))
mesh_paths = sorted(list(directory_path.glob("**/*.stl")))
if not any([image_paths, text_paths, mesh_paths]):
print(f"⚠️ В директории общего датасета '{directory_path}' не найдено файлов.")
return None
print(f" ✅ Найдено: {len(image_paths)} изображений, {len(text_paths)} текстов, {len(mesh_paths)} моделей.")
print(" 🧠 Индексирование предварительно вычисленных эмбеддингов...")
all_embedding_paths = list(embeddings_path.glob("**/*.npy"))
embedding_map = {p.stem: p for p in all_embedding_paths}
print(f" ✅ Найдено {len(embedding_map)} файлов эмбеддингов.")
def load_embeddings_for_paths(data_paths: list[Path]):
names = []
embs_list = []
for data_path in data_paths:
file_stem = data_path.stem
if file_stem in embedding_map:
embedding_path = embedding_map[file_stem]
try:
emb = np.load(embedding_path)
embs_list.append(emb)
names.append(data_path.name)
except Exception as e:
print(f" ⚠️ Не удалось загрузить или разобрать эмбеддинг для {data_path.name}: {e}")
else:
print(f" ⚠️ Внимание: не найден соответствующий эмбеддинг для {data_path.name}")
return names, np.array(embs_list) if embs_list else np.array([])
print(" 🚚 Загрузка и сопоставление эмбеддингов...")
image_names, image_embs = load_embeddings_for_paths(image_paths)
text_names, text_embs = load_embeddings_for_paths(text_paths)
mesh_names, mesh_embs = load_embeddings_for_paths(mesh_paths)
print(" ✅ Эмбеддинги для общего датасета загружены.")
static_root = Path("static")
image_items = [{"id": f"image_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(image_paths)]
text_items = [{"id": f"text_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(text_paths)]
mesh_items = [{"id": f"mesh_{i}", "name": p.name, "content": None, "contentUrl": f"/{p.relative_to(static_root)}"} for i, p in enumerate(mesh_paths)]
dataset_data = {"images": image_items, "texts": text_items, "meshes": mesh_items}
DATASET_CACHE[dataset_id] = {"data": dataset_data, "embeddings": {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)}}
print(f" 💾 Эмбеддинги для общего датасета {dataset_id} сохранены в кэш.")
print(" ⚖️ Вычисление полной матрицы схожести для общего датасета...")
full_comparison = {"images": [], "texts": [], "meshes": []}
all_embeddings = {"image": (image_names, image_embs), "text": (text_names, text_embs), "mesh": (mesh_names, mesh_embs)}
for source_modality, (source_names, source_embs) in all_embeddings.items():
if len(source_names) == 0: continue
for i, source_name in enumerate(source_names):
source_emb = source_embs[i:i+1]
matches = {}
for target_modality, (target_names, target_embs) in all_embeddings.items():
if len(target_names) == 0: continue
sims = cosine_similarity(source_emb, target_embs).flatten()
if source_modality == target_modality:
sims[i] = -1
top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES]
matches[target_modality] = [{"item": target_names[j], "confidence": float(sims[j])} for j in top_indices if sims[j] > -1]
key_name = "meshes" if source_modality == "mesh" else source_modality + 's'
full_comparison[key_name].append({"source": source_name, "matches": matches})
print(" ✅ Матрица схожести для общего датасета готова.")
try:
creation_time = datetime.datetime.fromtimestamp(directory_path.stat().st_ctime)
except Exception:
creation_time = datetime.datetime.utcnow()
final_response = {"id": dataset_id, "name": dataset_name, "uploadDate": creation_time.isoformat() + "Z", "data": dataset_data, "processingState": "processed", "processingProgress": 100, "fullComparison": full_comparison, "isShared": True}
print(f"✅ Обработка общего датасета {dataset_id} завершена.")
return final_response
def find_matches_for_item(modality: str, content_base64: str, dataset_id: str) -> dict:
# This function is unchanged
print(f"🔍 Поиск совпадений для объекта ({modality}) в датасете {dataset_id}...")
if dataset_id not in DATASET_CACHE:
raise ValueError(f"Датасет с ID {dataset_id} не найден в кэше.")
content_bytes = base64.b64decode(content_base64)
source_emb = get_embedding_for_single_item(modality, content_bytes)
cached_dataset = DATASET_CACHE[dataset_id]
results = {}
for target_modality, (target_names, target_embs) in cached_dataset["embeddings"].items():
key_name = "meshes" if target_modality == "mesh" else target_modality + 's'
if not target_names: continue
sims = cosine_similarity(source_emb, target_embs).flatten()
top_indices = np.argsort(sims)[::-1][:TOP_K_MATCHES]
target_items_map = {item['name']: item for item in cached_dataset['data'][key_name]}
matches = []
for j in top_indices:
item_name = target_names[j]
if item_name in target_items_map:
matches.append({"item": target_items_map[item_name], "confidence": float(sims[j])})
results[key_name] = matches
print(" ✅ Поиск завершен.")
return {"results": results}
def cache_local_dataset(dataset: dict) -> None:
"""
Receives a full dataset object from the frontend, computes embeddings,
and loads it into the in-memory cache.
"""
dataset_id = dataset.get('id')
if not dataset_id:
print("⚠️ Attempted to cache a dataset without an ID.")
return
if dataset_id in DATASET_CACHE:
print(f"✅ Dataset {dataset_id} is already in the backend cache. Skipping re-hydration.")
return
print(f"🧠 Re-hydrating backend cache for local dataset ID: {dataset_id}")
try:
all_embeddings = {}
all_names = {}
# The content comes in different formats (data URL for images, text for text, etc.)
# We need to decode it before sending to the embedding function.
def get_bytes_from_content(content_str: str, modality: str) -> bytes:
if modality in ['image', 'mesh']:
# Handle data URLs (e.g., "data:image/png;base64,...") or raw base64
if ',' in content_str:
header, encoded = content_str.split(',', 1)
return base64.b64decode(encoded)
else:
return base64.b64decode(content_str)
else: # text
return content_str.encode('utf-8')
for modality_plural, items in dataset.get('data', {}).items():
modality_singular = "mesh" if modality_plural == "meshes" else modality_plural[:-1]
names = []
embs_list = []
print(f" ⚙️ Processing {len(items)} items for modality: {modality_singular}")
for item in items:
item_content = item.get('content')
if not item_content:
continue
content_bytes = get_bytes_from_content(item_content, modality_singular)
embedding = get_embedding_for_single_item(modality_singular, content_bytes)
embs_list.append(embedding[0]) # get_embedding returns shape (1, D)
names.append(item.get('name'))
all_names[modality_singular] = names
all_embeddings[modality_singular] = np.array(embs_list) if embs_list else np.array([])
# Structure the cache entry exactly like process_uploaded_zip does
DATASET_CACHE[dataset_id] = {
"data": dataset.get('data'),
"embeddings": {
mod: (all_names[mod], all_embeddings[mod]) for mod in all_embeddings
}
}
print(f" ✅ Successfully cached {dataset_id} with embeddings.")
except Exception as e:
print(f"🔥 CRITICAL ERROR while re-hydrating cache for {dataset_id}: {e}")
import traceback
traceback.print_exc()