Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from langchain.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.chains import RetrievalQA | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # Load and process documents | |
| doc_loader = TextLoader("dataset.txt") | |
| docs = doc_loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| split_docs = text_splitter.split_documents(docs) | |
| # Create vector database | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectordb = FAISS.from_documents(split_docs, embeddings) | |
| # Load model and create pipeline | |
| model_name = "01-ai/Yi-Coder-9B-Chat" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto") | |
| qa_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=500, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Set up LangChain | |
| llm = HuggingFacePipeline(pipeline=qa_pipeline) | |
| retriever = vectordb.as_retriever(search_kwargs={"k": 5}) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| retriever=retriever, | |
| chain_type="stuff", | |
| llm=llm, | |
| return_source_documents=False | |
| ) | |
| def preprocess_query(query): | |
| if "script" in query or "code" in query.lower(): | |
| return f"Write a CPSL script: {query}" | |
| return query | |
| def clean_response(response): | |
| result = response.get("result", "") | |
| if "Answer:" in result: | |
| return result.split("Answer:")[1].strip() | |
| return result.strip() | |
| def chatbot_response(user_input): | |
| processed_query = preprocess_query(user_input) | |
| raw_response = qa_chain.invoke({"query": processed_query}) | |
| return clean_response(raw_response) | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# CPSL Chatbot") | |
| chat_history = gr.Chatbot() | |
| user_input = gr.Textbox(label="Your Message:") | |
| send_button = gr.Button("Send") | |
| def interact(user_message, history): | |
| bot_reply = chatbot_response(user_message) | |
| history.append((user_message, bot_reply)) | |
| return history, history | |
| send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history]) | |