seanerons commited on
Commit
9d6c111
·
verified ·
1 Parent(s): 8519bef

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import torch
4
+ import faiss
5
+ import numpy as np
6
+ import gradio as gr
7
+
8
+ from transformers import GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
9
+ from sentence_transformers import SentenceTransformer
10
+ from langchain.document_loaders import TextLoader
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.vectorstores import FAISS as LangChainFAISS
14
+ from langchain.docstore import InMemoryDocstore
15
+ from langchain.schema import Document
16
+ from langchain.llms import HuggingFacePipeline
17
+
18
+ # === 1. Extract ZIP Knowledge Base ===
19
+ if os.path.exists("md_knowledge_base.zip"):
20
+ with zipfile.ZipFile("md_knowledge_base.zip", "r") as zip_ref:
21
+ zip_ref.extractall("md_knowledge_base")
22
+ print("✅ Knowledge base extracted.")
23
+
24
+ # === 2. Load Markdown Files ===
25
+ KB_PATH = "md_knowledge_base"
26
+ files = [os.path.join(dp, f) for dp, _, fn in os.walk(KB_PATH) for f in fn if f.endswith(".md")]
27
+ docs = [doc for f in files for doc in TextLoader(f, encoding="utf-8").load()]
28
+ print(f"✅ Loaded {len(docs)} documents.")
29
+
30
+ # === 3. Split into Chunks ===
31
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
32
+ def get_dynamic_chunk_size(text):
33
+ if len(text) < 1000: return 300
34
+ elif len(text) < 5000: return 500
35
+ else: return 1000
36
+
37
+ chunks = []
38
+ for doc in docs:
39
+ chunk_size = get_dynamic_chunk_size(doc.page_content)
40
+ chunk_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
41
+ chunks.extend(chunk_splitter.split_documents([doc]))
42
+ texts = [chunk.page_content for chunk in chunks]
43
+
44
+ # === 4. Build Vectorstore ===
45
+ embed_model_id = "distilbert-base-uncased"
46
+ embedder = SentenceTransformer(embed_model_id)
47
+ embeddings = embedder.encode(texts, show_progress_bar=False)
48
+
49
+ dim = embeddings.shape[1]
50
+ index = faiss.IndexFlatL2(dim)
51
+ index.add(np.array(embeddings, dtype="float32"))
52
+
53
+ docs = [Document(page_content=t) for t in texts]
54
+ docstore = InMemoryDocstore({str(i): docs[i] for i in range(len(docs))})
55
+ id_map = {i: str(i) for i in range(len(docs))}
56
+ embed_fn = HuggingFaceEmbeddings(model_name=embed_model_id)
57
+
58
+ vectorstore = LangChainFAISS(
59
+ index=index,
60
+ docstore=docstore,
61
+ index_to_docstore_id=id_map,
62
+ embedding_function=embed_fn
63
+ )
64
+
65
+ print("✅ FAISS vectorstore ready.")
66
+
67
+ # === 5. Load GPT-2 for Generation ===
68
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
69
+ model = AutoModelForCausalLM.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+ text_gen_pipeline = pipeline(
72
+ "text-generation",
73
+ model=model,
74
+ tokenizer=tokenizer,
75
+ device=0 if torch.cuda.is_available() else -1,
76
+ return_full_text=False,
77
+ do_sample=False,
78
+ max_new_tokens=200,
79
+ pad_token_id=tokenizer.eos_token_id
80
+ )
81
+
82
+ llm = HuggingFacePipeline(pipeline=text_gen_pipeline)
83
+ print("✅ GPT-2 loaded.")
84
+
85
+ # === 6. Prompt Formatting and Answer Logic ===
86
+ def truncate_context(context, max_length=1024):
87
+ tokens = tokenizer.encode(context)
88
+ if len(tokens) > max_length:
89
+ tokens = tokens[:max_length]
90
+ return tokenizer.decode(tokens, skip_special_tokens=True)
91
+
92
+ def format_prompt(context, question):
93
+ return (
94
+ "You are the Cambridge University Assistant—helping students with questions about courses, admissions, fees, etc. "
95
+ "Only use the information in the context below to answer the question.\n\n"
96
+ f"Context:\n{truncate_context(context)}\n\n"
97
+ f"Student Question: {question}\n"
98
+ "Assistant Answer:"
99
+ )
100
+
101
+ def answer_fn(question):
102
+ docs = vectorstore.similarity_search(question, k=5)
103
+ if not docs:
104
+ return "I'm sorry, I couldn't find any relevant information for your query."
105
+ context = "\n\n".join(d.page_content for d in docs)
106
+ prompt = format_prompt(context, question)
107
+ try:
108
+ response = llm.invoke(prompt).strip()
109
+ return response
110
+ except Exception as e:
111
+ return f"An error occurred: {e}"
112
+
113
+ # === 7. Gradio UI ===
114
+ def chat_fn(user_message, history):
115
+ bot_response = answer_fn(user_message)
116
+ history = history + [(user_message, bot_response)]
117
+ return history, history
118
+
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("## 📘 University of Cambridge Assistant")
121
+ chatbot = gr.Chatbot()
122
+ state = gr.State([])
123
+ user_input = gr.Textbox(placeholder="Ask a question about Cambridge...", show_label=False)
124
+
125
+ user_input.submit(fn=chat_fn, inputs=[user_input, state], outputs=[chatbot, state])
126
+
127
+ demo.launch()