Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from scraper import scrape | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| # ===================================================== | |
| # 0. Config | |
| # ===================================================== | |
| HF_TOKEN = os.environ.get("HF_API_KEY") | |
| MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # Much better for RAG | |
| # ===================================================== | |
| # 1. Load + Build Knowledge Base | |
| # ===================================================== | |
| print("🔄 Scraping website...") | |
| raw_docs = scrape() | |
| texts = [] | |
| metas = [] | |
| for d in raw_docs: | |
| texts.append(d["text"]) | |
| metas.append({"source": d["source"]}) | |
| print("✂️ Splitting documents...") | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=150, | |
| ) | |
| documents = splitter.create_documents(texts, metas) | |
| print("🧠 Building embeddings...") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2" | |
| ) | |
| print("📦 Building vector store...") | |
| db = FAISS.from_documents(documents, embeddings) | |
| retriever = db.as_retriever(search_kwargs={"k": 4}) | |
| print("✅ Knowledge base ready!") | |
| # ===================================================== | |
| # 2. Prompt Builder | |
| # ===================================================== | |
| def build_prompt(question, docs): | |
| context = "\n\n".join( | |
| [ | |
| f"[Source: {d.metadata['source']}]\n{d.page_content}" | |
| for d in docs | |
| ] | |
| ) | |
| prompt = f""" | |
| You are an academic assistant for SPJIMR. | |
| Answer ONLY using the context below. | |
| If information is missing, say "I don't know." | |
| --------------------- | |
| CONTEXT: | |
| {context} | |
| --------------------- | |
| QUESTION: | |
| {question} | |
| ANSWER: | |
| """ | |
| return prompt.strip() | |
| # ===================================================== | |
| # 3. LLM Client | |
| # ===================================================== | |
| client = InferenceClient( | |
| model=MODEL_NAME, | |
| token=HF_TOKEN | |
| ) | |
| # ===================================================== | |
| # 4. Chat Function (Fixed Retriever API) | |
| # ===================================================== | |
| def chat(message, history): | |
| # New LangChain API | |
| docs = retriever.invoke(message) | |
| prompt = build_prompt(message, docs) | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| response = "" | |
| for chunk in client.chat_completion( | |
| messages=messages, | |
| max_tokens=700, | |
| temperature=0.3, | |
| stream=True, | |
| ): | |
| if chunk.choices[0].delta.content: | |
| token = chunk.choices[0].delta.content | |
| response += token | |
| yield response | |
| # ===================================================== | |
| # 5. Minimal Dark UI | |
| # ===================================================== | |
| custom_css = """ | |
| body { | |
| background: #0f172a !important; | |
| } | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto !important; | |
| } | |
| h1 { | |
| color: #e5e7eb; | |
| text-align: center; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| color: #9ca3af; | |
| margin-bottom: 20px; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| """ | |
| # ===================================================== | |
| # 6. App | |
| # ===================================================== | |
| with gr.Blocks( | |
| css=custom_css, | |
| theme=gr.themes.Base( | |
| primary_hue="indigo", | |
| neutral_hue="slate", | |
| ), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎓 SPJIMR AI Assistant | |
| <div class="subtitle"> | |
| Ask questions based on official SPJIMR website | |
| </div> | |
| """, | |
| elem_id="title" | |
| ) | |
| chatbot = gr.Chatbot( | |
| height=520, | |
| bubble_full_width=False, | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Ask about programs, admissions, faculty...", | |
| show_label=False, | |
| ) | |
| clear = gr.Button("Clear Chat") | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history): | |
| user_message = history[-1][0] | |
| history[-1][1] = "" | |
| for chunk in chat(user_message, history): | |
| history[-1][1] = chunk | |
| yield history | |
| msg.submit( | |
| user, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False, | |
| ).then( | |
| bot, | |
| chatbot, | |
| chatbot, | |
| ) | |
| clear.click(lambda: [], None, chatbot) | |
| # ===================================================== | |
| # 7. Launch | |
| # ===================================================== | |
| if __name__ == "__main__": | |
| demo.launch() |