import streamlit as st import pandas as pd import numpy as np import faiss import os from sentence_transformers import SentenceTransformer from huggingface_hub import InferenceClient # ============================== # CONFIG # ============================== st.set_page_config(page_title="Company ChatGPT", layout="wide") st.title("🏢 Company AI Assistant (RAG Powered)") # ============================== # LOAD MODELS # ============================== @st.cache_resource def load_models(): embed_model = SentenceTransformer("all-MiniLM-L6-v2") HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: st.error("❌ Please add HF_TOKEN in Hugging Face Secrets") st.stop() llm = InferenceClient( model="meta-llama/Meta-Llama-3-8B-Instruct", token=HF_TOKEN ) return embed_model, llm embed_model, llm = load_models() # ============================== # LOAD DATA # ============================== @st.cache_data def load_data(): path = "src/company_sample.csv" if not os.path.exists(path): st.error(f"❌ File not found: {path}") st.stop() df = pd.read_csv(path) return df df = load_data() if "text" not in df.columns: st.error("❌ CSV must contain 'text' column") st.stop() documents = df["text"].fillna("").tolist() # ============================== # CREATE VECTOR DB # ============================== @st.cache_resource def create_faiss(docs): embeddings = embed_model.encode(docs) index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(np.array(embeddings)) return index index = create_faiss(documents) # ============================== # RETRIEVAL FUNCTION # ============================== def retrieve(query, top_k=3): q_emb = embed_model.encode([query]) D, I = index.search(np.array(q_emb), top_k) return [documents[i] for i in I[0] if i < len(documents)] # ============================== # CHAT HISTORY # ============================== if "messages" not in st.session_state: st.session_state.messages = [] for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) # ============================== # USER INPUT # ============================== query = st.chat_input("Ask about company...") if query: st.session_state.messages.append({"role": "user", "content": query}) st.chat_message("user").write(query) # 🔍 Retrieve context context_docs = retrieve(query) context = "\n\n".join(context_docs) # ============================== # 🤖 LLM CALL (FIXED) # ============================== try: response = llm.chat_completion( messages=[ { "role": "system", "content": "You are a company assistant. Answer ONLY from given context. If not found, say 'Not available in company data.'" }, { "role": "user", "content": f""" Context: {context} Question: {query} """ } ], max_tokens=200, temperature=0.5 ) answer = response.choices[0].message.content except Exception as e: answer = f"❌ Error: {str(e)}" # ============================== # DISPLAY RESPONSE # ============================== st.session_state.messages.append({"role": "assistant", "content": answer}) st.chat_message("assistant").write(answer)