Maheen Saleh commited on
Commit
4a3a2c0
·
1 Parent(s): e7a4534

updated proj structure

Browse files
src/__pycache__/qa_prompts.cpython-311.pyc ADDED
Binary file (460 Bytes). View file
 
src/data_index/embeddings_model.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ sentence-transformers/all-MiniLM-L6-v2
src/data_index/index.faiss ADDED
Binary file (7.73 kB). View file
 
src/data_index/index.pkl ADDED
Binary file (5.21 kB). View file
 
src/extra_qa_chains.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def build_chain(retriever, model_name: str = LLM_MODEL_NAME):
2
+ # Local HF pipeline
3
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
4
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
5
+ gen = pipeline(
6
+ "text2text-generation",
7
+ model=model,
8
+ tokenizer=tokenizer,
9
+ max_new_tokens=512,
10
+ )
11
+ llm = HuggingFacePipeline(pipeline=gen)
12
+
13
+
14
+ prompt = PromptTemplate(
15
+ input_variables=["context", "question"],
16
+ template=PROMPT_TMPL,
17
+ )
18
+
19
+ qa = RetrievalQA.from_chain_type(
20
+ llm=llm,
21
+ chain_type="stuff",
22
+ retriever=retriever,
23
+ chain_type_kwargs={"prompt": prompt},
24
+ return_source_documents=True,
25
+ )
26
+ return qa
27
+
28
+
29
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
30
+ from langchain_community.llms import HuggingFacePipeline
31
+ from langchain.prompts import PromptTemplate
32
+ from langchain.chains import RetrievalQA
33
+
34
+ def build_chain_qwen(retriever, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
35
+ # Qwen2.5 is a causal LM (decoder-only), not seq2seq.
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ # Ensure padding token exists (use EOS as pad for causal models if missing)
38
+ if tokenizer.pad_token_id is None:
39
+ tokenizer.pad_token_id = tokenizer.eos_token_id
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(model_name)
42
+
43
+ gen = pipeline(
44
+ task="text-generation",
45
+ model=model,
46
+ tokenizer=tokenizer,
47
+ max_new_tokens=512,
48
+ do_sample=False, # deterministic for QA
49
+ truncation=True, # avoid context overruns
50
+ return_full_text=False, # only the generated answer
51
+ eos_token_id=tokenizer.eos_token_id,
52
+ pad_token_id=tokenizer.pad_token_id,
53
+ )
54
+ llm = HuggingFacePipeline(pipeline=gen)
55
+
56
+ prompt = PromptTemplate(
57
+ input_variables=["context", "question"],
58
+ template=PROMPT_TMPL,
59
+ )
60
+
61
+ qa = RetrievalQA.from_chain_type(
62
+ llm=llm,
63
+ chain_type="stuff", # keep as in your snippet
64
+ retriever=retriever,
65
+ chain_type_kwargs={"prompt": prompt},
66
+ return_source_documents=True,
67
+ )
68
+ return qa
69
+
70
+
71
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
72
+ from langchain_community.llms import HuggingFacePipeline
73
+ from langchain.prompts import PromptTemplate
74
+ from langchain.chains import RetrievalQA
75
+
76
+ def build_chain_gemma(retriever, model_name: str = "google/gemma-2-2b-it"):
77
+ # Gemma 2 is a causal LM (decoder-only)
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
79
+ if tokenizer.pad_token_id is None:
80
+ tokenizer.pad_token_id = tokenizer.eos_token_id
81
+
82
+ model = AutoModelForCausalLM.from_pretrained(model_name)
83
+
84
+ gen = pipeline(
85
+ task="text-generation",
86
+ model=model,
87
+ tokenizer=tokenizer,
88
+ max_new_tokens=512,
89
+ do_sample=False, # deterministic for QA
90
+ truncation=True, # avoid context overruns
91
+ return_full_text=False, # only generated continuation
92
+ eos_token_id=tokenizer.eos_token_id,
93
+ pad_token_id=tokenizer.pad_token_id,
94
+ )
95
+ llm = HuggingFacePipeline(pipeline=gen)
96
+
97
+ prompt = PromptTemplate(
98
+ input_variables=["context", "question"],
99
+ template=PROMPT_TMPL,
100
+ )
101
+
102
+ qa = RetrievalQA.from_chain_type(
103
+ llm=llm,
104
+ chain_type="stuff", # keep your current behavior
105
+ retriever=retriever,
106
+ chain_type_kwargs={"prompt": prompt},
107
+ return_source_documents=True,
108
+ )
109
+ return qa
src/ingest.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argparse
3
+ import sys
4
+ import os
5
+
6
+ from langchain_community.document_loaders import TextLoader, PyPDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+
11
+ import os
12
+ from dotenv import load_dotenv
13
+ load_dotenv() # still works locally
14
+
15
+ HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
16
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
17
+
18
+ EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL")
19
+ LLM_MODEL_NAME = os.getenv("LLM_MODEL")
20
+
21
+ ROOT_DIR = Path(__file__).parent
22
+ INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
23
+
24
+ ROOT_DIR = Path(__file__).parent
25
+ INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
26
+ DATA_DIR = Path(f"{ROOT_DIR}/data")
27
+
28
+
29
+ def load_documents(data_dir: Path):
30
+ docs = []
31
+ for path in data_dir.rglob("*"):
32
+ if path.is_dir():
33
+ continue
34
+ try:
35
+ if path.suffix.lower() in [".txt", ".md"]:
36
+ docs.extend(TextLoader(str(path), encoding="utf-8").load())
37
+ elif path.suffix.lower() == ".pdf":
38
+ docs.extend(PyPDFLoader(str(path)).load())
39
+ except Exception as e:
40
+ print(f"[skip] {path.name}: {e}", file=sys.stderr)
41
+ if not docs:
42
+ raise RuntimeError(f"No documents found in {data_dir}. Put .txt/.md/.pdf files there.")
43
+ return docs
44
+
45
+ def build_vectorstore(docs):
46
+ splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=120)
47
+ chunks = splitter.split_documents(docs)
48
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
49
+ vs = FAISS.from_documents(chunks, embeddings)
50
+ return vs
51
+
52
+ def main():
53
+ parser = argparse.ArgumentParser(description="Ingest documents and build FAISS index.")
54
+ args = parser.parse_args()
55
+
56
+
57
+
58
+ print(f"Loading documents from {DATA_DIR}")
59
+ docs = load_documents(DATA_DIR)
60
+ print(f"Loaded {len(docs)} documents. Building index…")
61
+
62
+ vs = build_vectorstore(docs)
63
+ INDEX_DIR.mkdir(parents=True, exist_ok=True)
64
+ vs.save_local(str(INDEX_DIR))
65
+
66
+ # Persist embedding model name for safety
67
+ (INDEX_DIR / "embeddings_model.txt").write_text(EMBED_MODEL_NAME, encoding="utf-8")
68
+
69
+ print(f"Index saved to {INDEX_DIR.resolve()}")
70
+
71
+ if __name__ == "__main__":
72
+ main()
src/qa_chain_cli.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import textwrap
3
+ from pathlib import Path
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from qa_prompts import PROMPT_TMPL
7
+
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.chains import RetrievalQA
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
+
14
+ load_dotenv()
15
+
16
+ HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
17
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
18
+
19
+ EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL")
20
+ LLM_MODEL_NAME = os.getenv("LLM_MODEL")
21
+
22
+ ROOT_DIR = Path(__file__).parent
23
+ INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
24
+
25
+
26
+ def load_retriever(index_dir: Path, k: int = 4):
27
+ # Ensure we use the same embedding model that was used during ingest
28
+ embed_model_name_path = index_dir / "embeddings_model.txt"
29
+ if not embed_model_name_path.exists():
30
+ raise RuntimeError(f"Missing {embed_model_name_path}. Re-run ingest.py.")
31
+ embed_model_name = embed_model_name_path.read_text(encoding="utf-8").strip()
32
+
33
+ embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)
34
+ vs = FAISS.load_local(str(index_dir), embeddings, allow_dangerous_deserialization=True)
35
+ return vs.as_retriever(search_kwargs={"k": k})
36
+
37
+
38
+
39
+ def build_chain_gemini(retriever):
40
+ if not GOOGLE_API_KEY:
41
+ raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")
42
+
43
+ # Uses Google Generative AI (Gemini) hosted inference endpoint
44
+ llm = ChatGoogleGenerativeAI(
45
+ model=LLM_MODEL_NAME,
46
+ api_key=GOOGLE_API_KEY,
47
+ temperature=0.1,
48
+ max_output_tokens=512,
49
+ convert_system_message_to_human=True,
50
+ )
51
+
52
+ prompt = PromptTemplate(
53
+ input_variables=["context", "question"],
54
+ template=PROMPT_TMPL,
55
+ )
56
+
57
+ # map_reduce keeps per-call size manageable and robust
58
+ qa = RetrievalQA.from_chain_type(
59
+ llm=llm,
60
+ chain_type="stuff",
61
+ retriever=retriever,
62
+ chain_type_kwargs={"prompt": prompt},
63
+ return_source_documents=True,
64
+ )
65
+ return qa
66
+
67
+
68
+ def main():
69
+ parser = argparse.ArgumentParser(description="Run recruiter Q/A over a saved FAISS index.")
70
+ args = parser.parse_args()
71
+
72
+ retriever = load_retriever(INDEX_DIR)
73
+
74
+ chain = build_chain_gemini(retriever)
75
+
76
+ print("\My Profile Chatbot ready. Ask about me.")
77
+ print("Type 'exit' to quit.\n")
78
+
79
+ while True:
80
+ try:
81
+ q = input("You: ").strip()
82
+ except (EOFError, KeyboardInterrupt):
83
+ print("\nBye!")
84
+ break
85
+ if not q:
86
+ continue
87
+ if q.lower() in {"exit", "quit", "q"}:
88
+ print("Bye!")
89
+ break
90
+
91
+ try:
92
+ res = chain.invoke({"query": q})
93
+ answer = res["result"] if isinstance(res, dict) else str(res)
94
+ except Exception as e:
95
+ answer = f"[error] {e}"
96
+
97
+ print("\nMaheen:", textwrap.fill(answer, width=100))
98
+ print()
99
+
100
+ if __name__ == "__main__":
101
+ main()
src/qa_prompts.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_TMPL = """You are a helpful chatbot that answers questions about the candidate's profile for recruiters.
2
+ Use ONLY the provided context. If the answer is not in the context, say you don't know. Be concise and factual.
3
+
4
+ Context:
5
+ {context}
6
+
7
+ Question: {question}
8
+
9
+ Answer:"""
src/streamlit_app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from pathlib import Path
3
+ from typing import List
4
+ import streamlit as st
5
+ from qa_prompts import PROMPT_TMPL
6
+
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.embeddings.base import Embeddings
11
+ from langchain_google_genai import ChatGoogleGenerativeAI
12
+ from huggingface_hub import InferenceClient
13
+
14
+ import os, streamlit as st
15
+ from dotenv import load_dotenv
16
+ load_dotenv() # still works locally
17
+
18
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
19
+ HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
20
+
21
+ EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
22
+ LLM_MODEL_NAME = os.getenv("LLM_MODEL", "gemini-1.5-flash")
23
+
24
+ ROOT_DIR = Path(__file__).parent
25
+ INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
26
+
27
+
28
+ ###### run ingest.py ######
29
+
30
+ if not INDEX_DIR.exists():
31
+ with st.spinner("Index not found. Building FAISS index (first run)…"):
32
+ # Ensure ingest.py reads the same env/secrets model and paths
33
+ proc = subprocess.run(["python", "src/ingest.py"], capture_output=True, text=True)
34
+ if proc.returncode != 0:
35
+ st.error(f"ingest.py failed:\n{proc.stderr}")
36
+ st.stop()
37
+
38
+
39
+ class HFAPIEmbeddings(Embeddings):
40
+ def __init__(self, repo_id: str, token: str | None = None, timeout: float = 120.0):
41
+ self.client = InferenceClient(model=repo_id, token=token, timeout=timeout)
42
+
43
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
44
+ return self.client.feature_extraction(texts)
45
+
46
+ def embed_query(self, text: str) -> List[float]:
47
+ vec = self.client.feature_extraction(text)
48
+ return vec[0] if (isinstance(vec, list) and vec and isinstance(vec[0], list)) else vec
49
+
50
+
51
+
52
+ def build_chain_gemini(retriever, _llm_repo, _max_new, _temp, _show_sources):
53
+ if not GOOGLE_API_KEY:
54
+ raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")
55
+
56
+ # Uses Google Generative AI (Gemini) hosted inference endpoint
57
+ llm = ChatGoogleGenerativeAI(
58
+ model=_llm_repo,
59
+ api_key=GOOGLE_API_KEY,
60
+ temperature=_temp,
61
+ max_output_tokens=_max_new,
62
+ convert_system_message_to_human=True,
63
+ )
64
+
65
+ prompt = PromptTemplate(
66
+ input_variables=["context", "question"],
67
+ template=PROMPT_TMPL,
68
+ )
69
+
70
+ #map reduce or stuff
71
+ qa = RetrievalQA.from_chain_type(
72
+ llm=llm,
73
+ chain_type="stuff",
74
+ retriever=retriever,
75
+ chain_type_kwargs={"prompt": prompt},
76
+ return_source_documents=_show_sources,
77
+ )
78
+ return qa
79
+
80
+
81
+
82
+ # ========================= Streamlit UI =========================
83
+ st.set_page_config(page_title="Maheen's Profile Chatbot", page_icon="💬", layout="centered")
84
+ st.title("Maheen's Profile Chatbot")
85
+ st.caption("RAG over my profile docs using FAISS + Hugging Face Inference API")
86
+
87
+ # Sidebar settings
88
+ st.sidebar.header("Settings")
89
+ hf_token = HF_API_TOKEN
90
+ if not hf_token:
91
+ st.sidebar.warning("HUGGINGFACEHUB_API_TOKEN is not set. Set it in your shell before running the app.")
92
+
93
+ # store_dir = st.sidebar.text_input("FAISS store path", value=INDEX_DIR)
94
+
95
+ # llm_repo_id = st.sidebar.text_input("LLM repo (HF)", value=LLM_MODEL_NAME)
96
+ # embed_repo_id = st.sidebar.text_input("Embedding model (HF)", value=EMBED_MODEL_NAME)
97
+
98
+ # Display model names as text (read-only)
99
+ st.sidebar.markdown(f"**Embedding Model:** `{EMBED_MODEL_NAME}`")
100
+ st.sidebar.markdown(f"**Chat Model:** `{LLM_MODEL_NAME}`")
101
+
102
+
103
+ # k = st.sidebar.number_input("Top-k retrieved chunks", min_value=1, max_value=20, value=4, step=1)
104
+ k = 4
105
+ # max_new_tokens = st.sidebar.number_input("Max new tokens", min_value=64, max_value=2048, value=512, step=64)
106
+ max_new_tokens = 512
107
+ # temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=1.0, value=0.1, step=0.05)
108
+ temperature = 0.1
109
+ # show_sources = st.sidebar.checkbox("Show sources", value=False)
110
+ show_sources = False
111
+
112
+
113
+ ###################
114
+
115
+ # Session state for chat history
116
+ if "history" not in st.session_state:
117
+ st.session_state.history = [] # list of (user, assistant, sources)
118
+
119
+ # Load vector store & chain lazily, cache across reruns
120
+ @st.cache_resource(show_spinner=True)
121
+ def _load_chain(_store_dir: str, _embed_repo: str, _llm_repo: str, _k: int, _max_new: int, _temp: float, _show_sources: bool):
122
+ if not Path(_store_dir).exists():
123
+ raise FileNotFoundError(f"FAISS store not found at '{_store_dir}'. Run ingest.py first.")
124
+ embeddings = HFAPIEmbeddings(repo_id=_embed_repo, token=hf_token)
125
+ vs = FAISS.load_local(
126
+ _store_dir,
127
+ embeddings,
128
+ allow_dangerous_deserialization=True, # required by newer LC versions
129
+ )
130
+ retriever = vs.as_retriever(search_kwargs={"k": 4}) # hardcoded, change later
131
+ chain = build_chain_gemini(retriever, _llm_repo, _max_new, _temp, _show_sources)
132
+ return chain
133
+
134
+
135
+ # Prepare chain
136
+ with st.spinner("Preparing retriever & LLM���"):
137
+ chain = _load_chain(INDEX_DIR, EMBED_MODEL_NAME, LLM_MODEL_NAME, k, max_new_tokens, temperature, show_sources)
138
+
139
+ def render_sources(docs):
140
+ if not docs:
141
+ return
142
+ st.markdown("**Sources**")
143
+ for i, d in enumerate(docs, start=1):
144
+ src = d.metadata.get("source", "unknown")
145
+ page = d.metadata.get("page", None)
146
+ label = f"{Path(src).name}" + (f" (page {page+1})" if isinstance(page, int) else "")
147
+ with st.expander(f"{i}. {label}"):
148
+ st.write(d.page_content[:1500] + ("…" if len(d.page_content) > 1500 else ""))
149
+
150
+ # --- Chat input with Enter submit ---
151
+ with st.form("chat-form", clear_on_submit=True):
152
+ user_input = st.text_input(
153
+ "Ask about my profile:",
154
+ placeholder="e.g., What are your key projects?"
155
+ )
156
+ submitted = st.form_submit_button("Ask")
157
+
158
+ if submitted and user_input.strip():
159
+ with st.spinner("Thinking…"):
160
+ try:
161
+ res = chain.invoke({"query": user_input.strip()})
162
+ if isinstance(res, dict):
163
+ answer = res.get("result", "")
164
+ sources = res.get("source_documents", []) if show_sources else []
165
+ else:
166
+ answer, sources = str(res), []
167
+ except Exception as e:
168
+ answer, sources = f"[error] {e}", []
169
+ st.session_state.history.append((user_input.strip(), answer, sources))
170
+
171
+ # Display history
172
+ for q, a, srcs in st.session_state.history:
173
+ st.markdown(f"**You:** {q}")
174
+ st.markdown(f"**Assistant:** {a}")
175
+ if show_sources:
176
+ render_sources(srcs)
177
+ st.markdown("---")
178
+
179
+ # Footer
180
+ # st.caption("Enter submits. Datastore path fixed from code/env. Models shown read-only.")