import streamlit as st import torch import os import pickle import faiss import numpy as np from transformers import AutoModel, AutoProcessor, AutoTokenizer from typing import List, Dict import time # ======================================== # 🎨 STREAMLIT PAGE CONFIG # ======================================== st.set_page_config( page_title="BeRU Chat - RAG Assistant", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) # ======================================== # 🎯 CACHING FOR MODEL LOADING # ======================================== @st.cache_resource def load_embedding_model(): """Load VLM2Vec embedding model""" st.write("⏳ Loading embedding model...") device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained( "TIGER-Lab/VLM2Vec-Qwen2VL-2B", trust_remote_code=True, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) processor = AutoProcessor.from_pretrained( "TIGER-Lab/VLM2Vec-Qwen2VL-2B", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "TIGER-Lab/VLM2Vec-Qwen2VL-2B", trust_remote_code=True ) model.eval() st.success("✅ Embedding model loaded!") return model, processor, tokenizer, device @st.cache_resource def load_llm_model(): """Load Mistral 7B LLM""" st.write("⏳ Loading language model...") device = "cuda" if torch.cuda.is_available() else "cpu" from transformers import AutoModelForCausalLM, BitsAndBytesConfig # 4-bit quantization config for efficiency quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.3" ) st.success("✅ Language model loaded!") return model, tokenizer, device @st.cache_resource def load_faiss_index(): """Load FAISS index if exists""" if os.path.exists("VLM2Vec-V2rag2/text_index.faiss"): st.write("⏳ Loading FAISS index...") index = faiss.read_index("VLM2Vec-V2rag2/text_index.faiss") st.success("✅ FAISS index loaded!") return index else: st.warning("⚠️ FAISS index not found. Please build the index first.") return None # ======================================== # 💬 EMBEDDING & RETRIEVAL FUNCTIONS # ======================================== def get_embeddings(texts: List[str], model, processor, tokenizer, device) -> np.ndarray: """Generate embeddings for texts""" embeddings_list = [] for text in texts: inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) embedding = outputs.hidden_states[-1].mean(dim=1).cpu().numpy() embeddings_list.append(embedding.flatten()) return np.array(embeddings_list) def retrieve_documents(query: str, model, processor, tokenizer, device, faiss_index, k: int = 5) -> List[Dict]: """Retrieve relevant documents using FAISS""" if faiss_index is None: return [] # Get query embedding query_embedding = get_embeddings([query], model, processor, tokenizer, device) # Search FAISS index distances, indices = faiss_index.search(query_embedding, k) # Load documents metadata (assuming you have this stored) results = [] for idx in indices[0]: if idx >= 0: results.append({ "index": idx, "distance": float(distances[0][list(indices[0]).index(idx)]) }) return results def generate_response(query: str, context: str, model, tokenizer, device) -> str: """Generate response using Mistral""" prompt = f"""[INST] You are a helpful assistant answering questions about technical documentation. Context: {context} Question: {query} [/INST]""" inputs = tokenizer(prompt, return_tensors="pt", max_length=2048, truncation=True).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.95, do_sample=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response.split("[/INST]")[1].strip() if "[/INST]" in response else response # ======================================== # 🎨 STREAMLIT UI # ======================================== st.title("🤖 BeRU Chat Assistant") st.markdown("*100% Offline RAG System with Mistral 7B & VLM2Vec*") # Sidebar Configuration with st.sidebar: st.header("⚙️ Configuration") device_info = "🟢 GPU" if torch.cuda.is_available() else "🔴 CPU" st.metric("Device", device_info) num_results = st.slider("Retrieve top K documents", 1, 10, 5) temperature = st.slider("Response Temperature", 0.1, 1.0, 0.7) st.divider() st.markdown("### 📊 Project Info") st.markdown(""" - **Model**: Mistral 7B Instruct v0.3 - **Embeddings**: VLM2Vec-Qwen2VL-2B - **Vector Store**: FAISS with 10K+ documents - **Retrieval**: Hybrid (Dense + BM25) """) # Main Chat Interface col1, col2 = st.columns([3, 1]) with col1: st.subheader("💬 Ask a Question") with col2: if st.button("🔄 Clear Chat", use_container_width=True): st.session_state.messages = [] st.rerun() # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "models_loaded" not in st.session_state: st.session_state.models_loaded = False # Load models if not st.session_state.models_loaded: st.info("📦 Loading models on first run... This may take 2-3 minutes.") try: embed_model, processor, tokenizer_embed, embed_device = load_embedding_model() llm_model, tokenizer_llm, llm_device = load_llm_model() faiss_idx = load_faiss_index() st.session_state.embed_model = embed_model st.session_state.processor = processor st.session_state.tokenizer_embed = tokenizer_embed st.session_state.embed_device = embed_device st.session_state.llm_model = llm_model st.session_state.tokenizer_llm = tokenizer_llm st.session_state.llm_device = llm_device st.session_state.faiss_idx = faiss_idx st.session_state.models_loaded = True st.success("✅ All models loaded successfully!") except Exception as e: st.error(f"❌ Error loading models: {str(e)}") st.stop() # Chat Interface st.markdown("---") # Display chat history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # User input user_input = st.chat_input("Type your question here...", key="user_input") if user_input: # Add user message to chat st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) # Generate response with st.chat_message("assistant"): st.write("🔍 Retrieving relevant documents...") # Retrieve documents retrieved = retrieve_documents( user_input, st.session_state.embed_model, st.session_state.processor, st.session_state.tokenizer_embed, st.session_state.embed_device, st.session_state.faiss_idx, k=num_results ) context = "\n\n".join([f"Document {i+1}: Context from index {doc['index']}" for i, doc in enumerate(retrieved)]) st.write("💭 Generating response...") # Generate response response = generate_response( user_input, context, st.session_state.llm_model, st.session_state.tokenizer_llm, st.session_state.llm_device ) st.markdown(response) # Add to chat history st.session_state.messages.append({"role": "assistant", "content": response}) # Footer st.markdown("---") st.markdown("""

BeRU Chat Assistant | Powered by Mistral 7B + VLM2Vec | 100% Offline

GitHub | Hugging Face

""", unsafe_allow_html=True)