Rulga commited on
Commit
e81d67d
·
1 Parent(s): de863da

Refactor DatasetManager to use DATASET_ID for dataset name and improve chat history file handling and validation

Browse files
Files changed (1) hide show
  1. src/knowledge_base/dataset.py +27 -32
src/knowledge_base/dataset.py CHANGED
@@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Any, Optional, Union
9
  from datetime import datetime
10
  from huggingface_hub import HfApi, HfFolder
11
  from langchain_community.vectorstores import FAISS
12
- from config.settings import VECTOR_STORE_PATH, HF_TOKEN, EMBEDDING_MODEL
13
  from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
14
  import logging
15
 
@@ -18,31 +18,17 @@ logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
  class DatasetManager:
21
- def __init__(self, dataset_name="Rulga/status-law-knowledge-base", token: Optional[str] = None):
22
  """
23
  Initialize dataset manager
24
 
25
  Args:
26
  dataset_name: Hugging Face Hub dataset name
27
- token: Hugging Face access token (if None, will use HF_TOKEN from settings)
28
  """
 
29
  self.token = token if token else HF_TOKEN
30
- if not self.token:
31
- raise ValueError("Hugging Face token not found. Please set HUGGINGFACE_TOKEN environment variable")
32
-
33
- self.dataset_name = dataset_name
34
  self.api = HfApi(token=self.token)
35
-
36
- # Проверяем/создаем репозиторий при инициализации
37
- try:
38
- self.api.repo_info(repo_id=self.dataset_name, repo_type="dataset")
39
- except Exception:
40
- print(f"Создаем новый репозиторий датасета: {self.dataset_name}")
41
- self.api.create_repo(
42
- repo_id=self.dataset_name,
43
- repo_type="dataset",
44
- private=True
45
- )
46
 
47
  def init_dataset_structure(self) -> Tuple[bool, str]:
48
  """
@@ -318,23 +304,21 @@ class DatasetManager:
318
  try:
319
  logger.info(f"Attempting to get chat history from dataset {self.dataset_name}")
320
 
321
- # Получаем список всех файлов в репозитории
322
  files = self.api.list_repo_files(
323
  repo_id=self.dataset_name,
324
  repo_type="dataset"
325
  )
326
 
327
- # Фильтруем только файлы из директории chat_history
328
- chat_files = [f for f in files if f.startswith("chat_history/")]
329
- logger.info(f"Found {len(chat_files)} files in chat_history")
330
 
331
- # Фильтруем по conversation_id если указан
332
  if conversation_id:
333
- chat_files = [f for f in chat_files if f.startswith(f"chat_history/{conversation_id}_")]
334
  logger.info(f"Filtered to {len(chat_files)} files for conversation {conversation_id}")
335
 
336
- # Если нет файлов, возвращаем пустой список
337
- if not chat_files or all(f.endswith(".gitkeep") for f in chat_files):
338
  logger.warning("No chat history files found")
339
  return True, []
340
 
@@ -345,7 +329,6 @@ class DatasetManager:
345
  continue
346
 
347
  try:
348
- # Скачиваем и читаем файл
349
  local_file = self.api.hf_hub_download(
350
  repo_id=self.dataset_name,
351
  filename=file,
@@ -355,20 +338,32 @@ class DatasetManager:
355
 
356
  with open(local_file, "r", encoding="utf-8") as f:
357
  chat_data = json.load(f)
358
- if not isinstance(chat_data, dict) or "messages" not in chat_data:
359
- logger.error(f"Invalid chat data structure in {file}")
 
 
 
 
360
  continue
 
 
 
 
361
  chat_histories.append(chat_data)
362
-
 
 
 
 
363
  except Exception as e:
364
  logger.error(f"Error processing file {file}: {str(e)}")
365
  continue
366
 
367
  return True, chat_histories
368
-
369
  except Exception as e:
370
  logger.error(f"Error getting chat history: {str(e)}")
371
- return False, f"Error getting chat history: {str(e)}"
372
 
373
  def upload_document(self, file_path: str, document_id: Optional[str] = None) -> Tuple[bool, str]:
374
  """
 
9
  from datetime import datetime
10
  from huggingface_hub import HfApi, HfFolder
11
  from langchain_community.vectorstores import FAISS
12
+ from config.settings import VECTOR_STORE_PATH, HF_TOKEN, EMBEDDING_MODEL, DATASET_ID, CHAT_HISTORY_PATH
13
  from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
14
  import logging
15
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
  class DatasetManager:
21
+ def __init__(self, dataset_name: Optional[str] = None, token: Optional[str] = None):
22
  """
23
  Initialize dataset manager
24
 
25
  Args:
26
  dataset_name: Hugging Face Hub dataset name
27
+ token: Hugging Face access token
28
  """
29
+ self.dataset_name = dataset_name or DATASET_ID
30
  self.token = token if token else HF_TOKEN
 
 
 
 
31
  self.api = HfApi(token=self.token)
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def init_dataset_structure(self) -> Tuple[bool, str]:
34
  """
 
304
  try:
305
  logger.info(f"Attempting to get chat history from dataset {self.dataset_name}")
306
 
307
+ # Get all files from repository
308
  files = self.api.list_repo_files(
309
  repo_id=self.dataset_name,
310
  repo_type="dataset"
311
  )
312
 
313
+ # Filter only files from chat_history directory using settings
314
+ chat_files = [f for f in files if f.startswith(f"{CHAT_HISTORY_PATH}/")]
315
+ logger.info(f"Found {len(chat_files)} files in {CHAT_HISTORY_PATH}")
316
 
 
317
  if conversation_id:
318
+ chat_files = [f for f in chat_files if conversation_id in f]
319
  logger.info(f"Filtered to {len(chat_files)} files for conversation {conversation_id}")
320
 
321
+ if not chat_files:
 
322
  logger.warning("No chat history files found")
323
  return True, []
324
 
 
329
  continue
330
 
331
  try:
 
332
  local_file = self.api.hf_hub_download(
333
  repo_id=self.dataset_name,
334
  filename=file,
 
338
 
339
  with open(local_file, "r", encoding="utf-8") as f:
340
  chat_data = json.load(f)
341
+ # Validate chat data structure
342
+ if not isinstance(chat_data, dict):
343
+ logger.error(f"Chat data is not a dictionary in {file}")
344
+ continue
345
+ if "messages" not in chat_data:
346
+ logger.error(f"No 'messages' key in chat data in {file}")
347
  continue
348
+ if not isinstance(chat_data["messages"], list):
349
+ logger.error(f"'messages' is not a list in {file}")
350
+ continue
351
+
352
  chat_histories.append(chat_data)
353
+ logger.info(f"Successfully loaded chat data from {file}")
354
+
355
+ except json.JSONDecodeError as e:
356
+ logger.error(f"Invalid JSON in file {file}: {str(e)}")
357
+ continue
358
  except Exception as e:
359
  logger.error(f"Error processing file {file}: {str(e)}")
360
  continue
361
 
362
  return True, chat_histories
363
+
364
  except Exception as e:
365
  logger.error(f"Error getting chat history: {str(e)}")
366
+ return False, str(e)
367
 
368
  def upload_document(self, file_path: str, document_id: Optional[str] = None) -> Tuple[bool, str]:
369
  """