Spaces:
Build error
Build error
| import os | |
| import json | |
| import gradio as gr | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| from chatbot_config import ChatbotConfig | |
| from chatbot_model import RetrievalChatbot | |
| from tf_data_pipeline import TFDataPipeline | |
| from response_quality_checker import ResponseQualityChecker | |
| from environment_setup import EnvironmentSetup | |
| from sentence_transformers import SentenceTransformer | |
| from logger_config import config_logger | |
| logger = config_logger(__name__) | |
| def load_pipeline(): | |
| """ | |
| Loads config, FAISS index, response pool, SentenceTransformer, TFDataPipeline, and sets up the chatbot. | |
| """ | |
| MODEL_DIR = "models" | |
| FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") | |
| FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") | |
| RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") | |
| config_path = Path(MODEL_DIR) / "config.json" | |
| if config_path.exists(): | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config_dict = json.load(f) | |
| config = ChatbotConfig.from_dict(config_dict) | |
| else: | |
| config = ChatbotConfig() | |
| # Initialize environment | |
| env = EnvironmentSetup() | |
| env.initialize() | |
| # Load models and data | |
| encoder = SentenceTransformer(config.pretrained_model) | |
| data_pipeline = TFDataPipeline( | |
| config=config, | |
| tokenizer=encoder.tokenizer, | |
| encoder=encoder, | |
| response_pool=[], | |
| query_embeddings_cache={}, | |
| index_type='IndexFlatIP', | |
| faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH | |
| ) | |
| # Load FAISS index and response pool | |
| if os.path.exists(FAISS_INDEX_PRODUCTION_PATH) and os.path.exists(RESPONSE_POOL_PATH): | |
| data_pipeline.load_faiss_index(FAISS_INDEX_PRODUCTION_PATH) | |
| with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: | |
| data_pipeline.response_pool = json.load(f) | |
| data_pipeline.validate_faiss_index() | |
| else: | |
| logger.warning("FAISS index or responses are missing. The chatbot may not work properly.") | |
| chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") | |
| quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline) | |
| return chatbot, quality_checker | |
| # Load the chatbot and quality checker globally | |
| chatbot, quality_checker = load_pipeline() | |
| def respond(message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]: | |
| """Generate chatbot response using internal context handling.""" | |
| if not message.strip(): # Skip | |
| return "", history | |
| try: | |
| response, _, metrics, confidence = chatbot.chat( | |
| query=message, | |
| conversation_history=None, # Handled internally | |
| quality_checker=quality_checker, | |
| top_k=10 | |
| ) | |
| history.append((message, response)) | |
| return "", history | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| error_message = "I apologize, but I encountered an error processing your request." | |
| history.append((message, error_message)) | |
| return "", history | |
| def main(): | |
| """Initialize and launch Gradio interface.""" | |
| with gr.Blocks( | |
| title="Chatbot Demo", | |
| css=""" | |
| .message-wrap { max-height: 800px !important; } | |
| .chatbot { min-height: 600px; } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # Retrieval-Based Chatbot Demo using Sentence Transformers + FAISS | |
| Knowledge areas: restaurants, movie tickets, rideshare, coffee, pizza, and auto repair. | |
| """ | |
| ) | |
| # Chat interface with custom height | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| container=True, | |
| height=600, | |
| show_label=True, | |
| elem_classes="chatbot" | |
| ) | |
| # Input area with send button | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| show_label=False, | |
| placeholder="Type your message here...", | |
| container=False, | |
| scale=8 | |
| ) | |
| send = gr.Button( | |
| "Send", | |
| variant="primary", | |
| scale=1, | |
| min_width=100 | |
| ) | |
| clear = gr.Button("Clear Conversation", variant="secondary") | |
| # Event handlers | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot], queue=False) | |
| send.click(respond, [msg, chatbot], [msg, chatbot], queue=False) | |
| clear.click(lambda: ([], []), outputs=[chatbot, msg], queue=False) | |
| # Responsive interface | |
| msg.change(lambda: None, None, None, queue=False) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = main() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) |