Spaces:
Build error
Build error
| import os | |
| import zipfile | |
| import torch | |
| import faiss | |
| import logging | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from sentence_transformers import SentenceTransformer | |
| from langchain.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS as LangChainFAISS | |
| from langchain.docstore import InMemoryDocstore | |
| from langchain.schema import Document | |
| from langchain.llms import HuggingFacePipeline | |
| # === Logger === | |
| logging.basicConfig(level=logging.ERROR) | |
| # === 1. Extract the Knowledge Base ZIP === | |
| if os.path.exists("roosevelt_knowledge_base.zip"): | |
| with zipfile.ZipFile("roosevelt_knowledge_base.zip", "r") as zip_ref: | |
| zip_ref.extractall("roosevelt_knowledge_base") | |
| print("โ Knowledge base extracted.") | |
| # === 2. Load Markdown Files === | |
| KB_PATH = "/Users/toajibul/Documents/Data_Science/Amdari/Chatbot Uni/roosevelt_knowledge_base" | |
| files = [os.path.join(dp, f) for dp, _, fn in os.walk(KB_PATH) for f in fn if f.endswith(".md")] | |
| docs = [doc for f in files for doc in TextLoader(f, encoding="utf-8").load()] | |
| print(f"โ Loaded {len(docs)} documents.") | |
| # === 3. Chunking === | |
| def get_dynamic_chunk_size(text): | |
| if len(text) < 1000: | |
| return 300 | |
| elif len(text) < 5000: | |
| return 500 | |
| else: | |
| return 1000 | |
| chunks = [] | |
| for doc in docs: | |
| chunk_size = get_dynamic_chunk_size(doc.page_content) | |
| chunk_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100) | |
| chunks.extend(chunk_splitter.split_documents([doc])) | |
| texts = [chunk.page_content for chunk in chunks] | |
| # === 4. Vectorstore (FAISS) === | |
| embed_model_id = "sentence-transformers/all-mpnet-base-v2" #"sentence-transformers/all-mpnet-base-v2" | |
| embedder = SentenceTransformer(embed_model_id) | |
| embeddings = embedder.encode(texts, show_progress_bar=False) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(np.array(embeddings, dtype="float32")) | |
| docs = [Document(page_content=t) for t in texts] | |
| docstore = InMemoryDocstore({str(i): docs[i] for i in range(len(docs))}) | |
| id_map = {i: str(i) for i in range(len(docs))} | |
| embed_fn = HuggingFaceEmbeddings(model_name=embed_model_id) | |
| vectorstore = LangChainFAISS( | |
| index=index, | |
| docstore=docstore, | |
| index_to_docstore_id=id_map, | |
| embedding_function=embed_fn | |
| ) | |
| print("โ FAISS vectorstore ready.") | |
| # === 5. Load Falcon-e-1B-Instruct === | |
| print("๐ Loading Falcon-1B model...") | |
| model_id = "tiiuae/falcon-rw-1b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| offload_folder="./offload", | |
| low_cpu_mem_usage=True | |
| ) | |
| text_gen_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| torch_dtype=torch.float16, | |
| return_full_text=False, | |
| do_sample=False, | |
| max_new_tokens=200, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| llm = HuggingFacePipeline(pipeline=text_gen_pipeline) | |
| # === 6. Prompt Format and Q&A === | |
| def truncate_context(context, max_tokens=1024): | |
| input_ids = tokenizer(context, truncation=True, max_length=max_tokens, return_tensors="pt").input_ids | |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| def format_prompt(context, question): | |
| return ( | |
| "You are the Roosevelt University Assistant.\n\n" | |
| "Use only the information provided in the context to answer the student's question. " | |
| "If the answer is not found in the context, respond: " | |
| "\"I'm sorry, but I donโt have that information available right now.\"\n\n" | |
| f"Context:\n{context.strip()}\n\n" | |
| f"Student Question: {question.strip()}\n" | |
| "Assistant Answer:" | |
| ) | |
| def answer_fn(question): | |
| if not question.strip(): | |
| return "Please enter a valid question." | |
| try: | |
| docs = vectorstore.similarity_search(question, k=5) | |
| if not docs: | |
| return "I'm sorry, I couldn't find any relevant information for your query." | |
| context = "\n\n".join(d.page_content for d in docs) | |
| prompt = format_prompt(context, question) | |
| response = llm.invoke(prompt).strip() | |
| return response | |
| except Exception as e: | |
| logging.error(f"Error: {e}") | |
| return "An internal error occurred while processing your question." | |
| # === 7. Gradio Interface === | |
| def chat_fn(user_message, history): | |
| bot_response = answer_fn(user_message) | |
| history = history + [(user_message, bot_response)] | |
| print("Returning from chat_fn:", history) | |
| return history, history | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| with gr.Column(elem_id="main-col", scale=1): | |
| gr.Markdown( | |
| """ | |
| <div style="text-align:center"> | |
| <h1 style="color:#003366;">๐ Roosevelt University Assistant</h1> | |
| <p style="font-size:16px; color:#444;"> | |
| Ask questions about Roosevelt courses, admissions, student life, or tuition fees. | |
| The assistant replies using only the university's official documentation. | |
| </p> | |
| </div> | |
| """, | |
| elem_id="title" | |
| ) | |
| chatbot = gr.Chatbot(label="๐ Roosevelt Assistant Chat", height=400) | |
| state = gr.State([]) | |
| user_input = gr.Textbox( | |
| placeholder="Ask a question about Roosevelt...", | |
| show_label=False, | |
| container=False | |
| ) | |
| submit_btn = gr.Button("Ask") | |
| submit_btn.click(chat_fn, inputs=[user_input, state], outputs=[chatbot, state]).then( | |
| lambda: "", None, user_input # Clear input box | |
| ) | |
| demo.launch() | |