Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -361,33 +361,38 @@ class SmartDocumentRAG:
|
|
| 361 |
|
| 362 |
return "Error: Could not decode file"
|
| 363 |
|
| 364 |
-
def enhanced_chunk_text(self, text: str) ->
|
| 365 |
-
"""
|
| 366 |
-
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
|
|
|
| 369 |
chunks = []
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
chunk_sentences = sentences[i:i + chunk_size]
|
| 382 |
-
if chunk_sentences:
|
| 383 |
-
chunk_text = '. '.join(chunk_sentences) + '.'
|
| 384 |
-
chunks.append({
|
| 385 |
-
'text': chunk_text,
|
| 386 |
-
'sentence_indices': list(range(i, min(i + chunk_size, len(sentences)))),
|
| 387 |
-
'doc_type': self.document_type
|
| 388 |
-
})
|
| 389 |
-
|
| 390 |
return chunks
|
|
|
|
| 391 |
|
| 392 |
def process_documents(self, files) -> str:
|
| 393 |
"""Enhanced document processing"""
|
|
@@ -451,7 +456,7 @@ class SmartDocumentRAG:
|
|
| 451 |
return f"β Error processing documents: {str(e)}"
|
| 452 |
|
| 453 |
def find_relevant_content(self, query: str, k: int = 3) -> str:
|
| 454 |
-
"""Improved content retrieval"""
|
| 455 |
if not self.is_indexed:
|
| 456 |
return ""
|
| 457 |
|
|
@@ -464,17 +469,19 @@ class SmartDocumentRAG:
|
|
| 464 |
|
| 465 |
relevant_chunks = []
|
| 466 |
for i, idx in enumerate(indices[0]):
|
| 467 |
-
|
|
|
|
| 468 |
relevant_chunks.append(self.documents[idx])
|
| 469 |
|
| 470 |
return ' '.join(relevant_chunks)
|
| 471 |
-
|
| 472 |
except Exception as e:
|
| 473 |
print(f"Error in content retrieval: {e}")
|
| 474 |
return ""
|
|
|
|
| 475 |
|
| 476 |
def answer_question(self, query: str) -> str:
|
| 477 |
-
"""Enhanced question answering with better model usage"""
|
| 478 |
if not query.strip():
|
| 479 |
return "β Please ask a question!"
|
| 480 |
|
|
@@ -484,48 +491,61 @@ class SmartDocumentRAG:
|
|
| 484 |
try:
|
| 485 |
query_lower = query.lower()
|
| 486 |
|
| 487 |
-
# Handle summary requests
|
| 488 |
if any(word in query_lower for word in ['summary', 'summarize', 'about', 'overview']):
|
| 489 |
return f"π **Document Summary:**\n\n{self.document_summary}"
|
| 490 |
|
| 491 |
-
#
|
| 492 |
context = self.find_relevant_content(query, k=3)
|
| 493 |
|
| 494 |
if not context:
|
| 495 |
return "π No relevant information found. Try rephrasing your question."
|
| 496 |
|
| 497 |
-
#
|
| 498 |
if self.qa_pipeline is None:
|
| 499 |
return self.extract_direct_answer(query, context)
|
| 500 |
|
| 501 |
try:
|
| 502 |
-
if self.model_type
|
| 503 |
-
# Use Q&A pipeline
|
| 504 |
result = self.qa_pipeline(question=query, context=context)
|
| 505 |
-
answer = result
|
| 506 |
-
confidence = result
|
| 507 |
|
| 508 |
-
if confidence > 0.1
|
| 509 |
return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..."
|
| 510 |
else:
|
| 511 |
return self.extract_direct_answer(query, context)
|
| 512 |
-
|
| 513 |
elif self.model_type == "flan-t5":
|
| 514 |
-
# Use
|
| 515 |
-
prompt =
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
else:
|
|
|
|
| 521 |
return self.extract_direct_answer(query, context)
|
| 522 |
-
|
| 523 |
except Exception as e:
|
| 524 |
print(f"Model inference error: {e}")
|
| 525 |
return self.extract_direct_answer(query, context)
|
| 526 |
|
| 527 |
except Exception as e:
|
| 528 |
return f"β Error processing question: {str(e)}"
|
|
|
|
| 529 |
|
| 530 |
def extract_direct_answer(self, query: str, context: str) -> str:
|
| 531 |
"""Direct answer extraction as fallback"""
|
|
@@ -570,6 +590,33 @@ class SmartDocumentRAG:
|
|
| 570 |
|
| 571 |
return "I found relevant content but couldn't extract a specific answer."
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
# Initialize the system
|
| 574 |
print("Initializing Enhanced Smart RAG System...")
|
| 575 |
rag_system = SmartDocumentRAG()
|
|
|
|
| 361 |
|
| 362 |
return "Error: Could not decode file"
|
| 363 |
|
| 364 |
+
def enhanced_chunk_text(self, text: str, max_chunk_size: int = 300, overlap: int = 50) -> list[str]:
|
| 365 |
+
"""
|
| 366 |
+
Splits text into smaller overlapping chunks for better semantic search.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
text (str): The full text to chunk.
|
| 370 |
+
max_chunk_size (int): Maximum tokens/words per chunk.
|
| 371 |
+
overlap (int): Number of words overlapping between consecutive chunks.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
list[str]: List of text chunks.
|
| 375 |
+
"""
|
| 376 |
+
import re
|
| 377 |
+
|
| 378 |
+
# Clean and normalize whitespace
|
| 379 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 380 |
|
| 381 |
+
words = text.split()
|
| 382 |
chunks = []
|
| 383 |
+
start = 0
|
| 384 |
+
text_len = len(words)
|
| 385 |
+
|
| 386 |
+
while start < text_len:
|
| 387 |
+
end = min(start + max_chunk_size, text_len)
|
| 388 |
+
chunk_words = words[start:end]
|
| 389 |
+
chunk = ' '.join(chunk_words)
|
| 390 |
+
chunks.append(chunk)
|
| 391 |
+
# Move start forward by chunk size minus overlap to create overlap
|
| 392 |
+
start += max_chunk_size - overlap
|
| 393 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
return chunks
|
| 395 |
+
|
| 396 |
|
| 397 |
def process_documents(self, files) -> str:
|
| 398 |
"""Enhanced document processing"""
|
|
|
|
| 456 |
return f"β Error processing documents: {str(e)}"
|
| 457 |
|
| 458 |
def find_relevant_content(self, query: str, k: int = 3) -> str:
|
| 459 |
+
"""Improved content retrieval with stricter relevance filter"""
|
| 460 |
if not self.is_indexed:
|
| 461 |
return ""
|
| 462 |
|
|
|
|
| 469 |
|
| 470 |
relevant_chunks = []
|
| 471 |
for i, idx in enumerate(indices[0]):
|
| 472 |
+
score = scores[0][i]
|
| 473 |
+
if idx < len(self.documents) and score > 0.4: # β
stricter similarity filter
|
| 474 |
relevant_chunks.append(self.documents[idx])
|
| 475 |
|
| 476 |
return ' '.join(relevant_chunks)
|
| 477 |
+
|
| 478 |
except Exception as e:
|
| 479 |
print(f"Error in content retrieval: {e}")
|
| 480 |
return ""
|
| 481 |
+
|
| 482 |
|
| 483 |
def answer_question(self, query: str) -> str:
|
| 484 |
+
"""Enhanced question answering with better model usage and hallucination reduction."""
|
| 485 |
if not query.strip():
|
| 486 |
return "β Please ask a question!"
|
| 487 |
|
|
|
|
| 491 |
try:
|
| 492 |
query_lower = query.lower()
|
| 493 |
|
| 494 |
+
# Handle summary requests explicitly
|
| 495 |
if any(word in query_lower for word in ['summary', 'summarize', 'about', 'overview']):
|
| 496 |
return f"π **Document Summary:**\n\n{self.document_summary}"
|
| 497 |
|
| 498 |
+
# Retrieve relevant content chunks via semantic search
|
| 499 |
context = self.find_relevant_content(query, k=3)
|
| 500 |
|
| 501 |
if not context:
|
| 502 |
return "π No relevant information found. Try rephrasing your question."
|
| 503 |
|
| 504 |
+
# If no QA pipeline, fall back to direct extraction
|
| 505 |
if self.qa_pipeline is None:
|
| 506 |
return self.extract_direct_answer(query, context)
|
| 507 |
|
| 508 |
try:
|
| 509 |
+
if self.model_type in ["distilbert-qa", "fallback"]:
|
| 510 |
+
# Use extractive Q&A pipeline
|
| 511 |
result = self.qa_pipeline(question=query, context=context)
|
| 512 |
+
answer = result.get('answer', '').strip()
|
| 513 |
+
confidence = result.get('score', 0)
|
| 514 |
|
| 515 |
+
if confidence > 0.1 and answer:
|
| 516 |
return f"**Answer:** {answer}\n\n**Context:** {context[:200]}..."
|
| 517 |
else:
|
| 518 |
return self.extract_direct_answer(query, context)
|
| 519 |
+
|
| 520 |
elif self.model_type == "flan-t5":
|
| 521 |
+
# Use generative model with improved prompt to reduce hallucination
|
| 522 |
+
prompt = (
|
| 523 |
+
f"Answer concisely and strictly based on the following context.\n\n"
|
| 524 |
+
f"Context:\n{context}\n\n"
|
| 525 |
+
f"Question:\n{query}\n\n"
|
| 526 |
+
f"If the answer is not contained in the context, reply with 'Not found in document.'\n"
|
| 527 |
+
f"Answer:"
|
| 528 |
+
)
|
| 529 |
+
result = self.qa_pipeline(prompt, max_length=256, num_return_sequences=1)
|
| 530 |
+
generated_text = result[0].get('generated_text', '')
|
| 531 |
+
answer = generated_text.replace(prompt, '').strip()
|
| 532 |
|
| 533 |
+
if answer.lower() in ["not found in document.", "no answer", "unknown", ""]:
|
| 534 |
+
return "π Sorry, the answer was not found in the documents."
|
| 535 |
+
else:
|
| 536 |
+
return f"**Answer:** {answer}"
|
| 537 |
+
|
| 538 |
else:
|
| 539 |
+
# Default fallback extraction
|
| 540 |
return self.extract_direct_answer(query, context)
|
| 541 |
+
|
| 542 |
except Exception as e:
|
| 543 |
print(f"Model inference error: {e}")
|
| 544 |
return self.extract_direct_answer(query, context)
|
| 545 |
|
| 546 |
except Exception as e:
|
| 547 |
return f"β Error processing question: {str(e)}"
|
| 548 |
+
|
| 549 |
|
| 550 |
def extract_direct_answer(self, query: str, context: str) -> str:
|
| 551 |
"""Direct answer extraction as fallback"""
|
|
|
|
| 590 |
|
| 591 |
return "I found relevant content but couldn't extract a specific answer."
|
| 592 |
|
| 593 |
+
def clean_text(self, text: str) -> str:
|
| 594 |
+
"""
|
| 595 |
+
Clean and normalize raw text by:
|
| 596 |
+
- Removing excessive whitespace
|
| 597 |
+
- Fixing merged words (camel case separation)
|
| 598 |
+
- Removing unwanted characters (optional)
|
| 599 |
+
- Lowercasing or preserving case (optional)
|
| 600 |
+
"""
|
| 601 |
+
import re
|
| 602 |
+
|
| 603 |
+
# Replace multiple whitespace/newlines/tabs with single space
|
| 604 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 605 |
+
|
| 606 |
+
# Fix merged words like 'wordAnotherWord' -> 'word Another Word'
|
| 607 |
+
text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text)
|
| 608 |
+
|
| 609 |
+
# Optional: remove special characters except basic punctuation
|
| 610 |
+
# text = re.sub(r'[^a-zA-Z0-9,.!?;:\'\"()\-\s]', '', text)
|
| 611 |
+
|
| 612 |
+
return text
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
|
| 620 |
# Initialize the system
|
| 621 |
print("Initializing Enhanced Smart RAG System...")
|
| 622 |
rag_system = SmartDocumentRAG()
|