File size: 7,756 Bytes
dec533d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
BeRU RAG Chat App - Optimized for Hugging Face Spaces
Deployment: https://huggingface.co/spaces/AnwinMJ/Beru
"""

import streamlit as st
import torch
import os
import pickle
import faiss
import numpy as np
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Dict
import time
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ========================================
# 🎨 STREAMLIT PAGE CONFIG
# ========================================
st.set_page_config(
    page_title="BeRU Chat - RAG Assistant",
    page_icon="πŸ€–",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        "About": "BeRU - Offline RAG System with VLM2Vec and Mistral 7B"
    }
)

# ========================================
# 🌍 ENVIRONMENT DETECTION
# ========================================
def detect_environment():
    """Detect if running on HF Spaces"""
    is_spaces = os.getenv('SPACES', 'false').lower() == 'true' or 'huggingface' in os.path.exists('/app')
    return {
        'is_spaces': is_spaces,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'model_cache': os.getenv('HF_HOME', './cache'),
        'gpu_memory': torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0
    }

env_info = detect_environment()

# Display environment info in sidebar
with st.sidebar:
    st.write("### System Info")
    st.write(f"πŸ–₯️ Device: `{env_info['device'].upper()}`")
    if env_info['device'] == 'cuda':
        st.write(f"πŸ’Ύ GPU VRAM: `{env_info['gpu_memory'] / 1e9:.1f} GB`")
    st.write(f"πŸ“¦ Cache: `{env_info['model_cache']}`")


# ========================================
# 🎯 MODEL LOADING WITH CACHING
# ========================================
@st.cache_resource
def load_embedding_model():
    """Load VLM2Vec embedding model with error handling"""
    with st.spinner("⏳ Loading embedding model... (first time may take 5 min)"):
        try:
            logger.info("Loading VLM2Vec model...")
            device = "cuda" if torch.cuda.is_available() else "cpu"
            
            model = AutoModel.from_pretrained(
                "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
                trust_remote_code=True,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                cache_dir=env_info['model_cache']
            ).to(device)
            
            processor = AutoProcessor.from_pretrained(
                "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
                trust_remote_code=True,
                cache_dir=env_info['model_cache']
            )
            
            tokenizer = AutoTokenizer.from_pretrained(
                "TIGER-Lab/VLM2Vec-Qwen2VL-2B",
                trust_remote_code=True,
                cache_dir=env_info['model_cache']
            )
            
            model.eval()
            logger.info("βœ… Embedding model loaded successfully")
            st.success("βœ… Embedding model loaded!")
            
            return model, processor, tokenizer, device
            
        except Exception as e:
            st.error(f"❌ Error loading embedding model: {str(e)}")
            logger.error(f"Model loading error: {e}")
            raise


@st.cache_resource
def load_llm_model():
    """Load Mistral 7B LLM with quantization"""
    with st.spinner("⏳ Loading LLM model... (first time may take 5 min)"):
        try:
            logger.info("Loading Mistral-7B model...")
            device = "cuda" if torch.cuda.is_available() else "cpu"
            
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
            
            # 4-bit quantization config for memory efficiency
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            
            tokenizer = AutoTokenizer.from_pretrained(
                "mistralai/Mistral-7B-Instruct-v0.3",
                cache_dir=env_info['model_cache']
            )
            
            model = AutoModelForCausalLM.from_pretrained(
                "mistralai/Mistral-7B-Instruct-v0.3",
                quantization_config=bnb_config,
                device_map="auto",
                cache_dir=env_info['model_cache']
            )
            
            logger.info("βœ… LLM model loaded successfully")
            st.success("βœ… LLM model loaded!")
            
            return model, tokenizer, device
            
        except Exception as e:
            st.error(f"❌ Error loading LLM: {str(e)}")
            logger.error(f"LLM loading error: {e}")
            raise


# ========================================
# 🏠 UI LAYOUT
# ========================================
st.title("πŸ€– BeRU Chat - RAG Assistant")
st.markdown("""
A powerful offline RAG system combining Mistral 7B LLM with VLM2Vec embeddings 
for intelligent document search and conversation.

**Status**: Models loading on first access (5-8 minutes)
""")

# Load models
try:
    embedding_model, processor, tokenizer, device = load_embedding_model()
    llm_model, llm_tokenizer, llm_device = load_llm_model()
    models_loaded = True
except Exception as e:
    st.error(f"Failed to load models: {str(e)}")
    models_loaded = False

if models_loaded:
    # Main chat interface
    left_col, right_col = st.columns([2, 1])
    
    with left_col:
        st.subheader("πŸ’¬ Chat")
        
        # Initialize session state
        if "messages" not in st.session_state:
            st.session_state.messages = []
        
        # Display chat history
        for msg in st.session_state.messages:
            with st.chat_message(msg["role"]):
                st.write(msg["content"])
        
        # Chat input
        user_input = st.chat_input("Ask a question about your documents...")
        
        if user_input:
            # Add user message
            st.session_state.messages.append({"role": "user", "content": user_input})
            
            with st.chat_message("user"):
                st.write(user_input)
            
            # Generate response
            with st.chat_message("assistant"):
                with st.spinner("πŸ€” Thinking..."):
                    # Placeholder for RAG response
                    response = "Response generated from RAG system..."
                    st.write(response)
                    st.session_state.messages.append({"role": "assistant", "content": response})
    
    with right_col:
        st.subheader("πŸ“Š Info")
        st.info("""
        **Model Info:**
        - 🧠 Embedding: VLM2Vec-Qwen2VL-2B
        - πŸ’¬ LLM: Mistral-7B-Instruct
        - πŸ” Search: FAISS + BM25
        
        **Performance:**
        - Device: GPU if available
        - Quantization: 4-bit
        - Context: Multi-turn
        """)
        
        st.subheader("βš™οΈ Settings")
        temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
        max_tokens = st.slider("Max Tokens", 100, 2000, 512)

else:
    st.error("❌ Failed to initialize models. Check logs for details.")
    st.info("Try refreshing the page or restarting the Space.")

# ========================================
# πŸ“ FOOTER
# ========================================
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
    <small>
        BeRU RAG System | 
        <a href='https://huggingface.co/spaces/AnwinMJ/Beru'>Space</a> | 
        <a href='https://github.com/AnwinMJ/BeRU'>GitHub</a>
    </small>
</div>
""", unsafe_allow_html=True)