Spaces:
Runtime error
Runtime error
| import os | |
| import zipfile | |
| import torch # β Import torch so empty_cache works | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| # --- Step 1: Unzip FAISS index --- | |
| if not os.path.exists("faiss_index") and os.path.exists("faiss_index.zip"): | |
| with zipfile.ZipFile("faiss_index.zip", "r") as zip_ref: | |
| zip_ref.extractall(".") | |
| # --- Step 2: Load embedding and vectorstore --- | |
| embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
| vectordb = FAISS.load_local("faiss_index", embedding_model, allow_dangerous_deserialization=True) | |
| # --- Step 3: Load the LLM --- | |
| model_id = "tiiuae/falcon3-1b-instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # β Use device_map + float16 to save memory | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| pad_token_id=tokenizer.eos_token_id, | |
| max_new_tokens=200, | |
| do_sample=True, | |
| temperature=1.0, | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # --- Step 4: Setup memory and QA chain --- | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| prompt = PromptTemplate.from_template(""" | |
| You are a helpful assistant at the University of Hertfordshire. Use the context below to answer the question clearly and factually. | |
| If the answer is not in the context, say you don't know. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Helpful Answer: | |
| """) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=vectordb.as_retriever(search_kwargs={"k": 3}), | |
| memory=memory, | |
| chain_type="stuff", | |
| combine_docs_chain_kwargs={"prompt": prompt} | |
| ) | |
| UH_LOGO = "images/UH.png" | |
| # --- Step 5: Define chatbot logic --- | |
| def chat(message, history): | |
| result = qa_chain.invoke({"question": message}) | |
| response = result.get("answer", "") | |
| response = response.split("Answer:")[-1].replace("<|assistant|>", "").strip() | |
| # β Actually clear unused GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return response | |
| # --- Step 6: UI --- | |
| sample_questions = [ | |
| "How do I register as a new student?", | |
| "Where can I find accommodation?", | |
| "Can I renew my tenancy agreement?", | |
| "What do I do on my first day?", | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Image(UH_LOGO, show_label=False, container=False, scale=1) | |
| gr.Markdown("## ASK Herts Students Help Chatbot π€") | |
| chatbot = gr.Chatbot() | |
| txt = gr.Textbox(placeholder="Ask me anything about university life...", label="Your question") | |
| submit = gr.Button("Submit") | |
| gr.Markdown("#### π‘ Sample Questions:") | |
| with gr.Row(): | |
| for q in sample_questions: | |
| gr.Button(q).click(lambda x=q: gr.update(value=x), outputs=[txt]) | |
| def respond(message, history): | |
| answer = chat(message, history) | |
| history.append((message, answer)) | |
| return "", history | |
| submit.click(respond, [txt, chatbot], [txt, chatbot]) | |
| txt.submit(respond, [txt, chatbot], [txt, chatbot]) | |
| demo.launch() | |