Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # ------------------------------------------------------------------- | |
| # Constants | |
| DB_DIR = "chroma_db" | |
| MODEL_NAME_EMBEDDINGS = "sentence-transformers/all-MiniLM-L6-v2" | |
| MODEL_ID_LLM = "google/flan-t5-base" | |
| DOC_PATH = "temp_docs/samsung_manual.txt" # fixed document path | |
| # Globals | |
| conversation_chain = None | |
| chat_history = [] # [{"role": "user/assistant", "content": "..."}] | |
| # ------------------------------------------------------------------- | |
| def load_and_process_document(): | |
| """Load the Samsung manual, split it, embed it, and create vectorstore.""" | |
| if not os.path.exists(DOC_PATH): | |
| raise FileNotFoundError(f"❌ Document not found at: {DOC_PATH}") | |
| print("📄 Loading document...") | |
| # Force UTF-8 encoding to handle special characters | |
| loader = TextLoader(DOC_PATH, encoding="utf-8") | |
| docs = loader.load() | |
| print("✂️ Splitting document into chunks...") | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| texts = text_splitter.split_documents(docs) | |
| print("🧠 Creating embeddings...") | |
| embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME_EMBEDDINGS) | |
| print("💾 Building Chroma vectorstore...") | |
| vectorstore = Chroma.from_documents( | |
| documents=texts, | |
| embedding=embeddings, | |
| persist_directory=DB_DIR | |
| ) | |
| return vectorstore, len(texts) # return number of chunks | |
| def get_conversational_chain(vectorstore): | |
| """Create the conversational retrieval chain.""" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID_LLM) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID_LLM) | |
| pipe = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_length=512, | |
| temperature=0.1, | |
| top_p=0.95, | |
| repetition_penalty=1.2 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=vectorstore.as_retriever(search_kwargs={"k": 2}), | |
| memory=memory | |
| ) | |
| return chain | |
| def chatbot_response(user_input): | |
| """Generate chatbot response from conversation chain.""" | |
| global conversation_chain, chat_history | |
| if conversation_chain is None: | |
| chat_history.append({ | |
| "role": "assistant", | |
| "content": "⚠️ The chatbot is not ready. Please check the server logs." | |
| }) | |
| return chat_history | |
| chat_history.append({"role": "user", "content": user_input}) | |
| response = conversation_chain({"question": user_input}) | |
| ai_answer = response["answer"] | |
| chat_history.append({"role": "assistant", "content": ai_answer}) | |
| return chat_history | |
| # ------------------------------------------------------------------- | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🤖 Chat with Samsung Manual") | |
| gr.Markdown("Ask questions about the **Samsung Manual** document below:") | |
| # Status info | |
| status_md = gr.Markdown("⏳ Initializing chatbot...") | |
| # Chat interface | |
| chatbot = gr.Chatbot(label="Conversation", type="messages") | |
| user_input = gr.Textbox( | |
| label="Type your question here…", | |
| placeholder="Ask me about the Samsung manual..." | |
| ) | |
| submit_btn = gr.Button("Ask") | |
| # ------------------------------------------------------------------- | |
| # Initialization function to show status | |
| def init_chatbot(): | |
| global conversation_chain | |
| try: | |
| if not os.path.exists(DB_DIR) or not os.listdir(DB_DIR): | |
| # Rebuild vectorstore | |
| vectorstore, chunks = load_and_process_document() | |
| msg = f"✅ Manual processed and stored! Total chunks: {chunks}" | |
| else: | |
| embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME_EMBEDDINGS) | |
| vectorstore = Chroma(persist_directory=DB_DIR, embedding_function=embeddings) | |
| chunks = len(vectorstore._collection.get()["metadatas"]) | |
| msg = f"✅ Chroma DB loaded! Total chunks: {chunks}" | |
| conversation_chain = get_conversational_chain(vectorstore) | |
| return msg | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"❌ Failed to initialize chatbot: {e}" | |
| # Initialize on startup | |
| status_md.value = init_chatbot() | |
| submit_btn.click( | |
| fn=chatbot_response, | |
| inputs=user_input, | |
| outputs=chatbot | |
| ) | |
| # ------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() | |