Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app for Hugging Face chatbot with RAG capabilities. | |
| """ | |
| import warnings | |
| # Suppress deprecation from dependencies (e.g. accelerate) until they use torch.distributed.ReduceOp | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=".*torch.distributed.reduce_op.*ReduceOp.*", | |
| category=FutureWarning, | |
| ) | |
| import gradio as gr | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import os | |
| from typing import List, Tuple | |
| from huggingface_hub import InferenceClient | |
| from ingestion import DocumentIngestion | |
| # Create a clean minimalist theme | |
| class MinimalistTheme(Base): | |
| """A clean, minimalist theme with subtle colors and simple styling.""" | |
| def __init__(self): | |
| super().__init__( | |
| primary_hue=colors.blue, | |
| secondary_hue=colors.gray, | |
| neutral_hue=colors.gray, | |
| spacing_size=sizes.spacing_md, | |
| radius_size=sizes.radius_sm, | |
| text_size=sizes.text_md, | |
| font=( | |
| fonts.GoogleFont("Inter"), | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| ), | |
| font_mono=( | |
| fonts.GoogleFont("JetBrains Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ) | |
| super().set( | |
| # Clean backgrounds | |
| body_background_fill="#ffffff", | |
| body_background_fill_dark="#0f0f0f", | |
| block_background_fill="#ffffff", | |
| block_background_fill_dark="#1a1a1a", | |
| # Subtle borders | |
| block_border_width="1px", | |
| block_border_color="#e0e0e0", | |
| block_border_color_dark="#2a2a2a", | |
| block_shadow="none", | |
| # Clean buttons | |
| button_primary_background_fill="#2563eb", | |
| button_primary_background_fill_hover="#1d4ed8", | |
| button_primary_text_color="#ffffff", | |
| button_primary_background_fill_dark="#3b82f6", | |
| button_primary_background_fill_hover_dark="#2563eb", | |
| button_secondary_background_fill="#f3f4f6", | |
| button_secondary_background_fill_hover="#e5e7eb", | |
| button_secondary_text_color="#111827", | |
| button_secondary_background_fill_dark="#374151", | |
| button_secondary_background_fill_hover_dark="#4b5563", | |
| button_border_width="1px", | |
| # Input fields | |
| input_background_fill="#ffffff", | |
| input_background_fill_dark="#1a1a1a", | |
| input_border_width="1px", | |
| input_border_color="#d1d5db", | |
| input_border_color_dark="#374151", | |
| # Text colors | |
| body_text_color="#111827", | |
| body_text_color_dark="#e5e7eb", | |
| block_label_text_color="#374151", | |
| block_label_text_color_dark="#9ca3af", | |
| ) | |
| class RAGChatbot: | |
| """Chatbot with RAG capabilities.""" | |
| # Default and fallback models (try in order until one is supported by your Inference API providers) | |
| DEFAULT_CHAT_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| FALLBACK_CHAT_MODELS = [ | |
| "ServiceNow-AI/Apriel-1.6-15b-Thinker:together", | |
| "microsoft/phi-2", | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| ] | |
| def __init__( | |
| self, | |
| model_name: str = None, | |
| embedding_model: str = "all-mpnet-base-v2", | |
| vector_store_path: str = "data/vector_store" | |
| ): | |
| """ | |
| Initialize the RAG chatbot. | |
| Args: | |
| model_name: Hugging Face model name for the chatbot (via Inference API) | |
| embedding_model: Model for document embeddings | |
| vector_store_path: Path to saved vector store | |
| """ | |
| self.model_name = model_name if model_name else self.DEFAULT_CHAT_MODEL | |
| # Build list of models to try (primary first, then fallbacks not already primary) | |
| self._models_to_try = [self.model_name] + [ | |
| m for m in self.FALLBACK_CHAT_MODELS if m != self.model_name | |
| ] | |
| # Initialize Inference API client (no model in constructor so we can try multiple) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| # Debug: report HF_TOKEN status (masked) | |
| if not hf_token: | |
| print("[DEBUG] HF_TOKEN: not set (empty or missing)") | |
| print("Warning: HF_TOKEN not set. Inference API calls may fail.") | |
| print("Set HF_TOKEN environment variable or add it to Space secrets.") | |
| else: | |
| masked = f"{hf_token[:4]}...{hf_token[-4:]}" if len(hf_token) > 8 else "****" | |
| print(f"[DEBUG] HF_TOKEN: set (length={len(hf_token)}, masked={masked})") | |
| print("HF_TOKEN found. Inference API ready.") | |
| print(f"[DEBUG] Inference API client (models to try: {self._models_to_try})") | |
| try: | |
| self.inference_client = InferenceClient(token=hf_token) | |
| print("[DEBUG] Inference API client initialized (model chosen per request with fallbacks)") | |
| except Exception as e: | |
| print(f"[DEBUG] Error initializing Inference API client: {type(e).__name__}: {e}") | |
| self.inference_client = None | |
| # Initialize document ingestion | |
| self.ingestion = DocumentIngestion(embedding_model=embedding_model) | |
| # Load vector store if it exists | |
| if os.path.exists(vector_store_path) and os.path.exists( | |
| os.path.join(vector_store_path, "index.faiss") | |
| ): | |
| try: | |
| self.ingestion.load(vector_store_path) | |
| print("Loaded existing vector store") | |
| except Exception as e: | |
| print(f"Could not load vector store: {e}") | |
| self.chat_history = [] | |
| def _generate_with_chat(self, user_content: str, max_new_tokens: int = 512) -> str: | |
| """Call the Inference API using chat_completion; try fallback models if current is not supported.""" | |
| last_error = None | |
| for model in self._models_to_try: | |
| print(f"[DEBUG] _generate_with_chat: trying model={model}, prompt_len={len(user_content)}, max_tokens={max_new_tokens}") | |
| try: | |
| response = self.inference_client.chat_completion( | |
| model=model, | |
| messages=[{"role": "user", "content": user_content}], | |
| max_tokens=max_new_tokens, | |
| temperature=0.7, | |
| ) | |
| print(f"[DEBUG] chat_completion OK for model={model}, response type: {type(response).__name__}") | |
| if response and response.choices and len(response.choices) > 0: | |
| msg = response.choices[0].message | |
| if hasattr(msg, "content") and msg.content: | |
| # Remember this model for next time | |
| self.model_name = model | |
| self._models_to_try = [model] + [m for m in self._models_to_try if m != model] | |
| return msg.content.strip() | |
| print("[DEBUG] chat_completion returned empty or unexpected structure") | |
| except Exception as e: | |
| last_error = e | |
| err_str = str(e).lower() | |
| if "model_not_supported" in err_str or "not supported by any provider" in err_str: | |
| print(f"[DEBUG] Model {model} not available, trying next fallback.") | |
| continue | |
| print(f"[DEBUG] _generate_with_chat exception for {model}: {type(e).__name__}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| if last_error is not None: | |
| raise last_error | |
| return "" | |
| def generate_response(self, query: str, use_rag: bool = True, num_results: int = 5) -> str: | |
| """ | |
| Generate a response to the user query using RAG and Inference API. | |
| Args: | |
| query: User's question | |
| use_rag: Whether to use RAG (retrieve relevant documents) | |
| num_results: Number of document chunks to retrieve | |
| Returns: | |
| Generated response | |
| """ | |
| if self.inference_client is None: | |
| return "Error: Inference API client not initialized. Please check HF_TOKEN configuration." | |
| # If RAG is enabled and we have a vector store, retrieve context and generate answer | |
| if use_rag and self.ingestion.index is not None: | |
| try: | |
| results = self.ingestion.search(query, k=num_results) | |
| if results: | |
| # Build context from retrieved chunks; include source/title so the model can cite it | |
| context_parts = [] | |
| for i, result in enumerate(results, 1): | |
| text = result['text'].strip() | |
| if not text: | |
| continue | |
| meta = result.get('metadata') or {} | |
| source_label = meta.get('document_title') or meta.get('source') or f"Source {i}" | |
| context_parts.append(f"[Context {i}] (Source: {source_label})\n{text}") | |
| context = "\n\n".join(context_parts) | |
| # Build instruction-tuned prompt | |
| prompt = f""" | |
| *You are an expert assistant specializing in organic farming, in particular in Canada and its legal context. | |
| Answer the user's question using only the information provided in the context. | |
| If the context does not include the information needed to answer the question, clearly say: | |
| "The provided context does not contain enough information to answer this question." | |
| When answering: | |
| Respond in English only. | |
| Do not use outside knowledge, assumptions, or guesswork. | |
| Cite or reference the specific parts of the context your answer is based on. | |
| Provide concise, accurate, and helpful explanations. | |
| Do not reveal your internal reasoning. Provide only the final answer. | |
| Structure your answer in the following format: | |
| Summary — A brief, high‑level answer. | |
| Supporting Details — Explain using information only from the provided context. When citing, use the Source label shown for that context (e.g. the document title or name in parentheses after [Context N]). | |
| Context References — List each reference with the exact Source shown for that context (e.g. "CAN/CGSB-32.312-2018" or the document title). Include section name or page when that information appears in the context text. Format: document/source, section or location if available, and a short quote or paraphrase. Do not use only "Context 1" or "Context 5" as the reference; always include the document title/source. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| # Build mapping from context index to source label for resolving references | |
| context_index_to_source = {} | |
| for i, result in enumerate(results, 1): | |
| meta = result.get("metadata") or {} | |
| context_index_to_source[i] = ( | |
| meta.get("document_title") or meta.get("source") or f"Source {i}" | |
| ) | |
| # Generate response using chat/comversational API (Mistral instruct uses this) | |
| try: | |
| response_text = self._generate_with_chat(prompt, max_new_tokens=512) | |
| if response_text: | |
| # Resolve [Context N] to actual source labels in the body | |
| for i, source_label in context_index_to_source.items(): | |
| response_text = response_text.replace( | |
| f"[Context {i}]", | |
| f"({source_label})", | |
| ) | |
| # Append a References section so users see what each source is | |
| ref_lines = [ | |
| "", | |
| "---", | |
| "**References**", | |
| ] | |
| for i, source_label in context_index_to_source.items(): | |
| ref_lines.append(f"{i}. {source_label}") | |
| response_text = response_text.rstrip() + "\n\n" + "\n".join(ref_lines) | |
| return response_text | |
| raise ValueError("Empty response from model") | |
| except Exception as api_error: | |
| print(f"[DEBUG] RAG generation failed: {type(api_error).__name__}: {api_error}") | |
| err_str = str(api_error).lower() | |
| if "model_not_supported" in err_str or "not supported by any provider" in err_str: | |
| return ( | |
| "None of the configured chat models are available with your Inference API providers.\n\n" | |
| "**How to fix:**\n" | |
| "1. See which models are available: https://huggingface.co/inference/models\n" | |
| "2. Enable providers (and pick a chat model): https://huggingface.co/settings/inference-api\n" | |
| "3. In app.py, set RAGChatbot(model_name=\"your-chosen-model-id\") to match a model you enabled." | |
| ) | |
| # Fallback: return formatted chunks with note | |
| response_parts = [] | |
| response_parts.append("I retrieved relevant information, but couldn't generate a synthesized answer. Here are the relevant chunks:\n\n") | |
| for i, result in enumerate(results, 1): | |
| meta = result.get('metadata') or {} | |
| source = meta.get('document_title') or meta.get('source', '') | |
| text = result['text'].strip() | |
| if text: | |
| response_parts.append(f"**Relevant information {i}** (from {source}):\n{text}\n") | |
| return "\n".join(response_parts) | |
| else: | |
| # No results found | |
| return "I couldn't find any relevant information in the documents to answer your question. Please try rephrasing or check if the documents contain information about this topic." | |
| except Exception as e: | |
| print(f"Error in RAG retrieval: {e}") | |
| return f"I encountered an error while searching the documents: {str(e)}" | |
| # If no RAG or no vector store, generate response without context | |
| try: | |
| prompt = f"""You are a helpful assistant. Answer the following question concisely. | |
| Question: {query} | |
| Answer:""" | |
| response_text = self._generate_with_chat(prompt, max_new_tokens=256) | |
| if response_text: | |
| return response_text | |
| return "I couldn't generate a response. Please try again." | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| return f"I encountered an error while generating a response: {str(e)}. Please check your HF_TOKEN configuration." | |
| def chat(self, message: str, history): | |
| """ | |
| Handle chat interaction. | |
| Args: | |
| message: User message | |
| history: Chat history (list of ChatMessage or dicts with 'role' and 'content') | |
| Returns: | |
| Updated history | |
| """ | |
| if not message or not message.strip(): | |
| return "", history or [] | |
| # Ensure history is a list | |
| if history is None: | |
| history = [] | |
| # Add user message as dictionary | |
| history.append({"role": "user", "content": message}) | |
| # Generate response (always use RAG) | |
| try: | |
| response = self.generate_response(message, use_rag=True) | |
| # Ensure response is not empty | |
| if not response or not response.strip(): | |
| response = "I'm sorry, I couldn't generate a response. Please try again." | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| response = f"I encountered an error: {str(e)}" | |
| # Add assistant response as dictionary | |
| history.append({"role": "assistant", "content": response}) | |
| print(f"Debug - History length: {len(history)}") | |
| print(f"Debug - Response: {response[:100] if response else 'None'}...") | |
| return "", history | |
| # Initialize chatbot | |
| chatbot = RAGChatbot() | |
| # Create Gradio interface | |
| with gr.Blocks(title="OCO Chatbot") as app: | |
| gr.Markdown("OCO Chatbot") | |
| chatbot_interface = gr.Chatbot( | |
| label="Chat", | |
| height=500, | |
| value=[] # Initialize with empty list | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Ask a question about your documents...", | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| msg.submit( | |
| chatbot.chat, | |
| inputs=[msg, chatbot_interface], | |
| outputs=[msg, chatbot_interface] | |
| ) | |
| submit_btn.click( | |
| chatbot.chat, | |
| inputs=[msg, chatbot_interface], | |
| outputs=[msg, chatbot_interface] | |
| ) | |
| def clear_chat(): | |
| return [], "" | |
| clear_btn.click(clear_chat, outputs=[chatbot_interface, msg]) | |
| if __name__ == "__main__": | |
| # Get port from environment variable (Hugging Face Spaces sets this) or default to 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| theme=MinimalistTheme() | |
| ) | |