File size: 4,500 Bytes
9d6c111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import zipfile
import torch
import faiss
import numpy as np
import gradio as gr

from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS as LangChainFAISS
from langchain.docstore import InMemoryDocstore
from langchain.schema import Document
from langchain.llms import HuggingFacePipeline

# === 1. Extract ZIP Knowledge Base ===
if os.path.exists("md_knowledge_base.zip"):
    with zipfile.ZipFile("md_knowledge_base.zip", "r") as zip_ref:
        zip_ref.extractall("md_knowledge_base")
    print("✅ Knowledge base extracted.")

# === 2. Load Markdown Files ===
KB_PATH = "md_knowledge_base"
files = [os.path.join(dp, f) for dp, _, fn in os.walk(KB_PATH) for f in fn if f.endswith(".md")]
docs = [doc for f in files for doc in TextLoader(f, encoding="utf-8").load()]
print(f"✅ Loaded {len(docs)} documents.")

# === 3. Split into Chunks ===
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
def get_dynamic_chunk_size(text):
    if len(text) < 1000: return 300
    elif len(text) < 5000: return 500
    else: return 1000

chunks = []
for doc in docs:
    chunk_size = get_dynamic_chunk_size(doc.page_content)
    chunk_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
    chunks.extend(chunk_splitter.split_documents([doc]))
texts = [chunk.page_content for chunk in chunks]

# === 4. Build Vectorstore ===
embed_model_id = "distilbert-base-uncased"
embedder = SentenceTransformer(embed_model_id)
embeddings = embedder.encode(texts, show_progress_bar=False)

dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype="float32"))

docs = [Document(page_content=t) for t in texts]
docstore = InMemoryDocstore({str(i): docs[i] for i in range(len(docs))})
id_map = {i: str(i) for i in range(len(docs))}
embed_fn = HuggingFaceEmbeddings(model_name=embed_model_id)

vectorstore = LangChainFAISS(
    index=index,
    docstore=docstore,
    index_to_docstore_id=id_map,
    embedding_function=embed_fn
)

print("✅ FAISS vectorstore ready.")

# === 5. Load GPT-2 for Generation ===
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu")

text_gen_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
    return_full_text=False,
    do_sample=False,
    max_new_tokens=200,
    pad_token_id=tokenizer.eos_token_id
)

llm = HuggingFacePipeline(pipeline=text_gen_pipeline)
print("✅ GPT-2 loaded.")

# === 6. Prompt Formatting and Answer Logic ===
def truncate_context(context, max_length=1024):
    tokens = tokenizer.encode(context)
    if len(tokens) > max_length:
        tokens = tokens[:max_length]
    return tokenizer.decode(tokens, skip_special_tokens=True)

def format_prompt(context, question):
    return (
        "You are the Cambridge University Assistant—helping students with questions about courses, admissions, fees, etc. "
        "Only use the information in the context below to answer the question.\n\n"
        f"Context:\n{truncate_context(context)}\n\n"
        f"Student Question: {question}\n"
        "Assistant Answer:"
    )

def answer_fn(question):
    docs = vectorstore.similarity_search(question, k=5)
    if not docs:
        return "I'm sorry, I couldn't find any relevant information for your query."
    context = "\n\n".join(d.page_content for d in docs)
    prompt = format_prompt(context, question)
    try:
        response = llm.invoke(prompt).strip()
        return response
    except Exception as e:
        return f"An error occurred: {e}"

# === 7. Gradio UI ===
def chat_fn(user_message, history):
    bot_response = answer_fn(user_message)
    history = history + [(user_message, bot_response)]
    return history, history

with gr.Blocks() as demo:
    gr.Markdown("## 📘 University of Cambridge Assistant")
    chatbot = gr.Chatbot()
    state = gr.State([])
    user_input = gr.Textbox(placeholder="Ask a question about Cambridge...", show_label=False)

    user_input.submit(fn=chat_fn, inputs=[user_input, state], outputs=[chatbot, state])

demo.launch()