demo3 / app.py
seanerons's picture
Upload app.py with huggingface_hub
9aedf04 verified
import os
import zipfile
import torch
import faiss
import numpy as np
import gradio as gr
from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS as LangChainFAISS
from langchain.docstore import InMemoryDocstore
from langchain.schema import Document
from langchain.llms import HuggingFacePipeline
# === 1. Extract ZIP Knowledge Base ===
if os.path.exists("md_knowledge_base.zip"):
with zipfile.ZipFile("md_knowledge_base.zip", "r") as zip_ref:
zip_ref.extractall("md_knowledge_base")
print("✅ Knowledge base extracted.")
# === 2. Load Markdown Files ===
KB_PATH = "md_knowledge_base"
files = [os.path.join(dp, f) for dp, _, fn in os.walk(KB_PATH) for f in fn if f.endswith(".md")]
docs = [doc for f in files for doc in TextLoader(f, encoding="utf-8").load()]
print(f"✅ Loaded {len(docs)} documents.")
# === 3. Split into Chunks ===
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
def get_dynamic_chunk_size(text):
if len(text) < 1000: return 300
elif len(text) < 5000: return 500
else: return 1000
chunks = []
for doc in docs:
chunk_size = get_dynamic_chunk_size(doc.page_content)
chunk_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
chunks.extend(chunk_splitter.split_documents([doc]))
texts = [chunk.page_content for chunk in chunks]
# === 4. Build Vectorstore ===
embed_model_id = "distilbert-base-uncased"
embedder = SentenceTransformer(embed_model_id)
embeddings = embedder.encode(texts, show_progress_bar=False)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings, dtype="float32"))
docs = [Document(page_content=t) for t in texts]
docstore = InMemoryDocstore({str(i): docs[i] for i in range(len(docs))})
id_map = {i: str(i) for i in range(len(docs))}
embed_fn = HuggingFaceEmbeddings(model_name=embed_model_id)
vectorstore = LangChainFAISS(
index=index,
docstore=docstore,
index_to_docstore_id=id_map,
embedding_function=embed_fn
)
print("✅ FAISS vectorstore ready.")
# === 5. Load GPT-2 for Generation ===
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu")
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1,
return_full_text=False,
do_sample=False,
max_new_tokens=200,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=text_gen_pipeline)
print("✅ GPT-2 loaded.")
# === 6. Prompt Formatting and Answer Logic ===
def truncate_context(context, max_length=1024):
tokens = tokenizer.encode(context)
if len(tokens) > max_length:
tokens = tokens[:max_length]
return tokenizer.decode(tokens, skip_special_tokens=True)
def format_prompt(context, question):
return (
"You are the Cambridge University Assistant—helping students with questions about courses, admissions, fees, etc. "
"Only use the information in the context below to answer the question.\n\n"
f"Context:\n{truncate_context(context)}\n\n"
f"Student Question: {question}\n"
"Assistant Answer:"
)
def answer_fn(question):
docs = vectorstore.similarity_search(question, k=5)
if not docs:
return "I'm sorry, I couldn't find any relevant information for your query."
context = "\n\n".join(d.page_content for d in docs)
prompt = format_prompt(context, question)
try:
response = llm.invoke(prompt).strip()
return response
except Exception as e:
return f"An error occurred: {e}"
# === 7. Gradio UI ===
def chat_fn(user_message, history):
bot_response = answer_fn(user_message)
history = history + [(user_message, bot_response)]
return history, history
with gr.Blocks() as demo:
gr.Markdown("## 📘 University of Cambridge Assistant")
chatbot = gr.Chatbot()
state = gr.State([])
user_input = gr.Textbox(placeholder="Ask a question about Cambridge...", show_label=False)
user_input.submit(fn=chat_fn, inputs=[user_input, state], outputs=[chatbot, state])
demo.launch()