Spaces:
Sleeping
Sleeping
| # Utilities to build a RAG system to query information from the | |
| # gwIAS search pipeline using Langchain | |
| # Thanks to Pablo Villanueva Domingo for sharing his CAMELS template | |
| # https://huggingface.co/spaces/PabloVD/CAMELSDocBot | |
| from langchain import hub | |
| from langchain_chroma import Chroma | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain.schema import Document | |
| import requests | |
| import json | |
| import base64 | |
| from bs4 import BeautifulSoup | |
| import re | |
| from urllib.parse import urljoin, urlparse | |
| def github_to_raw(url): | |
| """Convert GitHub URL to raw content URL""" | |
| return url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/") | |
| def load_github_notebook(url): | |
| """Load Jupyter notebook from GitHub URL using GitHub API""" | |
| try: | |
| # Convert GitHub blob URL to API URL | |
| if "github.com" in url and "/blob/" in url: | |
| # Extract owner, repo, branch and path from URL | |
| parts = url.replace("https://github.com/", "").split("/") | |
| owner = parts[0] | |
| repo = parts[1] | |
| branch = parts[3] # usually 'main' or 'master' | |
| path = "/".join(parts[4:]) | |
| api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}" | |
| else: | |
| raise ValueError("URL must be a GitHub blob URL") | |
| # Fetch notebook content | |
| response = requests.get(api_url) | |
| response.raise_for_status() | |
| content_data = response.json() | |
| if content_data.get('encoding') == 'base64': | |
| notebook_content = base64.b64decode(content_data['content']).decode('utf-8') | |
| else: | |
| notebook_content = content_data['content'] | |
| # Parse notebook JSON | |
| notebook = json.loads(notebook_content) | |
| docs = [] | |
| cell_count = 0 | |
| # Process each cell | |
| for cell in notebook.get('cells', []): | |
| cell_count += 1 | |
| cell_type = cell.get('cell_type', 'unknown') | |
| source = cell.get('source', []) | |
| # Join source lines | |
| if isinstance(source, list): | |
| content = ''.join(source) | |
| else: | |
| content = str(source) | |
| if content.strip(): # Only add non-empty cells | |
| metadata = { | |
| 'source': url, | |
| 'cell_type': cell_type, | |
| 'cell_number': cell_count, | |
| 'name': f"{url} - Cell {cell_count} ({cell_type})" | |
| } | |
| # Add cell type prefix for better context | |
| formatted_content = f"[{cell_type.upper()} CELL {cell_count}]\n{content}" | |
| docs.append(Document(page_content=formatted_content, metadata=metadata)) | |
| return docs | |
| except Exception as e: | |
| print(f"Error loading notebook from {url}: {str(e)}") | |
| return [] | |
| def clean_text(text): | |
| """Clean text content from a webpage""" | |
| # Remove excessive newlines | |
| text = re.sub(r'\n{3,}', '\n\n', text) | |
| # Remove excessive whitespace | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| return text.strip() | |
| def clean_github_content(html_content): | |
| """Extract meaningful content from GitHub pages""" | |
| # Ensure we're working with a BeautifulSoup object | |
| if isinstance(html_content, str): | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| else: | |
| soup = html_content | |
| # Remove navigation, footer, and other boilerplate | |
| for element in soup.find_all(['nav', 'footer', 'header']): | |
| element.decompose() | |
| # For README and code files | |
| readme_content = soup.find('article', class_='markdown-body') | |
| if readme_content: | |
| return clean_text(readme_content.get_text()) | |
| # For code files | |
| code_content = soup.find('table', class_='highlight') | |
| if code_content: | |
| return clean_text(code_content.get_text()) | |
| # For directory listings | |
| file_list = soup.find('div', role='grid') | |
| if file_list: | |
| return clean_text(file_list.get_text()) | |
| # Fallback to main content | |
| main_content = soup.find('main') | |
| if main_content: | |
| return clean_text(main_content.get_text()) | |
| # If no specific content found, get text from body | |
| body = soup.find('body') | |
| if body: | |
| return clean_text(body.get_text()) | |
| # Final fallback | |
| return clean_text(soup.get_text()) | |
| class GitHubLoader(WebBaseLoader): | |
| """Custom loader for GitHub pages with better content cleaning""" | |
| def clean_text(self, text): | |
| """Clean text content""" | |
| # Remove excessive newlines and spaces | |
| text = re.sub(r'\n{2,}', '\n', text) | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| # Remove common GitHub boilerplate | |
| text = re.sub(r'Skip to content|Sign in|Search or jump to|Footer navigation|Terms|Privacy|Security|Status|Docs', '', text) | |
| return text.strip() | |
| def lazy_load(self) -> list[Document]: | |
| """Override lazy_load instead of _scrape to handle both BeautifulSoup and string returns.""" | |
| for url in self.web_paths: | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| # For directory listings (tree URLs), use the API | |
| if '/tree/' in url: | |
| # Parse URL components | |
| parts = url.replace("https://github.com/", "").split("/") | |
| owner = parts[0] | |
| repo = parts[1] | |
| branch = parts[3] # usually 'main' or 'master' | |
| path = "/".join(parts[4:]) if len(parts) > 4 else "" | |
| # Construct API URL | |
| api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}?ref={branch}" | |
| api_response = requests.get(api_url) | |
| api_response.raise_for_status() | |
| # Parse directory listing | |
| contents = api_response.json() | |
| if isinstance(contents, list): | |
| # Format directory contents | |
| content = "Directory contents:\n" + "\n".join([f"{item['name']} ({item['type']})" for item in contents]) | |
| yield Document( | |
| page_content=self.clean_text(content), | |
| metadata={'source': url, 'type': 'github_directory'} | |
| ) | |
| continue | |
| # For regular files, parse HTML | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| # For README and markdown files | |
| readme_content = soup.find('article', class_='markdown-body') | |
| if readme_content: | |
| yield Document( | |
| page_content=self.clean_text(readme_content.get_text()), | |
| metadata={'source': url, 'type': 'github_markdown'} | |
| ) | |
| continue | |
| # For code files | |
| code_content = soup.find('table', class_='highlight') | |
| if code_content: | |
| yield Document( | |
| page_content=self.clean_text(code_content.get_text()), | |
| metadata={'source': url, 'type': 'github_code'} | |
| ) | |
| continue | |
| # For other content, get main content | |
| main_content = soup.find('main') | |
| if main_content: | |
| yield Document( | |
| page_content=self.clean_text(main_content.get_text()), | |
| metadata={'source': url, 'type': 'github_other'} | |
| ) | |
| continue | |
| # Fallback to whole page content | |
| yield Document( | |
| page_content=self.clean_text(soup.get_text()), | |
| metadata={'source': url, 'type': 'github_fallback'} | |
| ) | |
| except Exception as e: | |
| print(f"Error processing {url}: {str(e)}") | |
| continue | |
| def load(self) -> list[Document]: | |
| """Load method that returns a list of documents.""" | |
| return list(self.lazy_load()) | |
| class ReadTheDocsLoader(WebBaseLoader): | |
| """Custom loader for ReadTheDocs pages""" | |
| def __init__(self, base_url: str): | |
| """Initialize with base URL of the documentation.""" | |
| super().__init__([]) | |
| self.base_url = base_url.rstrip('/') | |
| def clean_text(self, text: str) -> str: | |
| """Clean text content from ReadTheDocs pages.""" | |
| # Remove excessive whitespace and newlines | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| text = re.sub(r'\n{3,}', '\n\n', text) | |
| # Remove common ReadTheDocs boilerplate | |
| text = re.sub(r'View page source|Next|Previous|©.*?\.', '', text) | |
| return text.strip() | |
| def normalize_url(self, base_url: str, href: str) -> str: | |
| """Normalize relative URLs to absolute URLs.""" | |
| # If it's already an absolute URL, return it | |
| if href.startswith(('http://', 'https://')): | |
| return href | |
| # Handle relative URLs | |
| return urljoin(base_url, href) | |
| def get_all_pages(self) -> list[str]: | |
| """Get all documentation pages starting from the base URL.""" | |
| visited = set() | |
| to_visit = {self.base_url} | |
| docs_urls = set() | |
| while to_visit: | |
| url = to_visit.pop() | |
| if url in visited: | |
| continue | |
| visited.add(url) | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| # Add current page if it's a documentation page | |
| if url.startswith(self.base_url): | |
| docs_urls.add(url) | |
| # Find all links | |
| for link in soup.find_all('a'): | |
| href = link.get('href') | |
| if not href: | |
| continue | |
| # Skip anchor links and external links | |
| if href.startswith('#') or href.startswith(('http://', 'https://')) and not href.startswith(self.base_url): | |
| continue | |
| # Normalize the URL | |
| full_url = self.normalize_url(url, href) | |
| # Only follow links within the documentation domain | |
| if full_url.startswith(self.base_url): | |
| to_visit.add(full_url) | |
| except Exception as e: | |
| print(f"Error fetching {url}: {str(e)}") | |
| return list(docs_urls) | |
| def load(self) -> list[Document]: | |
| """Load all documentation pages.""" | |
| urls = self.get_all_pages() | |
| docs = [] | |
| for url in urls: | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| # Get main content | |
| main_content = soup.find('div', {'role': 'main'}) | |
| if not main_content: | |
| main_content = soup.find('main') | |
| if not main_content: | |
| continue | |
| # Clean content | |
| content = self.clean_text(main_content.get_text()) | |
| if content: | |
| docs.append(Document( | |
| page_content=content, | |
| metadata={'source': url, 'type': 'readthedocs'} | |
| )) | |
| except Exception as e: | |
| print(f"Error processing {url}: {str(e)}") | |
| return docs | |
| def load_docs(): | |
| """Load all documentation.""" | |
| # Get urls | |
| with open("urls.txt", "r") as f: | |
| urls = [line.strip() for line in f.readlines()] | |
| docs = [] | |
| # Load GitHub content | |
| for url in urls: | |
| if "github.com" in url or "raw.githubusercontent.com" in url: | |
| if "/blob/" in url and url.endswith(".ipynb"): | |
| # Handle Jupyter notebooks | |
| notebook_docs = load_github_notebook(url) | |
| docs.extend(notebook_docs) | |
| elif "raw.githubusercontent.com" in url: | |
| # Handle raw GitHub content directly | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| content = response.text | |
| docs.append(Document( | |
| page_content=content, | |
| metadata={'source': url, 'type': 'github_raw'} | |
| )) | |
| except Exception as e: | |
| print(f"Error loading raw content from {url}: {str(e)}") | |
| else: | |
| # Handle other GitHub content | |
| loader = GitHubLoader([url]) | |
| docs.extend(loader.load()) | |
| # Load ReadTheDocs content | |
| rtd_loader = ReadTheDocsLoader("https://gwfast.readthedocs.io/en/latest") | |
| docs.extend(rtd_loader.load()) | |
| return docs | |
| def extract_reference(url): | |
| """Extract a reference keyword from the GitHub URL""" | |
| if "blob/main" in url: | |
| return url.split("blob/main/")[-1] | |
| elif "tree/main" in url: | |
| return url.split("tree/main/")[-1] or "root" | |
| elif "blob/master" in url: | |
| return url.split("blob/master/")[-1] | |
| elif "tree/master" in url: | |
| return url.split("tree/master/")[-1] or "root" | |
| elif "refs/heads/master" in url: | |
| return url.split("refs/heads/master/")[-1] | |
| return url | |
| # Join content pages for processing | |
| def format_docs(docs): | |
| formatted_docs = [] | |
| for doc in docs: | |
| source = doc.metadata.get('source', 'Unknown source') | |
| reference = f"[{extract_reference(source)}]" | |
| content = doc.page_content | |
| formatted_docs.append(f"{content}\n\nReference: {reference}") | |
| return "\n\n---\n\n".join(formatted_docs) | |
| # Create a RAG chain | |
| def RAG(llm, docs, embeddings): | |
| # Split text | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| splits = text_splitter.split_documents(docs) | |
| # Create vector store | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
| # Retrieve and generate using the relevant snippets of the documents | |
| retriever = vectorstore.as_retriever() | |
| # Prompt basis example for RAG systems | |
| prompt = hub.pull("rlm/rag-prompt") | |
| # Adding custom instructions to the prompt | |
| template = prompt.messages[0].prompt.template | |
| template_parts = template.split("\nQuestion: {question}") | |
| combined_template = "You are an assistant for question-answering tasks. "\ | |
| + "Use the following pieces of retrieved context to answer the question. "\ | |
| + "If you don't know the answer, just say that you don't know. "\ | |
| + "Try to keep the answer concise if possible. "\ | |
| + "Write the names of the relevant functions from the retrived code and include code snippets to aid the user's understanding. "\ | |
| + "Include the references used in square brackets at the end of your answer."\ | |
| + template_parts[1] | |
| prompt.messages[0].prompt.template = combined_template | |
| # Create the chain | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return rag_chain |