Alphin Jain
added summarizer module
5edd1c1
import time
from typing import Dict, List
import cohere
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from ..config.settings import LLM_MODEL
class DocumentSummarizer:
def __init__(self, retriever, batch_size=4):
self.batch_size = batch_size
self.retriever = retriever # Store the retriever here
self.cohere_client = cohere.ClientV2()
self.components = {
'basic_info': "Basic Paper Information",
'abstract': "Abstract Summary",
'methods': "Methodology Summary",
'results': "Key Results",
'limitations': "Limitations & Future Work",
'related_work': "Related Work",
'applications': "Practical Applications",
'technical': "Technical Details",
'equations': "Key Equations",
}
self.prompts = self._initialize_prompts()
def _initialize_prompts(self):
# It's better to explicitly import what you need
from ..summarization.prompt2 import (
basic_info_prompt, abstract_prompt,
methods_prompt, results_prompt, visuals_prompt, limitations_prompt,
contributions_prompt, related_work_prompt, applications_prompt,
technical_prompt, quick_summary_prompt, reading_guide_prompt, # quick_summary & reading_guide prompts might be needed
equations_prompt
)
return {
'basic_info': basic_info_prompt,
'abstract': abstract_prompt,
'methods': methods_prompt,
'results': results_prompt,
'limitations': limitations_prompt,
'related_work': related_work_prompt,
'applications': applications_prompt,
'technical': technical_prompt,
'equations': equations_prompt,
}
def summarize_text(self, documents: List[Dict], prompt: str, language: str):
"""
Summarizes the provided documents using the given prompt and language
via the Cohere Chat API.
"""
# Use the initialized client
try:
response = self.cohere_client.chat(
model=LLM_MODEL,
documents=documents, # Pass the list of dicts directly
messages=[
{"role": "system", "content": f"You are an expert summarization AI. Please respond in {language}."},
{"role": "user", "content": f"{prompt}"}
],
)
if response and response.message and response.message.content and response.message.content[0] and response.message.content[0].text:
return response.message.content[0].text
else:
return None
except Exception as e:
print(f"Error during Cohere API call: {e}")
return None
def extract_relevant_documents(self, component: str, filename: str, chunk_size: int):
"""
Extracts relevant documents for a specific component from the retriever.
"""
query = f"Analyze the {self.components.get(component, component)} section from the document titled '{filename}'."
# Use the retriever stored in self.
# Pass the chunk_size parameter correctly
try:
documents = self.retriever.get_relevant_docs(
chromdb_query=query,
rerank_query=query,
filter={'filename': filename},
chunk_size=chunk_size
)
return documents
except Exception as e:
print(f"Error during document retrieval for component {component}: {e}")
return []
def summerize_document(self, filename: str, language: str, chunk_size: int):
"""
Summarizes a document by processing each component in parallel.
"""
start_total = time.time()
components = list(self.components.keys())
results = {}
errors = {} # Track errors
def process_component(comp):
try:
document_chunks = self.extract_relevant_documents(comp, filename, chunk_size)
prompt = self.prompts.get(comp)
summary = self.summarize_text(document_chunks, prompt, language)
return comp, summary, None
except Exception as e:
return comp, None, str(e)
with ThreadPoolExecutor(max_workers=None) as executor:
# Submit all component tasks
future_to_component = {executor.submit(process_component, comp): comp for comp in components}
# Process results as they complete
for future in as_completed(future_to_component):
comp = future_to_component[future]
try:
comp_name, result, error = future.result()
if result is not None:
results[comp_name] = result
elif error:
errors[comp_name] = error
except Exception as exc:
errors[comp] = str(exc)
end_total = time.time()
print(f"\n--- Total summarization time for {filename}: {end_total - start_total:.2f} seconds ---\n")
compiled = self.compile_summary(filename, results)
return compiled
def compile_summary(self, filename: str, results: Dict[str, str]) -> str:
"""
Compiles a summary for a document by concatenating the results of all requested components.
Orders sections according to a predefined list.
"""
# Include all components that might have results, maintaining desired order
sections_order = [
'basic_info', 'abstract',
'methods', 'results', 'equations', 'technical',
'related_work', 'applications', 'limitations'
]
lines = [f"# Summary of {filename}", f"Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n"]
for section in sections_order:
# Only add a section if it was processed and returned a result
if section in results and results[section]:
# Use .get with a default in case a component was added to results
# but not self.components (though validate init helps prevent this)
title = self.components.get(section, section).title()
lines.append(f"## {title}\n") # Use ## for subheadings
lines.append(f"{results[section]}\n")
return "\n".join(lines)