Ilia Tambovtsev commited on
Commit
24e252a
·
1 Parent(s): 0c02234

feat: fix async eval to be really async

Browse files
Files changed (2) hide show
  1. src/eval/eval_mlflow.py +17 -9
  2. src/rag/storage.py +29 -7
src/eval/eval_mlflow.py CHANGED
@@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
18
  from langchain_openai import ChatOpenAI
19
  from pydantic import BaseModel, ConfigDict, Field
20
  from tqdm import tqdm
 
21
 
22
  from src.config import Config, load_spreadsheet
23
  from src.config.logging import setup_logging
@@ -523,6 +524,7 @@ class RAGEvaluatorMlflow:
523
  # Create semaphore within the async context
524
  semaphore = asyncio.Semaphore(self._max_concurrent)
525
 
 
526
  tasks = [
527
  self.process_question(
528
  retriever=retriever,
@@ -536,12 +538,13 @@ class RAGEvaluatorMlflow:
536
  for idx, (_, row) in enumerate(questions_df.iterrows())
537
  ]
538
 
539
- for completed in tqdm(
540
- asyncio.as_completed(tasks),
541
- desc=f"Processing questions (max {self._max_concurrent} concurrent)",
542
- total=len(tasks),
543
- ):
544
- await completed
 
545
 
546
  def run_evaluation(self, questions_df: pd.DataFrame) -> None:
547
  """Run evaluation with async LLM queries and controlled concurrency"""
@@ -585,6 +588,13 @@ class RAGEvaluatorMlflow:
585
  )
586
  )
587
 
 
 
 
 
 
 
 
588
  # Process results
589
  results_df = pd.DataFrame(results_log)
590
  results_df["experiment_name"] = (
@@ -615,7 +625,5 @@ class RAGEvaluatorMlflow:
615
  if values:
616
  mean_value = sum(values) / len(values)
617
  mlflow.log_metric(f"mean_{name}", mean_value)
 
618
  self._logger.info(f"Mean {name}: {mean_value:.3f}")
619
-
620
-
621
-
 
18
  from langchain_openai import ChatOpenAI
19
  from pydantic import BaseModel, ConfigDict, Field
20
  from tqdm import tqdm
21
+ from tqdm.asyncio import tqdm_asyncio
22
 
23
  from src.config import Config, load_spreadsheet
24
  from src.config.logging import setup_logging
 
524
  # Create semaphore within the async context
525
  semaphore = asyncio.Semaphore(self._max_concurrent)
526
 
527
+ # Create tasks for all questions
528
  tasks = [
529
  self.process_question(
530
  retriever=retriever,
 
538
  for idx, (_, row) in enumerate(questions_df.iterrows())
539
  ]
540
 
541
+ # Wait for all tasks to complete
542
+ await tqdm_asyncio.gather(
543
+ *tasks,
544
+ desc=f"Processing questions for '{retriever.scorer.id[:15]}' (max {self._max_concurrent} concurrent)",
545
+ total=len(questions_df),
546
+ dynamic_ncols=True, # Adjust width automatically
547
+ )
548
 
549
  def run_evaluation(self, questions_df: pd.DataFrame) -> None:
550
  """Run evaluation with async LLM queries and controlled concurrency"""
 
588
  )
589
  )
590
 
591
+ # Calculate n_errors
592
+ n_errors = (
593
+ len(questions_df) - len(results_log)
594
+ if results_log
595
+ else len(questions_df)
596
+ )
597
+
598
  # Process results
599
  results_df = pd.DataFrame(results_log)
600
  results_df["experiment_name"] = (
 
625
  if values:
626
  mean_value = sum(values) / len(values)
627
  mlflow.log_metric(f"mean_{name}", mean_value)
628
+ mlflow.log_metric(f"error_rate", n_errors / len(questions_df))
629
  self._logger.info(f"Mean {name}: {mean_value:.3f}")
 
 
 
src/rag/storage.py CHANGED
@@ -445,7 +445,7 @@ class ChromaSlideStore:
445
  """Get embeddings for texts"""
446
  return self._embeddings.embed_documents(texts)
447
 
448
- def query_storage(
449
  self,
450
  query: str,
451
  n_results: int = 10,
@@ -462,14 +462,27 @@ class ChromaSlideStore:
462
  List of ScoredChunks sorted by similarity
463
  """
464
  # Get query embedding
465
- query_embedding = self._embeddings.embed_query(query)
466
 
467
  # Query ChromaDB
468
  result = self._collection.query(
469
  query_embeddings=[query_embedding], n_results=n_results, where=where
470
  )
 
 
 
 
 
 
 
 
 
 
471
  return result
472
 
 
 
 
473
  def _process_chroma_results(self, results: QueryResult) -> List[ScoredChunk]:
474
  """Convert ChromaDB results to list of (Document, score) tuples
475
 
@@ -490,7 +503,7 @@ class ChromaSlideStore:
490
 
491
  return sorted(scored_chunks, key=lambda chunk: chunk.score)
492
 
493
- def search_query(
494
  self,
495
  query: str,
496
  chunk_types: Optional[List[str]] = None,
@@ -545,7 +558,10 @@ class ChromaSlideStore:
545
  ),
546
  )
547
 
548
- def search_query_pages(
 
 
 
549
  self,
550
  query: str,
551
  chunk_types: Optional[List[str]] = None,
@@ -566,7 +582,7 @@ class ChromaSlideStore:
566
  List of search results with full slide context, deduplicated by slide_id
567
  """
568
  # First perform regular search
569
- search_results = self.search_query(
570
  query=query,
571
  chunk_types=chunk_types,
572
  n_results=n_results, # * 3, # Get more to ensure different pages
@@ -621,7 +637,10 @@ class ChromaSlideStore:
621
 
622
  return page_results # [:n_results]
623
 
624
- def search_query_presentations(
 
 
 
625
  self,
626
  query: str,
627
  chunk_types: Optional[List[str]] = None,
@@ -644,7 +663,7 @@ class ChromaSlideStore:
644
  List of presentations with their matching slides, sorted by best match
645
  """
646
  # Get initial search results with enough buffer for filtering
647
- search_results = self.search_query_pages(
648
  query=query,
649
  chunk_types=chunk_types,
650
  n_results=n_results,
@@ -689,6 +708,9 @@ class ChromaSlideStore:
689
 
690
  return ScoredPresentations(presentations=presentation_results, scorer=scorer)
691
 
 
 
 
692
  def get_by_metadata(
693
  self, metadata_filter: Dict, n_results: Optional[int] = None
694
  ) -> List[Document]:
 
445
  """Get embeddings for texts"""
446
  return self._embeddings.embed_documents(texts)
447
 
448
+ async def aquery_storage(
449
  self,
450
  query: str,
451
  n_results: int = 10,
 
462
  List of ScoredChunks sorted by similarity
463
  """
464
  # Get query embedding
465
+ query_embedding = await self._embeddings.aembed_query(query)
466
 
467
  # Query ChromaDB
468
  result = self._collection.query(
469
  query_embeddings=[query_embedding], n_results=n_results, where=where
470
  )
471
+
472
+ ## Run ChromaDB query in executor to avoid blocking
473
+ # result = await asyncio.get_event_loop().run_in_executor(
474
+ # None,
475
+ # lambda: self._collection.query(
476
+ # query_embeddings=[query_embedding],
477
+ # n_results=n_results,
478
+ # where=where
479
+ # )
480
+ # )
481
  return result
482
 
483
+ def query_storage(self, *args, **kwargs):
484
+ return asyncio.run(self.aquery_storage(*args, **kwargs))
485
+
486
  def _process_chroma_results(self, results: QueryResult) -> List[ScoredChunk]:
487
  """Convert ChromaDB results to list of (Document, score) tuples
488
 
 
503
 
504
  return sorted(scored_chunks, key=lambda chunk: chunk.score)
505
 
506
+ async def asearch_query(
507
  self,
508
  query: str,
509
  chunk_types: Optional[List[str]] = None,
 
558
  ),
559
  )
560
 
561
+ def search_query(self, *args, **kwargs):
562
+ return asyncio.run(self.asearch_query(*args, **kwargs))
563
+
564
+ async def asearch_query_pages(
565
  self,
566
  query: str,
567
  chunk_types: Optional[List[str]] = None,
 
582
  List of search results with full slide context, deduplicated by slide_id
583
  """
584
  # First perform regular search
585
+ search_results = await self.asearch_query(
586
  query=query,
587
  chunk_types=chunk_types,
588
  n_results=n_results, # * 3, # Get more to ensure different pages
 
637
 
638
  return page_results # [:n_results]
639
 
640
+ def search_query_pages(self, *args, **kwargs):
641
+ return asyncio.run(self.asearch_query_pages(*args, **kwargs))
642
+
643
+ async def asearch_query_presentations(
644
  self,
645
  query: str,
646
  chunk_types: Optional[List[str]] = None,
 
663
  List of presentations with their matching slides, sorted by best match
664
  """
665
  # Get initial search results with enough buffer for filtering
666
+ search_results = await self.asearch_query_pages(
667
  query=query,
668
  chunk_types=chunk_types,
669
  n_results=n_results,
 
708
 
709
  return ScoredPresentations(presentations=presentation_results, scorer=scorer)
710
 
711
+ def search_query_presentations(self, *args, **kwargs):
712
+ return asyncio.run(self.asearch_query_presentations(*args, **kwargs))
713
+
714
  def get_by_metadata(
715
  self, metadata_filter: Dict, n_results: Optional[int] = None
716
  ) -> List[Document]: