Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| # ============================== | |
| # CONFIG | |
| # ============================== | |
| st.set_page_config(page_title="Company ChatGPT", layout="wide") | |
| st.title("π’ Company AI Assistant (RAG Powered)") | |
| # ============================== | |
| # LOAD MODELS | |
| # ============================== | |
| def load_models(): | |
| embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| st.error("β Please add HF_TOKEN in Hugging Face Secrets") | |
| st.stop() | |
| llm = InferenceClient( | |
| model="meta-llama/Meta-Llama-3-8B-Instruct", | |
| token=HF_TOKEN | |
| ) | |
| return embed_model, llm | |
| embed_model, llm = load_models() | |
| # ============================== | |
| # LOAD DATA | |
| # ============================== | |
| def load_data(): | |
| path = "src/company_sample.csv" | |
| if not os.path.exists(path): | |
| st.error(f"β File not found: {path}") | |
| st.stop() | |
| df = pd.read_csv(path) | |
| return df | |
| df = load_data() | |
| if "text" not in df.columns: | |
| st.error("β CSV must contain 'text' column") | |
| st.stop() | |
| documents = df["text"].fillna("").tolist() | |
| # ============================== | |
| # CREATE VECTOR DB | |
| # ============================== | |
| def create_faiss(docs): | |
| embeddings = embed_model.encode(docs) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(np.array(embeddings)) | |
| return index | |
| index = create_faiss(documents) | |
| # ============================== | |
| # RETRIEVAL FUNCTION | |
| # ============================== | |
| def retrieve(query, top_k=3): | |
| q_emb = embed_model.encode([query]) | |
| D, I = index.search(np.array(q_emb), top_k) | |
| return [documents[i] for i in I[0] if i < len(documents)] | |
| # ============================== | |
| # CHAT HISTORY | |
| # ============================== | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for msg in st.session_state.messages: | |
| st.chat_message(msg["role"]).write(msg["content"]) | |
| # ============================== | |
| # USER INPUT | |
| # ============================== | |
| query = st.chat_input("Ask about company...") | |
| if query: | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| st.chat_message("user").write(query) | |
| # π Retrieve context | |
| context_docs = retrieve(query) | |
| context = "\n\n".join(context_docs) | |
| # ============================== | |
| # π€ LLM CALL (FIXED) | |
| # ============================== | |
| try: | |
| response = llm.chat_completion( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a company assistant. Answer ONLY from given context. If not found, say 'Not available in company data.'" | |
| }, | |
| { | |
| "role": "user", | |
| "content": f""" | |
| Context: | |
| {context} | |
| Question: | |
| {query} | |
| """ | |
| } | |
| ], | |
| max_tokens=200, | |
| temperature=0.5 | |
| ) | |
| answer = response.choices[0].message.content | |
| except Exception as e: | |
| answer = f"β Error: {str(e)}" | |
| # ============================== | |
| # DISPLAY RESPONSE | |
| # ============================== | |
| st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| st.chat_message("assistant").write(answer) |