| import os |
|
|
| import datasets |
| from langchain.docstore.document import Document |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_chroma import Chroma |
|
|
| |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| |
| from smolagents import LiteLLMModel, Tool |
| from smolagents.agents import CodeAgent |
|
|
|
|
| |
|
|
|
|
| knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") |
|
|
| source_docs = [ |
| Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base |
| ] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
| AutoTokenizer.from_pretrained("thenlper/gte-small"), |
| chunk_size=200, |
| chunk_overlap=20, |
| add_start_index=True, |
| strip_whitespace=True, |
| separators=["\n\n", "\n", ".", " ", ""], |
| ) |
|
|
| |
| print("Splitting documents...") |
| docs_processed = [] |
| unique_texts = {} |
| for doc in tqdm(source_docs): |
| new_docs = text_splitter.split_documents([doc]) |
| for new_doc in new_docs: |
| if new_doc.page_content not in unique_texts: |
| unique_texts[new_doc.page_content] = True |
| docs_processed.append(new_doc) |
|
|
|
|
| print("Embedding documents... This should take a few minutes (5 minutes on MacBook with M1 Pro)") |
| |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
| |
|
|
| vector_store = Chroma.from_documents(docs_processed, embeddings, persist_directory="./chroma_db") |
|
|
|
|
| class RetrieverTool(Tool): |
| name = "retriever" |
| description = ( |
| "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query." |
| ) |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
| } |
| } |
| output_type = "string" |
|
|
| def __init__(self, vector_store, **kwargs): |
| super().__init__(**kwargs) |
| self.vector_store = vector_store |
|
|
| def forward(self, query: str) -> str: |
| assert isinstance(query, str), "Your search query must be a string" |
| docs = self.vector_store.similarity_search(query, k=3) |
| return "\nRetrieved documents:\n" + "".join( |
| [f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)] |
| ) |
|
|
|
|
| retriever_tool = RetrieverTool(vector_store) |
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| model = LiteLLMModel( |
| model_id="groq/llama-3.3-70b-versatile", |
| api_key=os.environ.get("GROQ_API_KEY"), |
| ) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| agent = CodeAgent( |
| tools=[retriever_tool], |
| model=model, |
| max_steps=4, |
| verbosity_level=2, |
| ) |
|
|
| agent_output = agent.run("How can I push a model to the Hub?") |
|
|
|
|
| print("Final output:") |
| print(agent_output) |
|
|