Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| import pickle | |
| import json | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # Load everything at startup | |
| print("Loading DMO RAG Chatbot...") | |
| embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| index = faiss.read_index('dmo_knowledge_base.faiss') | |
| with open('dmo_documents.pkl', 'rb') as f: | |
| doc_data = pickle.load(f) | |
| docs = doc_data['texts'] | |
| metadata = doc_data['metadata'] | |
| with open('dmo_config.json', 'r') as f: | |
| config = json.load(f) | |
| # Load local LLM | |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| print("Bot ready!") | |
| def retrieve(query, k=3): | |
| query_embedding = embed_model.encode([query]) | |
| query_embedding = np.array(query_embedding).astype('float32') | |
| faiss.normalize_L2(query_embedding) | |
| distances, indices = index.search(query_embedding, k) | |
| results = [] | |
| for idx, score in zip(indices[0], distances[0]): | |
| if idx >= 0 and idx < len(docs): | |
| results.append({ | |
| 'text': docs[idx], | |
| 'metadata': metadata[idx], | |
| 'score': float(score) | |
| }) | |
| return results | |
| def generate(query, retrieved_docs): | |
| context = "\n\n".join([f"Document {i+1}: {doc['text']}" for i, doc in enumerate(retrieved_docs)]) | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful travel assistant for a Destination Marketing Organization. Use the provided documents to answer accurately and concisely."}, | |
| {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"} | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| output = pipe(prompt, return_full_text=False) | |
| return output[0]['generated_text'].strip() | |
| def chat(message, history): | |
| docs = retrieve(message) | |
| answer = generate(message, docs) | |
| sources = [] | |
| for i, doc in enumerate(docs): | |
| src = doc['metadata'].get('city', doc['metadata'].get('topic', 'unknown')) | |
| sources.append(f"[{i+1}] {src} (score: {doc['score']:.2f})") | |
| sources_text = "\n".join(sources) | |
| full_response = f"{answer}\n\n---\n**Sources:**\n{sources_text}" | |
| return full_response | |
| # Gradio UI | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="DMO Destination Assistant", | |
| description="Ask me about travel destinations, attractions, tips, and more! I use a knowledge base of travel guides and FAQs to answer your questions.", | |
| examples=[ | |
| "What are the top attractions in Paris?", | |
| "Best time to visit Tokyo?", | |
| "How do I stay safe while traveling?", | |
| "What should I pack for Asia?", | |
| "Tell me about Barcelona beaches", | |
| "Best food to try in Bangkok?", | |
| "How do I get around in Europe?", | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| import os | |
| port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |