""" BeRU RAG Chat App - Optimized for Hugging Face Spaces Deployment: https://huggingface.co/spaces/AnwinMJ/Beru """ import streamlit as st import torch import os import pickle import faiss import numpy as np from transformers import AutoModel, AutoProcessor, AutoTokenizer from typing import List, Dict import time import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ======================================== # 🎨 STREAMLIT PAGE CONFIG # ======================================== st.set_page_config( page_title="BeRU Chat - RAG Assistant", page_icon="🤖", layout="wide", initial_sidebar_state="expanded", menu_items={ "About": "BeRU - Offline RAG System with VLM2Vec and Mistral 7B" } ) # ======================================== # 🌍 ENVIRONMENT DETECTION # ======================================== def detect_environment(): """Detect if running on HF Spaces""" is_spaces = os.getenv('SPACES', 'false').lower() == 'true' or 'huggingface' in os.path.exists('/app') return { 'is_spaces': is_spaces, 'device': 'cuda' if torch.cuda.is_available() else 'cpu', 'model_cache': os.getenv('HF_HOME', './cache'), 'gpu_memory': torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0 } env_info = detect_environment() # Display environment info in sidebar with st.sidebar: st.write("### System Info") st.write(f"🖥️ Device: `{env_info['device'].upper()}`") if env_info['device'] == 'cuda': st.write(f"💾 GPU VRAM: `{env_info['gpu_memory'] / 1e9:.1f} GB`") st.write(f"📦 Cache: `{env_info['model_cache']}`") # ======================================== # 🎯 MODEL LOADING WITH CACHING # ======================================== @st.cache_resource def load_embedding_model(): """Load VLM2Vec embedding model with error handling""" with st.spinner("⏳ Loading embedding model... (first time may take 5 min)"): try: logger.info("Loading VLM2Vec model...") device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained( "TIGER-Lab/VLM2Vec-Qwen2VL-2B", trust_remote_code=True, torch_dtype=torch.float16 if device == "cuda" else torch.float32, cache_dir=env_info['model_cache'] ).to(device) processor = AutoProcessor.from_pretrained( "TIGER-Lab/VLM2Vec-Qwen2VL-2B", trust_remote_code=True, cache_dir=env_info['model_cache'] ) tokenizer = AutoTokenizer.from_pretrained( "TIGER-Lab/VLM2Vec-Qwen2VL-2B", trust_remote_code=True, cache_dir=env_info['model_cache'] ) model.eval() logger.info("✅ Embedding model loaded successfully") st.success("✅ Embedding model loaded!") return model, processor, tokenizer, device except Exception as e: st.error(f"❌ Error loading embedding model: {str(e)}") logger.error(f"Model loading error: {e}") raise @st.cache_resource def load_llm_model(): """Load Mistral 7B LLM with quantization""" with st.spinner("⏳ Loading LLM model... (first time may take 5 min)"): try: logger.info("Loading Mistral-7B model...") device = "cuda" if torch.cuda.is_available() else "cpu" from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # 4-bit quantization config for memory efficiency bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.3", cache_dir=env_info['model_cache'] ) model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.3", quantization_config=bnb_config, device_map="auto", cache_dir=env_info['model_cache'] ) logger.info("✅ LLM model loaded successfully") st.success("✅ LLM model loaded!") return model, tokenizer, device except Exception as e: st.error(f"❌ Error loading LLM: {str(e)}") logger.error(f"LLM loading error: {e}") raise # ======================================== # 🏠 UI LAYOUT # ======================================== st.title("🤖 BeRU Chat - RAG Assistant") st.markdown(""" A powerful offline RAG system combining Mistral 7B LLM with VLM2Vec embeddings for intelligent document search and conversation. **Status**: Models loading on first access (5-8 minutes) """) # Load models try: embedding_model, processor, tokenizer, device = load_embedding_model() llm_model, llm_tokenizer, llm_device = load_llm_model() models_loaded = True except Exception as e: st.error(f"Failed to load models: {str(e)}") models_loaded = False if models_loaded: # Main chat interface left_col, right_col = st.columns([2, 1]) with left_col: st.subheader("💬 Chat") # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] # Display chat history for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.write(msg["content"]) # Chat input user_input = st.chat_input("Ask a question about your documents...") if user_input: # Add user message st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.write(user_input) # Generate response with st.chat_message("assistant"): with st.spinner("🤔 Thinking..."): # Placeholder for RAG response response = "Response generated from RAG system..." st.write(response) st.session_state.messages.append({"role": "assistant", "content": response}) with right_col: st.subheader("📊 Info") st.info(""" **Model Info:** - 🧠 Embedding: VLM2Vec-Qwen2VL-2B - 💬 LLM: Mistral-7B-Instruct - 🔍 Search: FAISS + BM25 **Performance:** - Device: GPU if available - Quantization: 4-bit - Context: Multi-turn """) st.subheader("⚙️ Settings") temperature = st.slider("Temperature", 0.0, 1.0, 0.7) max_tokens = st.slider("Max Tokens", 100, 2000, 512) else: st.error("❌ Failed to initialize models. Check logs for details.") st.info("Try refreshing the page or restarting the Space.") # ======================================== # 📝 FOOTER # ======================================== st.markdown("---") st.markdown("""
""", unsafe_allow_html=True)