dmo-rag-chatbot / app.py
Aakash1703's picture
Upload app.py
47c1c57 verified
import gradio as gr
import numpy as np
import faiss
import pickle
import json
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Load everything at startup
print("Loading DMO RAG Chatbot...")
embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
index = faiss.read_index('dmo_knowledge_base.faiss')
with open('dmo_documents.pkl', 'rb') as f:
doc_data = pickle.load(f)
docs = doc_data['texts']
metadata = doc_data['metadata']
with open('dmo_config.json', 'r') as f:
config = json.load(f)
# Load local LLM
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype="auto",
device_map="auto"
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
print("Bot ready!")
def retrieve(query, k=3):
query_embedding = embed_model.encode([query])
query_embedding = np.array(query_embedding).astype('float32')
faiss.normalize_L2(query_embedding)
distances, indices = index.search(query_embedding, k)
results = []
for idx, score in zip(indices[0], distances[0]):
if idx >= 0 and idx < len(docs):
results.append({
'text': docs[idx],
'metadata': metadata[idx],
'score': float(score)
})
return results
def generate(query, retrieved_docs):
context = "\n\n".join([f"Document {i+1}: {doc['text']}" for i, doc in enumerate(retrieved_docs)])
messages = [
{"role": "system", "content": "You are a helpful travel assistant for a Destination Marketing Organization. Use the provided documents to answer accurately and concisely."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = pipe(prompt, return_full_text=False)
return output[0]['generated_text'].strip()
def chat(message, history):
docs = retrieve(message)
answer = generate(message, docs)
sources = []
for i, doc in enumerate(docs):
src = doc['metadata'].get('city', doc['metadata'].get('topic', 'unknown'))
sources.append(f"[{i+1}] {src} (score: {doc['score']:.2f})")
sources_text = "\n".join(sources)
full_response = f"{answer}\n\n---\n**Sources:**\n{sources_text}"
return full_response
# Gradio UI
demo = gr.ChatInterface(
fn=chat,
title="DMO Destination Assistant",
description="Ask me about travel destinations, attractions, tips, and more! I use a knowledge base of travel guides and FAQs to answer your questions.",
examples=[
"What are the top attractions in Paris?",
"Best time to visit Tokyo?",
"How do I stay safe while traveling?",
"What should I pack for Asia?",
"Tell me about Barcelona beaches",
"Best food to try in Bangkok?",
"How do I get around in Europe?",
],
)
if __name__ == "__main__":
import os
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)