Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Tuple, Any, Optional | |
| # Define data directory path | |
| DATA_DIR = "/app/data" | |
| # Global variables to store model and data | |
| _model = None | |
| _question_embeddings = None | |
| _answer_embeddings = None | |
| _qa_data = None | |
| def initialize_model() -> None: | |
| """ | |
| Initialize the model once and store it in a global variable. | |
| """ | |
| global _model | |
| if _model is None: | |
| _model = SentenceTransformer("pkshatech/GLuCoSE-base-ja") | |
| return _model | |
| def get_model() -> SentenceTransformer: | |
| """ | |
| Get the loaded model or initialize it if not loaded. | |
| """ | |
| global _model | |
| if _model is None: | |
| _model = initialize_model() | |
| return _model | |
| def load_embeddings() -> Tuple[np.ndarray, np.ndarray, List[Dict[str, str]]]: | |
| """ | |
| Load embeddings and QA data from files. | |
| """ | |
| global _question_embeddings, _answer_embeddings, _qa_data | |
| try: | |
| q_emb_path = os.path.join(DATA_DIR, "question_embeddings.npy") | |
| a_emb_path = os.path.join(DATA_DIR, "answer_embeddings.npy") | |
| qa_data_path = os.path.join(DATA_DIR, "qa_data.json") | |
| _question_embeddings = np.load(q_emb_path) | |
| _answer_embeddings = np.load(a_emb_path) | |
| with open(qa_data_path, "r", encoding="utf-8") as f: | |
| _qa_data = json.load(f) | |
| return _question_embeddings, _answer_embeddings, _qa_data | |
| except FileNotFoundError as e: | |
| print(f"Warning: Embeddings not found. {str(e)}") | |
| return None, None, None | |
| except Exception as e: | |
| print(f"Error loading embeddings: {str(e)}") | |
| return None, None, None | |
| def get_embeddings() -> ( | |
| Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[List[Dict[str, str]]]] | |
| ): | |
| """ | |
| Get the loaded embeddings or load them if not loaded. | |
| """ | |
| global _question_embeddings, _answer_embeddings, _qa_data | |
| if _question_embeddings is None or _answer_embeddings is None or _qa_data is None: | |
| _question_embeddings, _answer_embeddings, _qa_data = load_embeddings() | |
| return _question_embeddings, _answer_embeddings, _qa_data | |
| def reload_embeddings() -> bool: | |
| """ | |
| Reload embeddings from files. | |
| """ | |
| global _question_embeddings, _answer_embeddings, _qa_data | |
| try: | |
| _question_embeddings, _answer_embeddings, _qa_data = load_embeddings() | |
| print(f"Embeddings reloaded successfully. {len(_qa_data)} QA pairs available.") | |
| return True | |
| except Exception as e: | |
| print(f"Error reloading embeddings: {str(e)}") | |
| return False | |