sample_rag / src /streamlit_app.py
pradeep4321's picture
Update src/streamlit_app.py
fccb3d2 verified
import streamlit as st
import pandas as pd
import numpy as np
import faiss
import os
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
# ==============================
# CONFIG
# ==============================
st.set_page_config(page_title="Company ChatGPT", layout="wide")
st.title("🏒 Company AI Assistant (RAG Powered)")
# ==============================
# LOAD MODELS
# ==============================
@st.cache_resource
def load_models():
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
st.error("❌ Please add HF_TOKEN in Hugging Face Secrets")
st.stop()
llm = InferenceClient(
model="meta-llama/Meta-Llama-3-8B-Instruct",
token=HF_TOKEN
)
return embed_model, llm
embed_model, llm = load_models()
# ==============================
# LOAD DATA
# ==============================
@st.cache_data
def load_data():
path = "src/company_sample.csv"
if not os.path.exists(path):
st.error(f"❌ File not found: {path}")
st.stop()
df = pd.read_csv(path)
return df
df = load_data()
if "text" not in df.columns:
st.error("❌ CSV must contain 'text' column")
st.stop()
documents = df["text"].fillna("").tolist()
# ==============================
# CREATE VECTOR DB
# ==============================
@st.cache_resource
def create_faiss(docs):
embeddings = embed_model.encode(docs)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings))
return index
index = create_faiss(documents)
# ==============================
# RETRIEVAL FUNCTION
# ==============================
def retrieve(query, top_k=3):
q_emb = embed_model.encode([query])
D, I = index.search(np.array(q_emb), top_k)
return [documents[i] for i in I[0] if i < len(documents)]
# ==============================
# CHAT HISTORY
# ==============================
if "messages" not in st.session_state:
st.session_state.messages = []
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
# ==============================
# USER INPUT
# ==============================
query = st.chat_input("Ask about company...")
if query:
st.session_state.messages.append({"role": "user", "content": query})
st.chat_message("user").write(query)
# πŸ” Retrieve context
context_docs = retrieve(query)
context = "\n\n".join(context_docs)
# ==============================
# πŸ€– LLM CALL (FIXED)
# ==============================
try:
response = llm.chat_completion(
messages=[
{
"role": "system",
"content": "You are a company assistant. Answer ONLY from given context. If not found, say 'Not available in company data.'"
},
{
"role": "user",
"content": f"""
Context:
{context}
Question:
{query}
"""
}
],
max_tokens=200,
temperature=0.5
)
answer = response.choices[0].message.content
except Exception as e:
answer = f"❌ Error: {str(e)}"
# ==============================
# DISPLAY RESPONSE
# ==============================
st.session_state.messages.append({"role": "assistant", "content": answer})
st.chat_message("assistant").write(answer)