|
|
import os |
|
|
import zipfile |
|
|
import torch |
|
|
import faiss |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from langchain.document_loaders import TextLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.vectorstores import FAISS as LangChainFAISS |
|
|
from langchain.docstore import InMemoryDocstore |
|
|
from langchain.schema import Document |
|
|
from langchain.llms import HuggingFacePipeline |
|
|
|
|
|
|
|
|
if os.path.exists("md_knowledge_base.zip"): |
|
|
with zipfile.ZipFile("md_knowledge_base.zip", "r") as zip_ref: |
|
|
zip_ref.extractall("md_knowledge_base") |
|
|
print("✅ Knowledge base extracted.") |
|
|
|
|
|
|
|
|
KB_PATH = "md_knowledge_base" |
|
|
files = [os.path.join(dp, f) for dp, _, fn in os.walk(KB_PATH) for f in fn if f.endswith(".md")] |
|
|
docs = [doc for f in files for doc in TextLoader(f, encoding="utf-8").load()] |
|
|
print(f"✅ Loaded {len(docs)} documents.") |
|
|
|
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) |
|
|
def get_dynamic_chunk_size(text): |
|
|
if len(text) < 1000: return 300 |
|
|
elif len(text) < 5000: return 500 |
|
|
else: return 1000 |
|
|
|
|
|
chunks = [] |
|
|
for doc in docs: |
|
|
chunk_size = get_dynamic_chunk_size(doc.page_content) |
|
|
chunk_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100) |
|
|
chunks.extend(chunk_splitter.split_documents([doc])) |
|
|
texts = [chunk.page_content for chunk in chunks] |
|
|
|
|
|
|
|
|
embed_model_id = "distilbert-base-uncased" |
|
|
embedder = SentenceTransformer(embed_model_id) |
|
|
embeddings = embedder.encode(texts, show_progress_bar=False) |
|
|
|
|
|
dim = embeddings.shape[1] |
|
|
index = faiss.IndexFlatL2(dim) |
|
|
index.add(np.array(embeddings, dtype="float32")) |
|
|
|
|
|
docs = [Document(page_content=t) for t in texts] |
|
|
docstore = InMemoryDocstore({str(i): docs[i] for i in range(len(docs))}) |
|
|
id_map = {i: str(i) for i in range(len(docs))} |
|
|
embed_fn = HuggingFaceEmbeddings(model_name=embed_model_id) |
|
|
|
|
|
vectorstore = LangChainFAISS( |
|
|
index=index, |
|
|
docstore=docstore, |
|
|
index_to_docstore_id=id_map, |
|
|
embedding_function=embed_fn |
|
|
) |
|
|
|
|
|
print("✅ FAISS vectorstore ready.") |
|
|
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
text_gen_pipeline = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=0 if torch.cuda.is_available() else -1, |
|
|
return_full_text=False, |
|
|
do_sample=False, |
|
|
max_new_tokens=200, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=text_gen_pipeline) |
|
|
print("✅ GPT-2 loaded.") |
|
|
|
|
|
|
|
|
def truncate_context(context, max_length=1024): |
|
|
tokens = tokenizer.encode(context) |
|
|
if len(tokens) > max_length: |
|
|
tokens = tokens[:max_length] |
|
|
return tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
|
|
|
def format_prompt(context, question): |
|
|
return ( |
|
|
"You are the Cambridge University Assistant—helping students with questions about courses, admissions, fees, etc. " |
|
|
"Only use the information in the context below to answer the question.\n\n" |
|
|
f"Context:\n{truncate_context(context)}\n\n" |
|
|
f"Student Question: {question}\n" |
|
|
"Assistant Answer:" |
|
|
) |
|
|
|
|
|
def answer_fn(question): |
|
|
docs = vectorstore.similarity_search(question, k=5) |
|
|
if not docs: |
|
|
return "I'm sorry, I couldn't find any relevant information for your query." |
|
|
context = "\n\n".join(d.page_content for d in docs) |
|
|
prompt = format_prompt(context, question) |
|
|
try: |
|
|
response = llm.invoke(prompt).strip() |
|
|
return response |
|
|
except Exception as e: |
|
|
return f"An error occurred: {e}" |
|
|
|
|
|
|
|
|
def chat_fn(user_message, history): |
|
|
bot_response = answer_fn(user_message) |
|
|
history = history + [(user_message, bot_response)] |
|
|
return history, history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 📘 University of Cambridge Assistant") |
|
|
chatbot = gr.Chatbot() |
|
|
state = gr.State([]) |
|
|
user_input = gr.Textbox(placeholder="Ask a question about Cambridge...", show_label=False) |
|
|
|
|
|
user_input.submit(fn=chat_fn, inputs=[user_input, state], outputs=[chatbot, state]) |
|
|
|
|
|
demo.launch() |
|
|
|