Spaces:
Sleeping
Sleeping
File size: 6,563 Bytes
334c1a6 5edd1c1 334c1a6 5edd1c1 334c1a6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import time
from typing import Dict, List
import cohere
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from ..config.settings import LLM_MODEL
class DocumentSummarizer:
def __init__(self, retriever, batch_size=4):
self.batch_size = batch_size
self.retriever = retriever # Store the retriever here
self.cohere_client = cohere.ClientV2()
self.components = {
'basic_info': "Basic Paper Information",
'abstract': "Abstract Summary",
'methods': "Methodology Summary",
'results': "Key Results",
'limitations': "Limitations & Future Work",
'related_work': "Related Work",
'applications': "Practical Applications",
'technical': "Technical Details",
'equations': "Key Equations",
}
self.prompts = self._initialize_prompts()
def _initialize_prompts(self):
# It's better to explicitly import what you need
from ..summarization.prompt2 import (
basic_info_prompt, abstract_prompt,
methods_prompt, results_prompt, visuals_prompt, limitations_prompt,
contributions_prompt, related_work_prompt, applications_prompt,
technical_prompt, quick_summary_prompt, reading_guide_prompt, # quick_summary & reading_guide prompts might be needed
equations_prompt
)
return {
'basic_info': basic_info_prompt,
'abstract': abstract_prompt,
'methods': methods_prompt,
'results': results_prompt,
'limitations': limitations_prompt,
'related_work': related_work_prompt,
'applications': applications_prompt,
'technical': technical_prompt,
'equations': equations_prompt,
}
def summarize_text(self, documents: List[Dict], prompt: str, language: str):
"""
Summarizes the provided documents using the given prompt and language
via the Cohere Chat API.
"""
# Use the initialized client
try:
response = self.cohere_client.chat(
model=LLM_MODEL,
documents=documents, # Pass the list of dicts directly
messages=[
{"role": "system", "content": f"You are an expert summarization AI. Please respond in {language}."},
{"role": "user", "content": f"{prompt}"}
],
)
if response and response.message and response.message.content and response.message.content[0] and response.message.content[0].text:
return response.message.content[0].text
else:
return None
except Exception as e:
print(f"Error during Cohere API call: {e}")
return None
def extract_relevant_documents(self, component: str, filename: str, chunk_size: int):
"""
Extracts relevant documents for a specific component from the retriever.
"""
query = f"Analyze the {self.components.get(component, component)} section from the document titled '{filename}'."
# Use the retriever stored in self.
# Pass the chunk_size parameter correctly
try:
documents = self.retriever.get_relevant_docs(
chromdb_query=query,
rerank_query=query,
filter={'filename': filename},
chunk_size=chunk_size
)
return documents
except Exception as e:
print(f"Error during document retrieval for component {component}: {e}")
return []
def summerize_document(self, filename: str, language: str, chunk_size: int):
"""
Summarizes a document by processing each component in parallel.
"""
start_total = time.time()
components = list(self.components.keys())
results = {}
errors = {} # Track errors
def process_component(comp):
try:
document_chunks = self.extract_relevant_documents(comp, filename, chunk_size)
prompt = self.prompts.get(comp)
summary = self.summarize_text(document_chunks, prompt, language)
return comp, summary, None
except Exception as e:
return comp, None, str(e)
with ThreadPoolExecutor(max_workers=None) as executor:
# Submit all component tasks
future_to_component = {executor.submit(process_component, comp): comp for comp in components}
# Process results as they complete
for future in as_completed(future_to_component):
comp = future_to_component[future]
try:
comp_name, result, error = future.result()
if result is not None:
results[comp_name] = result
elif error:
errors[comp_name] = error
except Exception as exc:
errors[comp] = str(exc)
end_total = time.time()
print(f"\n--- Total summarization time for {filename}: {end_total - start_total:.2f} seconds ---\n")
compiled = self.compile_summary(filename, results)
return compiled
def compile_summary(self, filename: str, results: Dict[str, str]) -> str:
"""
Compiles a summary for a document by concatenating the results of all requested components.
Orders sections according to a predefined list.
"""
# Include all components that might have results, maintaining desired order
sections_order = [
'basic_info', 'abstract',
'methods', 'results', 'equations', 'technical',
'related_work', 'applications', 'limitations'
]
lines = [f"# Summary of {filename}", f"Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n"]
for section in sections_order:
# Only add a section if it was processed and returned a result
if section in results and results[section]:
# Use .get with a default in case a component was added to results
# but not self.components (though validate init helps prevent this)
title = self.components.get(section, section).title()
lines.append(f"## {title}\n") # Use ## for subheadings
lines.append(f"{results[section]}\n")
return "\n".join(lines) |