Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from cryptography.fernet import Fernet | |
| # --- LangChain / RAG Imports (from your first script) --- | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # --- Core Functions (from your first script) --- | |
| def load_decrypted_preprompt(file_path="pre_prompt.enc"): | |
| """ | |
| Load and decrypt the pre-prompt from the encrypted file using the key | |
| stored in the environment variable 'KEY'. | |
| """ | |
| try: | |
| key_str = os.getenv("KEY", "") | |
| if not key_str: | |
| print("Warning: KEY environment variable not set, using default preprompt") | |
| return "You are AMAbot, a helpful assistant that answers questions about Christopher." | |
| key = key_str.encode() | |
| fernet = Fernet(key) | |
| with open(file_path, "rb") as file: | |
| encrypted_text = file.read() | |
| decrypted_text = fernet.decrypt(encrypted_text) | |
| return decrypted_text.decode("utf-8") | |
| except Exception as e: | |
| print(f"Error loading preprompt: {e}, using default") | |
| return "You are AMAbot, a helpful assistant that answers questions about Christopher." | |
| PRE_PROMPT = load_decrypted_preprompt() | |
| DEFAULT_TEMPERATURE = 0.7 | |
| DEFAULT_MAX_TOKENS = 512 | |
| DEFAULT_TOP_K = 50 | |
| DEFAULT_TOP_P = 0.95 | |
| # Using the model from your first script | |
| MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" | |
| def load_vector_db(index_path="faiss_index", model_name="sentence-transformers/all-MiniLM-L6-v2"): | |
| """Load the FAISS vector database from disk.""" | |
| try: | |
| embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
| vector_db = FAISS.load_local( | |
| index_path, | |
| embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| print(f"Successfully loaded vector database from {index_path}") | |
| return vector_db | |
| except Exception as e: | |
| print(f"Failed to load vector database: {e}") | |
| return None | |
| def create_qa_prompt(): | |
| """ | |
| Create a prompt template for QA, formatted for Zephyr/Mistral models. | |
| This is the specific prompt format Zephyr was trained on. | |
| """ | |
| template = """<|system|> | |
| You are a helpful assistant that answers questions using the context provided. | |
| If you don't know the answer based on the context, just say that you don't know. Don't try to make up an answer.</s> | |
| <|user|> | |
| Context: | |
| {context} | |
| Question: {question}</s> | |
| <|assistant|> | |
| Helpful Answer:""" | |
| return PromptTemplate(template=template, input_variables=["context", "question"]) | |
| def update_chat(message, history): | |
| """Append the user's message to the chat history and clear the input box.""" | |
| if history is None: | |
| history = [] | |
| history = history.copy() | |
| history.append({"role": "user", "content": message}) | |
| return history, message, "" | |
| def get_assistant_response(message, history, max_tokens, temperature, top_p, qa_chain_state_dict): | |
| """ | |
| Generate assistant response by manually running the RAG pipeline | |
| and using the chat_completion endpoint. This is the logic from your first script. | |
| """ | |
| vector_db = qa_chain_state_dict.get("vector_db") | |
| answer = "I apologize, but I'm having trouble accessing my knowledge base right now." | |
| if not vector_db: | |
| print("Error: Vector DB is not available.") | |
| history.append({"role": "assistant", "content": answer}) | |
| return history, qa_chain_state_dict | |
| try: | |
| # 1. Retrieve relevant documents from the vector store | |
| retriever = vector_db.as_retriever(search_kwargs={"k": 3}) | |
| retrieved_docs = retriever.invoke(message) | |
| # 2. Format the context for the prompt | |
| context = "\n\n".join([doc.page_content for doc in retrieved_docs]) | |
| # 3. Create the prompt using the correct template for Zephyr | |
| qa_prompt_template = create_qa_prompt() | |
| formatted_prompt = qa_prompt_template.format(context=context, question=message) | |
| # 4. Prepare the message payload for the conversational API | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": formatted_prompt, | |
| } | |
| ] | |
| # 5. Call the correct API endpoint | |
| print("Attempting to call chat_completion API...") | |
| client = InferenceClient(MODEL_NAME, token=os.getenv("HF_TOKEN", "")) | |
| response = client.chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature if temperature > 0 else 0.1, # Temp must be > 0 for chat | |
| top_p=top_p, | |
| stream=False | |
| ) | |
| # 6. Extract the answer | |
| if response.choices and response.choices[0].message: | |
| answer = response.choices[0].message.content.strip() | |
| print(f"API call successful, answer length: {len(answer)}") | |
| else: | |
| print("API returned an empty response.") | |
| except Exception as e: | |
| print(f"An error occurred in get_assistant_response: {type(e).__name__} - {repr(e)}") | |
| answer = f"I'm experiencing technical difficulties. Please try again. (Error: {str(e)[:100]})" | |
| history.append({"role": "assistant", "content": answer}) | |
| return history, qa_chain_state_dict | |
| # --- Initialize Components (from your first script) --- | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| if not HF_TOKEN: | |
| print("Warning: HF_TOKEN token not set in environment variables!") | |
| # Load vector database | |
| vector_db = load_vector_db("faiss_index") | |
| # Prepare the initial state dictionary with the vector_db | |
| qa_chain_state_initial = {"vector_db": vector_db} | |
| # Test the vector DB setup | |
| if vector_db: | |
| print("Testing vector database...") | |
| try: | |
| test_retriever = vector_db.as_retriever(search_kwargs={"k": 1}) | |
| test_docs = test_retriever.invoke("test query") | |
| print("Vector DB test successful, can retrieve documents") | |
| except Exception as e: | |
| print(f"Vector DB test failed: {e}") | |
| # ------------------------------------------------------------------ | |
| # Gradio Interface Layout (from your second script) | |
| # ------------------------------------------------------------------ | |
| with gr.Blocks(fill_width=True, theme=gr.themes.Default(primary_hue="sky")) as demo: | |
| # This HTML block contains all the CSS and JS for the desired layout | |
| gr.HTML(""" | |
| <script> | |
| window.addEventListener("load", () => { | |
| document.documentElement.setAttribute("data-theme", "light"); | |
| }); | |
| </script> | |
| <style> | |
| :root { | |
| --primary-200: transparent !important; | |
| color-scheme: light !important; | |
| background-color: #fff !important; | |
| color: #333 !important; | |
| } | |
| #chatbot .message.user { | |
| background-color: #ccc !important; | |
| color: #222 !important; | |
| } | |
| .gradio-container footer { | |
| display: none !important; | |
| } | |
| .gradio-container { | |
| width: 100% !important; | |
| max-width: none !important; | |
| margin: 0; | |
| } | |
| .gradio-container .fillable { | |
| width: 100% !important; | |
| max-width: unset !important; | |
| margin: 0; | |
| } | |
| .hf-chat-input textarea:focus { | |
| outline: none !important; | |
| box-shadow: none !important; | |
| border-color: #c2c2c2 !important; | |
| } | |
| .hf-chat-input:focus { | |
| outline: none !important; | |
| box-shadow: none !important; | |
| border-color: #c2c2c2 !important; | |
| } | |
| .block-container { | |
| width: 100% !important; | |
| max-width: none !important; | |
| } | |
| body .gradio-container .chatbot .hf-chat-input button .textbox textarea { | |
| background-color: #fff !important; | |
| color: #333 !important; | |
| } | |
| .example-row { | |
| flex-grow: 1 !important; | |
| width: 100% !important; | |
| display: flex; | |
| flex-direction: row; | |
| flex-wrap: wrap; | |
| justify-content: center; | |
| gap: 10px; | |
| } | |
| .input-container { | |
| position: relative; | |
| width: 100%; | |
| } | |
| .hf-chat-input { | |
| background-color: #f9f9f9; | |
| border: 1px solid #e0e0e0; | |
| border-radius: 20px; | |
| padding: 10px 50px 10px 20px; | |
| font-size: 16px; | |
| width: 100%; | |
| box-sizing: border-box; | |
| transition: border-color 0.2s ease; | |
| } | |
| .hf-chat-input:focus { | |
| outline: none; | |
| border-color: #c2c2c2; | |
| } | |
| .send-button { | |
| position: absolute; | |
| right: 10px; | |
| top: 50%; | |
| transform: translateY(-50%); | |
| width: 15px !important; | |
| height: 30px !important; | |
| padding: 0; | |
| background: #fff; | |
| border: none; | |
| border-radius: 50%; | |
| cursor: pointer; | |
| transition: background-color 0.2s ease; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-size: 16px; | |
| line-height: 1; | |
| } | |
| .send-button:hover, | |
| .send-button:focus, | |
| .send-button:active { | |
| background-color: #f0f0f0; | |
| outline: none; | |
| top: 50% !important; | |
| transform: translateY(-50%) !important; | |
| } | |
| .input-row { | |
| display: flex; | |
| align-items: center; | |
| width: 100%; | |
| gap: 10px; | |
| } | |
| </style> | |
| """) | |
| # State management remains the same | |
| qa_chain_state = gr.State(value=qa_chain_state_initial) | |
| user_message_state = gr.State() | |
| chatbot = gr.Chatbot(label="AMAbot", show_label=True, elem_id="chatbot", height=250, type="messages", visible=False) | |
| with gr.Row(elem_classes="example-row", visible=True) as examples_container: | |
| ex1 = gr.Button("Who?") | |
| ex2 = gr.Button("Where?") | |
| ex3 = gr.Button("What?") | |
| ex1.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
| ex2.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
| ex3.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
| with gr.Row(elem_classes="input-row"): | |
| with gr.Column(elem_classes="input-container"): | |
| user_input = gr.Textbox( | |
| show_label=False, | |
| placeholder="Ask AMAbot anything about Christopher", | |
| container=False, | |
| elem_classes="hf-chat-input" | |
| ) | |
| send_btn = gr.Button("❯❯", elem_classes="send-button") | |
| # Hidden inputs for model parameters | |
| max_tokens_input = gr.Number(value=DEFAULT_MAX_TOKENS, visible=False) | |
| temperature_input = gr.Number(value=DEFAULT_TEMPERATURE, visible=False) | |
| top_p_input = gr.Number(value=DEFAULT_TOP_P, visible=False) | |
| # --- Event Handlers (Unchanged, as they correctly call the functions) --- | |
| user_input.submit(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
| send_btn.click(lambda: gr.update(visible=True), None, chatbot, queue=False) | |
| # Submit action for text input | |
| user_input.submit( | |
| update_chat, | |
| inputs=[user_input, chatbot], | |
| outputs=[chatbot, user_message_state, user_input] | |
| ).then( | |
| get_assistant_response, | |
| inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
| outputs=[chatbot, qa_chain_state] | |
| ) | |
| # Click action for send button | |
| send_btn.click( | |
| update_chat, | |
| inputs=[user_input, chatbot], | |
| outputs=[chatbot, user_message_state, user_input] | |
| ).then( | |
| get_assistant_response, | |
| inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
| outputs=[chatbot, qa_chain_state] | |
| ) | |
| # Click actions for example buttons | |
| ex1.click( | |
| lambda history: update_chat("Who is Christopher?", history)[:2], | |
| inputs=[chatbot], | |
| outputs=[chatbot, user_message_state] | |
| ).then( | |
| get_assistant_response, | |
| inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
| outputs=[chatbot, qa_chain_state] | |
| ) | |
| ex2.click( | |
| lambda history: update_chat("Where is Christopher from?", history)[:2], | |
| inputs=[chatbot], | |
| outputs=[chatbot, user_message_state] | |
| ).then( | |
| get_assistant_response, | |
| inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
| outputs=[chatbot, qa_chain_state] | |
| ) | |
| ex3.click( | |
| lambda history: update_chat("What degrees does Christopher have, and what technical experience does he have?", history)[:2], | |
| inputs=[chatbot], | |
| outputs=[chatbot, user_message_state] | |
| ).then( | |
| get_assistant_response, | |
| inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state], | |
| outputs=[chatbot, qa_chain_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(show_api=False, share=True) |