spjimr-chatbot / app.py
Prof-Hunter's picture
Update app.py
fdcf881 verified
import os
import gradio as gr
from huggingface_hub import InferenceClient
from scraper import scrape
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# =====================================================
# 0. Config
# =====================================================
HF_TOKEN = os.environ.get("HF_API_KEY")
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # Much better for RAG
# =====================================================
# 1. Load + Build Knowledge Base
# =====================================================
print("🔄 Scraping website...")
raw_docs = scrape()
texts = []
metas = []
for d in raw_docs:
texts.append(d["text"])
metas.append({"source": d["source"]})
print("✂️ Splitting documents...")
splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=150,
)
documents = splitter.create_documents(texts, metas)
print("🧠 Building embeddings...")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2"
)
print("📦 Building vector store...")
db = FAISS.from_documents(documents, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 4})
print("✅ Knowledge base ready!")
# =====================================================
# 2. Prompt Builder
# =====================================================
def build_prompt(question, docs):
context = "\n\n".join(
[
f"[Source: {d.metadata['source']}]\n{d.page_content}"
for d in docs
]
)
prompt = f"""
You are an academic assistant for SPJIMR.
Answer ONLY using the context below.
If information is missing, say "I don't know."
---------------------
CONTEXT:
{context}
---------------------
QUESTION:
{question}
ANSWER:
"""
return prompt.strip()
# =====================================================
# 3. LLM Client
# =====================================================
client = InferenceClient(
model=MODEL_NAME,
token=HF_TOKEN
)
# =====================================================
# 4. Chat Function (Fixed Retriever API)
# =====================================================
def chat(message, history):
# New LangChain API
docs = retriever.invoke(message)
prompt = build_prompt(message, docs)
messages = [
{"role": "user", "content": prompt}
]
response = ""
for chunk in client.chat_completion(
messages=messages,
max_tokens=700,
temperature=0.3,
stream=True,
):
if chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
response += token
yield response
# =====================================================
# 5. Minimal Dark UI
# =====================================================
custom_css = """
body {
background: #0f172a !important;
}
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
h1 {
color: #e5e7eb;
text-align: center;
}
.subtitle {
text-align: center;
color: #9ca3af;
margin-bottom: 20px;
}
footer {
display: none !important;
}
"""
# =====================================================
# 6. App
# =====================================================
with gr.Blocks(
css=custom_css,
theme=gr.themes.Base(
primary_hue="indigo",
neutral_hue="slate",
),
) as demo:
gr.Markdown(
"""
# 🎓 SPJIMR AI Assistant
<div class="subtitle">
Ask questions based on official SPJIMR website
</div>
""",
elem_id="title"
)
chatbot = gr.Chatbot(
height=520,
bubble_full_width=False,
)
msg = gr.Textbox(
placeholder="Ask about programs, admissions, faculty...",
show_label=False,
)
clear = gr.Button("Clear Chat")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
user_message = history[-1][0]
history[-1][1] = ""
for chunk in chat(user_message, history):
history[-1][1] = chunk
yield history
msg.submit(
user,
[msg, chatbot],
[msg, chatbot],
queue=False,
).then(
bot,
chatbot,
chatbot,
)
clear.click(lambda: [], None, chatbot)
# =====================================================
# 7. Launch
# =====================================================
if __name__ == "__main__":
demo.launch()