alx-d commited on
Commit
01330c2
·
verified ·
1 Parent(s): 97f878b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +99 -23
advanced_rag.py CHANGED
@@ -397,31 +397,25 @@ def load_txt_from_url(url: str) -> Document:
397
  from pdfminer.high_level import extract_text
398
  from langchain_core.documents import Document
399
 
400
-
401
  def get_confirm_token(response):
402
  for key, value in response.cookies.items():
403
  if key.startswith("download_warning"):
404
  return value
405
  return None
406
 
407
-
408
  def download_file_from_google_drive(file_id, destination):
409
  """
410
  Download a file from Google Drive handling large file confirmation.
411
  """
412
  URL = "https://docs.google.com/uc?export=download&confirm=1"
413
  session = requests.Session()
414
-
415
  response = session.get(URL, params={"id": file_id}, stream=True)
416
  token = get_confirm_token(response)
417
-
418
  if token:
419
  params = {"id": file_id, "confirm": token}
420
  response = session.get(URL, params=params, stream=True)
421
-
422
  save_response_content(response, destination)
423
 
424
-
425
  def save_response_content(response, destination):
426
  CHUNK_SIZE = 32768
427
  with open(destination, "wb") as f:
@@ -429,47 +423,131 @@ def save_response_content(response, destination):
429
  if chunk:
430
  f.write(chunk)
431
 
432
-
433
  def extract_file_id(drive_link: str) -> str:
 
434
  match = re.search(r"/d/([a-zA-Z0-9_-]+)", drive_link)
435
  if match:
436
  return match.group(1)
 
 
 
 
 
 
437
  raise ValueError("Could not extract file ID from the provided Google Drive link.")
438
 
439
-
440
- def load_file_from_google_drive(link: str) -> list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  """
442
- Load a document from a Google Drive link using pdfminer to extract text.
443
  Returns a list of LangChain Document objects.
444
  """
445
  file_id = extract_file_id(link)
446
- print(f"[DEBUG] Extracted file ID: {file_id}")
447
-
448
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
449
  temp_path = temp_file.name
450
-
451
  try:
452
  download_file_from_google_drive(file_id, temp_path)
453
- print(f"[DEBUG] File downloaded to: {temp_path}")
454
-
455
  try:
456
  full_text = extract_text(temp_path)
457
  if not full_text.strip():
458
  raise ValueError("Extracted text is empty. The PDF might be image-based.")
459
- print("[DEBUG] Extracted preview text from PDF:")
460
- print(full_text[:1000]) # Preview first 500 characters
461
-
462
  document = Document(page_content=full_text, metadata={"source": link})
463
  return [document]
464
-
465
  except Exception as e:
466
- print(f"[ERROR] Could not extract text from PDF: {e}")
467
  return []
468
-
469
  finally:
470
  if os.path.exists(temp_path):
471
  os.remove(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  class ElevatedRagChain:
474
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
475
  bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
@@ -768,8 +846,6 @@ class ElevatedRagChain:
768
  self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
769
  debug_print("Elevated RAG chain successfully built and ready to use.")
770
 
771
-
772
-
773
  def get_current_context(self) -> str:
774
  base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
775
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
 
397
  from pdfminer.high_level import extract_text
398
  from langchain_core.documents import Document
399
 
 
400
  def get_confirm_token(response):
401
  for key, value in response.cookies.items():
402
  if key.startswith("download_warning"):
403
  return value
404
  return None
405
 
 
406
  def download_file_from_google_drive(file_id, destination):
407
  """
408
  Download a file from Google Drive handling large file confirmation.
409
  """
410
  URL = "https://docs.google.com/uc?export=download&confirm=1"
411
  session = requests.Session()
 
412
  response = session.get(URL, params={"id": file_id}, stream=True)
413
  token = get_confirm_token(response)
 
414
  if token:
415
  params = {"id": file_id, "confirm": token}
416
  response = session.get(URL, params=params, stream=True)
 
417
  save_response_content(response, destination)
418
 
 
419
  def save_response_content(response, destination):
420
  CHUNK_SIZE = 32768
421
  with open(destination, "wb") as f:
 
423
  if chunk:
424
  f.write(chunk)
425
 
 
426
  def extract_file_id(drive_link: str) -> str:
427
+ # Check for /d/ format
428
  match = re.search(r"/d/([a-zA-Z0-9_-]+)", drive_link)
429
  if match:
430
  return match.group(1)
431
+
432
+ # Check for open?id= format
433
+ match = re.search(r"open\?id=([a-zA-Z0-9_-]+)", drive_link)
434
+ if match:
435
+ return match.group(1)
436
+
437
  raise ValueError("Could not extract file ID from the provided Google Drive link.")
438
 
439
+ def load_txt_from_google_drive(link: str) -> Document:
440
+ """
441
+ Load text from a Google Drive shared link
442
+ """
443
+ file_id = extract_file_id(link)
444
+
445
+ # Create direct download link
446
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
447
+
448
+ # Request the file content
449
+ response = requests.get(download_url)
450
+ if response.status_code != 200:
451
+ raise ValueError(f"Failed to download file from Google Drive. Status code: {response.status_code}")
452
+
453
+ # Create a Document object
454
+ content = response.text
455
+ if not content.strip():
456
+ raise ValueError(f"TXT file from Google Drive is empty.")
457
+ metadata = {"source": link}
458
+ return Document(page_content=content, metadata=metadata)
459
+
460
+ def load_pdf_from_google_drive(link: str) -> list:
461
  """
462
+ Load a PDF document from a Google Drive link using pdfminer to extract text.
463
  Returns a list of LangChain Document objects.
464
  """
465
  file_id = extract_file_id(link)
466
+ debug_print(f"Extracted file ID: {file_id}")
 
467
  with tempfile.NamedTemporaryFile(delete=False) as temp_file:
468
  temp_path = temp_file.name
 
469
  try:
470
  download_file_from_google_drive(file_id, temp_path)
471
+ debug_print(f"File downloaded to: {temp_path}")
 
472
  try:
473
  full_text = extract_text(temp_path)
474
  if not full_text.strip():
475
  raise ValueError("Extracted text is empty. The PDF might be image-based.")
476
+ debug_print("Extracted preview text from PDF:")
477
+ debug_print(full_text[:1000]) # Preview first 1000 characters
 
478
  document = Document(page_content=full_text, metadata={"source": link})
479
  return [document]
 
480
  except Exception as e:
481
+ debug_print(f"Could not extract text from PDF: {e}")
482
  return []
 
483
  finally:
484
  if os.path.exists(temp_path):
485
  os.remove(temp_path)
486
+
487
+ def load_file_from_google_drive(link: str) -> list:
488
+ """
489
+ Load a document from a Google Drive link, detecting whether it's a PDF or TXT file.
490
+ Returns a list of LangChain Document objects.
491
+ """
492
+ file_id = extract_file_id(link)
493
+
494
+ # Create direct download link
495
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
496
+
497
+ # First, try to read a small portion of the file to determine its type
498
+ try:
499
+ # Use a streaming request to read just the first part of the file
500
+ response = requests.get(download_url, stream=True)
501
+ if response.status_code != 200:
502
+ raise ValueError(f"Failed to download file from Google Drive. Status code: {response.status_code}")
503
+
504
+ # Read just the first 1024 bytes to check file signature
505
+ file_start = next(response.iter_content(1024))
506
+ response.close() # Close the stream
507
+
508
+ # Convert bytes to string for pattern matching
509
+ file_start_str = file_start.decode('utf-8', errors='ignore')
510
+
511
+ # Check for PDF signature (%PDF-) at the beginning of the file
512
+ if file_start_str.startswith('%PDF-') or b'%PDF-' in file_start:
513
+ debug_print(f"Detected PDF file by content signature from Google Drive: {link}")
514
+ return load_pdf_from_google_drive(link)
515
+ else:
516
+ # If not a PDF, try as text
517
+ debug_print(f"No PDF signature found, treating as TXT file from Google Drive: {link}")
518
+
519
+ # Since we already downloaded part of the file, get the full content
520
+ response = requests.get(download_url)
521
+ if response.status_code != 200:
522
+ raise ValueError(f"Failed to download complete file from Google Drive. Status code: {response.status_code}")
523
 
524
+ content = response.text
525
+ if not content.strip():
526
+ raise ValueError(f"TXT file from Google Drive is empty.")
527
+
528
+ doc = Document(page_content=content, metadata={"source": link})
529
+ return [doc]
530
+
531
+ except UnicodeDecodeError:
532
+ # If we get a decode error, it's likely a binary file like PDF
533
+ debug_print(f"Got decode error, likely a binary file. Treating as PDF from Google Drive: {link}")
534
+ return load_pdf_from_google_drive(link)
535
+ except Exception as e:
536
+ debug_print(f"Error detecting file type: {e}")
537
+
538
+ # Fall back to trying both formats
539
+ debug_print("Falling back to trying both formats for Google Drive file")
540
+ try:
541
+ return load_pdf_from_google_drive(link)
542
+ except Exception as pdf_error:
543
+ debug_print(f"Failed to load as PDF: {pdf_error}")
544
+ try:
545
+ doc = load_txt_from_google_drive(link)
546
+ return [doc]
547
+ except Exception as txt_error:
548
+ debug_print(f"Failed to load as TXT: {txt_error}")
549
+ raise ValueError(f"Could not load file from Google Drive as either PDF or TXT: {link}")
550
+
551
  class ElevatedRagChain:
552
  def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
553
  bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
 
846
  self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
847
  debug_print("Elevated RAG chain successfully built and ready to use.")
848
 
 
 
849
  def get_current_context(self) -> str:
850
  base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
851
  history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"