| from langchain.docstore.document import Document |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.retrievers import BM25Retriever |
|
|
| |
| import warnings |
| warnings.filterwarnings("ignore") |
| import datasets |
| import os |
| import json |
| import subprocess |
| import sys |
| import joblib |
| from llama_cpp import Llama |
| from llama_cpp_agent import LlamaCppAgent |
| from llama_cpp_agent import MessagesFormatterType |
| from llama_cpp_agent.providers import LlamaCppPythonProvider |
| from llama_cpp_agent.chat_history import BasicChatHistory |
| from llama_cpp_agent.chat_history.messages import Roles |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
| from typing import List, Tuple,Dict,Optional |
| from logger import logging |
| from exception import CustomExceptionHandling |
|
|
| from smolagents.gradio_ui import GradioUI |
| from smolagents import ( |
| CodeAgent, |
| GoogleSearchTool, |
| Model, |
| Tool, |
| LiteLLMModel, |
| ToolCallingAgent, |
| ChatMessage,tool,MessageRole |
| ) |
|
|
| cache_file = "docs_processed.joblib" |
| if os.path.exists(cache_file): |
| docs_processed = joblib.load(cache_file) |
| print("Loaded docs_processed from cache.") |
| else: |
| knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") |
| source_docs = [ |
| Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base |
| ] |
|
|
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=2000, |
| chunk_overlap=100, |
| add_start_index=True, |
| strip_whitespace=True, |
| separators=["\n\n", "\n", ".", " ", ""], |
| ) |
| docs_processed = text_splitter.split_documents(source_docs) |
| joblib.dump(docs_processed, cache_file) |
| print("Created and saved docs_processed to cache.") |
|
|
| class RetrieverTool(Tool): |
| name = "retriever" |
| description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query." |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
| } |
| } |
| output_type = "string" |
|
|
| def __init__(self, docs, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.retriever = BM25Retriever.from_documents( |
| docs, |
| k=7, |
| ) |
|
|
| def forward(self, query: str) -> str: |
| assert isinstance(query, str), "Your search query must be a string" |
|
|
| docs = self.retriever.invoke( |
| query, |
| ) |
| return "\nRetrieved documents:\n" + "".join( |
| [ |
| f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content) |
| for i, doc in enumerate(docs) |
| ] |
| ) |
|
|
| retriever_tool = RetrieverTool(docs_processed) |
| |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
| hf_hub_download( |
| repo_id="bartowski/google_gemma-3-1b-it-GGUF", |
| filename="google_gemma-3-1b-it-Q6_K.gguf", |
| local_dir="./models", |
| ) |
| hf_hub_download( |
| repo_id="bartowski/google_gemma-3-1b-it-GGUF", |
| filename="google_gemma-3-1b-it-Q5_K_M.gguf", |
| local_dir="./models", |
| ) |
|
|
| |
| title = "Gemma Llama.cpp" |
| description = """Gemma 3 is a family of lightweight, multimodal open models that offers advanced capabilities like large context windows and multilingual support, enabling diverse applications on various devices.""" |
|
|
|
|
| llm = None |
| llm_model = None |
|
|
|
|
| query_system = """ |
| You are a query rewriter. Your task is to convert a user's question into a concise search query suitable for information retrieval. |
| The goal is to identify the most important keywords for a search engine. |
| |
| Here are some examples: |
| |
| User Question: What is transformer? |
| Search Query: transformer |
| |
| User Question: How does a transformer model work in natural language processing? |
| Search Query: transformer model natural language processing |
| |
| User Question: What are the advantages of using transformers over recurrent neural networks? |
| Search Query: transformer vs recurrent neural network advantages |
| |
| User Question: Explain the attention mechanism in transformers. |
| Search Query: transformer attention mechanism |
| |
| User Question: What are the different types of transformer architectures? |
| Search Query: transformer architectures |
| |
| User Question: What is the history of the transformer model? |
| Search Query: transformer model history |
| """ |
| def to_query(provider,message): |
| try: |
| agent = LlamaCppAgent( |
| provider, |
| system_prompt=f"{query_system}", |
| predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2, |
| debug_output=True, |
| ) |
| |
| |
| settings = provider.get_provider_default_settings() |
| messages = BasicChatHistory() |
| result = agent.get_chat_response( |
| message, |
| llm_sampling_settings=settings, |
| chat_history=messages, |
| returns_streaming_generator=False, |
| print_output=False, |
| ) |
| return result |
| except Exception as e: |
| |
| raise CustomExceptionHandling(e, sys) from e |
|
|
| def respond( |
| message: str, |
| history: List[Tuple[str, str]], |
| model: str, |
| system_message: str, |
| max_tokens: int, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| repeat_penalty: float, |
| ): |
| """ |
| Respond to a message using the Gemma3 model via Llama.cpp. |
| |
| Args: |
| - message (str): The message to respond to. |
| - history (List[Tuple[str, str]]): The chat history. |
| - model (str): The model to use. |
| - system_message (str): The system message to use. |
| - max_tokens (int): The maximum number of tokens to generate. |
| - temperature (float): The temperature of the model. |
| - top_p (float): The top-p of the model. |
| - top_k (int): The top-k of the model. |
| - repeat_penalty (float): The repetition penalty of the model. |
| |
| Returns: |
| str: The response to the message. |
| """ |
| try: |
| |
| global llm |
| global llm_model |
|
|
| |
| if llm is None or llm_model != model: |
| llm = Llama( |
| model_path=f"models/{model}", |
| flash_attn=False, |
| n_gpu_layers=0, |
| n_batch=8, |
| n_ctx=4096, |
| n_threads=2, |
| n_threads_batch=2, |
| ) |
| llm_model = model |
| provider = LlamaCppPythonProvider(llm) |
|
|
| query = to_query(provider,message) |
| |
| text = retriever_tool(query=f"{query}") |
|
|
| retriever_system=""" |
| You are an AI assistant that answers questions based on documents provided by the user. Wait for the user to send a document. Once you receive the document, carefully read its contents and then answer the following question: |
| |
| Question: %s |
| |
| [Wait for user's message containing the document] |
| """ % message |
|
|
| |
| |
| agent = LlamaCppAgent( |
| provider, |
| system_prompt=f"{retriever_system}", |
| predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2, |
| debug_output=True, |
| ) |
|
|
| |
| settings = provider.get_provider_default_settings() |
| settings.temperature = temperature |
| settings.top_k = top_k |
| settings.top_p = top_p |
| settings.max_tokens = max_tokens |
| settings.repeat_penalty = repeat_penalty |
| settings.stream = True |
|
|
| messages = BasicChatHistory() |
|
|
| |
| for msn in history: |
| user = {"role": Roles.user, "content": msn[0]} |
| assistant = {"role": Roles.assistant, "content": msn[1]} |
| messages.add_message(user) |
| messages.add_message(assistant) |
|
|
| |
| stream = agent.get_chat_response( |
| text, |
| llm_sampling_settings=settings, |
| chat_history=messages, |
| returns_streaming_generator=True, |
| print_output=False, |
| ) |
|
|
| |
| logging.info("Response stream generated successfully") |
|
|
| |
| outputs = "" |
| for output in stream: |
| outputs += output |
| yield outputs |
|
|
| |
| except Exception as e: |
| |
| raise CustomExceptionHandling(e, sys) from e |
|
|
|
|
| |
| demo = gr.ChatInterface( |
| respond, |
| examples=[["What is the Transformer?"], ["Tell me About Huggingface."], ["How to upload dataset?"]], |
| additional_inputs_accordion=gr.Accordion( |
| label="⚙️ Parameters", open=False, render=False |
| ), |
| additional_inputs=[ |
| gr.Dropdown( |
| choices=[ |
| "google_gemma-3-1b-it-Q6_K.gguf", |
| "google_gemma-3-1b-it-Q5_K_M.gguf", |
| ], |
| value="google_gemma-3-1b-it-Q5_K_M.gguf", |
| label="Model", |
| info="Select the AI model to use for chat", |
| ), |
| gr.Textbox( |
| value="You are a helpful assistant.", |
| label="System Prompt", |
| info="Define the AI assistant's personality and behavior", |
| lines=2,visible=False |
| ), |
| gr.Slider( |
| minimum=512, |
| maximum=2048, |
| value=1024, |
| step=1, |
| label="Max Tokens", |
| info="Maximum length of response (higher = longer replies)", |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=0.7, |
| step=0.1, |
| label="Temperature", |
| info="Creativity level (higher = more creative, lower = more focused)", |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.95, |
| step=0.05, |
| label="Top-p", |
| info="Nucleus sampling threshold", |
| ), |
| gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=40, |
| step=1, |
| label="Top-k", |
| info="Limit vocabulary choices to top K tokens", |
| ), |
| gr.Slider( |
| minimum=1.0, |
| maximum=2.0, |
| value=1.1, |
| step=0.1, |
| label="Repetition Penalty", |
| info="Penalize repeated words (higher = less repetition)", |
| ), |
| ], |
| theme="Ocean", |
| submit_btn="Send", |
| stop_btn="Stop", |
| title=title, |
| description=description, |
| chatbot=gr.Chatbot(scale=1, show_copy_button=True), |
| flagging_mode="never", |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| demo.launch(debug=False) |
|
|