Maheen Saleh commited on
Commit
5a0c8da
·
1 Parent(s): 4a3a2c0

moved files for hugging face spaces

Browse files
Files changed (10) hide show
  1. .gitignore +3 -5
  2. Dockerfile +20 -0
  3. __pycache__/qa_prompts.cpython-311.pyc +0 -0
  4. basic.py +0 -147
  5. extra_qa_chains.py +0 -109
  6. ingest.py +0 -69
  7. qa_chain.py +0 -101
  8. qa_prompts.py +0 -9
  9. readme.md +8 -5
  10. ui_qa.py +0 -181
.gitignore CHANGED
@@ -1,6 +1,4 @@
1
  venv/
2
- data/
3
- data_index/
4
- .streamlit/
5
- .env
6
- .DS_Store
 
1
  venv/
2
+ src/data/
3
+ # src/data_index/
4
+ src/.env
 
 
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt ./
12
+ COPY src/ ./src/
13
+
14
+ RUN pip3 install -r requirements.txt
15
+
16
+ EXPOSE 8501
17
+
18
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
+
20
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
__pycache__/qa_prompts.cpython-311.pyc DELETED
Binary file (456 Bytes)
 
basic.py DELETED
@@ -1,147 +0,0 @@
1
- from pathlib import Path
2
- import os
3
- import textwrap
4
-
5
- # LangChain (HF + community)
6
- from langchain_community.document_loaders import TextLoader, PyPDFLoader
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain_community.vectorstores import FAISS
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain_community.embeddings import HuggingFaceHubEmbeddings
11
-
12
- from langchain.prompts import PromptTemplate
13
- from langchain.chains import RetrievalQA
14
- from langchain_community.llms import HuggingFacePipeline
15
- from langchain_community.llms import HuggingFaceHub
16
-
17
- # Hugging Face transformers
18
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
19
-
20
- ROOT_DIR = Path(__file__).parent
21
- DATA_DIR = Path(f"{ROOT_DIR}/data")
22
-
23
- def load_documents(data_dir: Path):
24
- docs = []
25
- for path in data_dir.rglob("*"):
26
- if path.is_dir():
27
- continue
28
- try:
29
- if path.suffix.lower() in [".txt", ".md"]:
30
- docs.extend(TextLoader(str(path), encoding="utf-8").load())
31
- elif path.suffix.lower() == ".pdf":
32
- docs.extend(PyPDFLoader(str(path)).load())
33
- except Exception as e:
34
- print(f"[skip] {path.name}: {e}")
35
- if not docs:
36
- raise RuntimeError(f"No documents found in {data_dir}. Put .txt/.md/.pdf files there.")
37
- return docs
38
-
39
- def build_retriever(docs):
40
- splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=120)
41
- chunks = splitter.split_documents(docs)
42
-
43
- # HF sentence-transformers embeddings (local)
44
- embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"
45
- embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)
46
-
47
- # # Embeddings via Hugging Face Inference API (no local model)
48
- # embed_model = "sentence-transformers/all-MiniLM-L6-v2"
49
- # embeddings = HuggingFaceHubEmbeddings(
50
- # repo_id=embed_model,
51
- # # Batch size helps when indexing many chunks (tune if needed)
52
- # task="feature-extraction",
53
- # )
54
-
55
- vs = FAISS.from_documents(chunks, embeddings)
56
- return vs.as_retriever(search_kwargs={"k": 4})
57
-
58
-
59
- PROMPT_TMPL = """You are a helpful chatbot that answers questions about the candidate's profile for recruiters.
60
- Use ONLY the provided context. If the answer is not in the context, say you don't know.
61
-
62
- Context:
63
- {context}
64
-
65
- Question: {question}
66
-
67
- Answer:"""
68
-
69
- def build_chain(retriever, model_name="google/flan-t5-base", llm_repo_id="mistralai/Mistral-7B-Instruct-v0.3"):
70
- # Local HF pipeline (CPU-friendly model)
71
- tokenizer = AutoTokenizer.from_pretrained(model_name)
72
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
73
- gen = pipeline(
74
- "text2text-generation",
75
- model=model,
76
- tokenizer=tokenizer,
77
- max_new_tokens=512,
78
- )
79
- llm = HuggingFacePipeline(pipeline=gen)
80
-
81
-
82
- # # Text-generation via Hugging Face Inference API
83
- # llm = HuggingFaceHub(
84
- # repo_id=llm_repo_id,
85
- # task="text-generation",
86
- # model_kwargs={
87
- # "max_new_tokens": 512,
88
- # "temperature": 0.1,
89
- # "return_full_text": False,
90
- # },
91
- # )
92
-
93
- prompt = PromptTemplate(
94
- input_variables=["context", "question"],
95
- template=PROMPT_TMPL,
96
- )
97
-
98
- qa = RetrievalQA.from_chain_type(
99
- llm=llm,
100
- chain_type="stuff",
101
- retriever=retriever,
102
- chain_type_kwargs={"prompt": prompt},
103
- return_source_documents=False,
104
- )
105
- return qa
106
-
107
-
108
- def main():
109
-
110
- if not os.environ.get("HUGGINGFACEHUB_API_TOKEN"):
111
- print("Please set HUGGINGFACEHUB_API_TOKEN environment variable.")
112
- return
113
-
114
- print("Loading documents from", DATA_DIR)
115
- docs = load_documents(DATA_DIR)
116
- print(f"Loaded {len(docs)} documents. Building index…")
117
- retriever = build_retriever(docs)
118
- print('Retriever built successfully')
119
- # exit()
120
- chain = build_chain(retriever)
121
-
122
- print("\nRecruiter Chatbot ready. Ask about the candidate's profile.")
123
- print("Type 'exit' to quit.\n")
124
-
125
- while True:
126
- try:
127
- q = input("You: ").strip()
128
- except (EOFError, KeyboardInterrupt):
129
- print("\nBye!")
130
- break
131
- if not q:
132
- continue
133
- if q.lower() in {"exit", "quit", "q"}:
134
- print("Bye!")
135
- break
136
-
137
- try:
138
- res = chain.invoke({"query": q})
139
- answer = res["result"] if isinstance(res, dict) else str(res)
140
- except Exception as e:
141
- answer = f"[error] {e}"
142
-
143
- print("\nAssistant:", textwrap.fill(answer, width=100))
144
- print()
145
-
146
- if __name__ == "__main__":
147
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extra_qa_chains.py DELETED
@@ -1,109 +0,0 @@
1
- def build_chain(retriever, model_name: str = LLM_MODEL_NAME):
2
- # Local HF pipeline
3
- tokenizer = AutoTokenizer.from_pretrained(model_name)
4
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
5
- gen = pipeline(
6
- "text2text-generation",
7
- model=model,
8
- tokenizer=tokenizer,
9
- max_new_tokens=512,
10
- )
11
- llm = HuggingFacePipeline(pipeline=gen)
12
-
13
-
14
- prompt = PromptTemplate(
15
- input_variables=["context", "question"],
16
- template=PROMPT_TMPL,
17
- )
18
-
19
- qa = RetrievalQA.from_chain_type(
20
- llm=llm,
21
- chain_type="stuff",
22
- retriever=retriever,
23
- chain_type_kwargs={"prompt": prompt},
24
- return_source_documents=True,
25
- )
26
- return qa
27
-
28
-
29
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
30
- from langchain_community.llms import HuggingFacePipeline
31
- from langchain.prompts import PromptTemplate
32
- from langchain.chains import RetrievalQA
33
-
34
- def build_chain_qwen(retriever, model_name: str = "Qwen/Qwen2.5-7B-Instruct"):
35
- # Qwen2.5 is a causal LM (decoder-only), not seq2seq.
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- # Ensure padding token exists (use EOS as pad for causal models if missing)
38
- if tokenizer.pad_token_id is None:
39
- tokenizer.pad_token_id = tokenizer.eos_token_id
40
-
41
- model = AutoModelForCausalLM.from_pretrained(model_name)
42
-
43
- gen = pipeline(
44
- task="text-generation",
45
- model=model,
46
- tokenizer=tokenizer,
47
- max_new_tokens=512,
48
- do_sample=False, # deterministic for QA
49
- truncation=True, # avoid context overruns
50
- return_full_text=False, # only the generated answer
51
- eos_token_id=tokenizer.eos_token_id,
52
- pad_token_id=tokenizer.pad_token_id,
53
- )
54
- llm = HuggingFacePipeline(pipeline=gen)
55
-
56
- prompt = PromptTemplate(
57
- input_variables=["context", "question"],
58
- template=PROMPT_TMPL,
59
- )
60
-
61
- qa = RetrievalQA.from_chain_type(
62
- llm=llm,
63
- chain_type="stuff", # keep as in your snippet
64
- retriever=retriever,
65
- chain_type_kwargs={"prompt": prompt},
66
- return_source_documents=True,
67
- )
68
- return qa
69
-
70
-
71
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
72
- from langchain_community.llms import HuggingFacePipeline
73
- from langchain.prompts import PromptTemplate
74
- from langchain.chains import RetrievalQA
75
-
76
- def build_chain_gemma(retriever, model_name: str = "google/gemma-2-2b-it"):
77
- # Gemma 2 is a causal LM (decoder-only)
78
- tokenizer = AutoTokenizer.from_pretrained(model_name)
79
- if tokenizer.pad_token_id is None:
80
- tokenizer.pad_token_id = tokenizer.eos_token_id
81
-
82
- model = AutoModelForCausalLM.from_pretrained(model_name)
83
-
84
- gen = pipeline(
85
- task="text-generation",
86
- model=model,
87
- tokenizer=tokenizer,
88
- max_new_tokens=512,
89
- do_sample=False, # deterministic for QA
90
- truncation=True, # avoid context overruns
91
- return_full_text=False, # only generated continuation
92
- eos_token_id=tokenizer.eos_token_id,
93
- pad_token_id=tokenizer.pad_token_id,
94
- )
95
- llm = HuggingFacePipeline(pipeline=gen)
96
-
97
- prompt = PromptTemplate(
98
- input_variables=["context", "question"],
99
- template=PROMPT_TMPL,
100
- )
101
-
102
- qa = RetrievalQA.from_chain_type(
103
- llm=llm,
104
- chain_type="stuff", # keep your current behavior
105
- retriever=retriever,
106
- chain_type_kwargs={"prompt": prompt},
107
- return_source_documents=True,
108
- )
109
- return qa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ingest.py DELETED
@@ -1,69 +0,0 @@
1
- from pathlib import Path
2
- import argparse
3
- import sys
4
- import os
5
-
6
- from langchain_community.document_loaders import TextLoader, PyPDFLoader
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain_community.vectorstores import FAISS
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
-
11
- import os, streamlit as st
12
- from dotenv import load_dotenv
13
- load_dotenv() # still works locally
14
-
15
- GOOGLE_API_KEY = st.secrets.get("GOOGLE_API_KEY", os.getenv("GOOGLE_API_KEY"))
16
- HF_API_TOKEN = st.secrets.get("HUGGING_FACE_API_TOKEN", os.getenv("HUGGING_FACE_API_TOKEN"))
17
-
18
- EMBED_MODEL_NAME = st.secrets.get("HUGGING_FACE_EMBEDDING_MODEL", os.getenv("HUGGING_FACE_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
19
- LLM_MODEL_NAME = st.secrets.get("LLM_MODEL", os.getenv("LLM_MODEL", "gemini-1.5-flash"))
20
-
21
- ROOT_DIR = Path(__file__).parent
22
- INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
23
- DATA_DIR = Path(f"{ROOT_DIR}/data")
24
-
25
-
26
- def load_documents(data_dir: Path):
27
- docs = []
28
- for path in data_dir.rglob("*"):
29
- if path.is_dir():
30
- continue
31
- try:
32
- if path.suffix.lower() in [".txt", ".md"]:
33
- docs.extend(TextLoader(str(path), encoding="utf-8").load())
34
- elif path.suffix.lower() == ".pdf":
35
- docs.extend(PyPDFLoader(str(path)).load())
36
- except Exception as e:
37
- print(f"[skip] {path.name}: {e}", file=sys.stderr)
38
- if not docs:
39
- raise RuntimeError(f"No documents found in {data_dir}. Put .txt/.md/.pdf files there.")
40
- return docs
41
-
42
- def build_vectorstore(docs):
43
- splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=120)
44
- chunks = splitter.split_documents(docs)
45
- embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
46
- vs = FAISS.from_documents(chunks, embeddings)
47
- return vs
48
-
49
- def main():
50
- parser = argparse.ArgumentParser(description="Ingest documents and build FAISS index.")
51
- args = parser.parse_args()
52
-
53
-
54
-
55
- print(f"Loading documents from {DATA_DIR}")
56
- docs = load_documents(DATA_DIR)
57
- print(f"Loaded {len(docs)} documents. Building index…")
58
-
59
- vs = build_vectorstore(docs)
60
- INDEX_DIR.mkdir(parents=True, exist_ok=True)
61
- vs.save_local(str(INDEX_DIR))
62
-
63
- # Persist embedding model name for safety
64
- (INDEX_DIR / "embeddings_model.txt").write_text(EMBED_MODEL_NAME, encoding="utf-8")
65
-
66
- print(f"Index saved to {INDEX_DIR.resolve()}")
67
-
68
- if __name__ == "__main__":
69
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qa_chain.py DELETED
@@ -1,101 +0,0 @@
1
- import argparse
2
- import textwrap
3
- from pathlib import Path
4
- import os
5
- from dotenv import load_dotenv
6
- from qa_prompts import PROMPT_TMPL
7
-
8
- from langchain_community.vectorstores import FAISS
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain.prompts import PromptTemplate
11
- from langchain.chains import RetrievalQA
12
- from langchain_google_genai import ChatGoogleGenerativeAI
13
-
14
- load_dotenv()
15
-
16
- HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
17
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
18
-
19
- EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL")
20
- LLM_MODEL_NAME = os.getenv("LLM_MODEL")
21
-
22
- ROOT_DIR = Path(__file__).parent
23
- INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
24
-
25
-
26
- def load_retriever(index_dir: Path, k: int = 4):
27
- # Ensure we use the same embedding model that was used during ingest
28
- embed_model_name_path = index_dir / "embeddings_model.txt"
29
- if not embed_model_name_path.exists():
30
- raise RuntimeError(f"Missing {embed_model_name_path}. Re-run ingest.py.")
31
- embed_model_name = embed_model_name_path.read_text(encoding="utf-8").strip()
32
-
33
- embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)
34
- vs = FAISS.load_local(str(index_dir), embeddings, allow_dangerous_deserialization=True)
35
- return vs.as_retriever(search_kwargs={"k": k})
36
-
37
-
38
-
39
- def build_chain_gemini(retriever):
40
- if not GOOGLE_API_KEY:
41
- raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")
42
-
43
- # Uses Google Generative AI (Gemini) hosted inference endpoint
44
- llm = ChatGoogleGenerativeAI(
45
- model=LLM_MODEL_NAME,
46
- api_key=GOOGLE_API_KEY,
47
- temperature=0.1,
48
- max_output_tokens=512,
49
- convert_system_message_to_human=True,
50
- )
51
-
52
- prompt = PromptTemplate(
53
- input_variables=["context", "question"],
54
- template=PROMPT_TMPL,
55
- )
56
-
57
- # map_reduce keeps per-call size manageable and robust
58
- qa = RetrievalQA.from_chain_type(
59
- llm=llm,
60
- chain_type="stuff",
61
- retriever=retriever,
62
- chain_type_kwargs={"prompt": prompt},
63
- return_source_documents=True,
64
- )
65
- return qa
66
-
67
-
68
- def main():
69
- parser = argparse.ArgumentParser(description="Run recruiter Q/A over a saved FAISS index.")
70
- args = parser.parse_args()
71
-
72
- retriever = load_retriever(INDEX_DIR)
73
-
74
- chain = build_chain_gemini(retriever)
75
-
76
- print("\My Profile Chatbot ready. Ask about me.")
77
- print("Type 'exit' to quit.\n")
78
-
79
- while True:
80
- try:
81
- q = input("You: ").strip()
82
- except (EOFError, KeyboardInterrupt):
83
- print("\nBye!")
84
- break
85
- if not q:
86
- continue
87
- if q.lower() in {"exit", "quit", "q"}:
88
- print("Bye!")
89
- break
90
-
91
- try:
92
- res = chain.invoke({"query": q})
93
- answer = res["result"] if isinstance(res, dict) else str(res)
94
- except Exception as e:
95
- answer = f"[error] {e}"
96
-
97
- print("\nMaheen:", textwrap.fill(answer, width=100))
98
- print()
99
-
100
- if __name__ == "__main__":
101
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qa_prompts.py DELETED
@@ -1,9 +0,0 @@
1
- PROMPT_TMPL = """You are a helpful chatbot that answers questions about the candidate's profile for recruiters.
2
- Use ONLY the provided context. If the answer is not in the context, say you don't know. Be concise and factual.
3
-
4
- Context:
5
- {context}
6
-
7
- Question: {question}
8
-
9
- Answer:"""
 
 
 
 
 
 
 
 
 
 
readme.md CHANGED
@@ -6,25 +6,28 @@ A chatbot about my profile, experience, education and skills
6
 
7
  - basic.py:
8
  - main workflow of the chatbot in cli
9
- - run with `python basic.py`
10
  - ui_qa.py
11
  - streamlit app for QA chatbot
12
- - run with `streamlit run ui_qa.py`
13
 
14
 
15
  ## Todos:
16
 
17
  - update readme for project structure, esp data, venv and env
18
  - add more data
 
19
  - educational docs
20
  - experience letters
21
  - project docs/ reports/ readme
22
- - linkedin stuff
23
  - better llm selection: using gemini for now
24
  - add ui selection for gemini free models
 
 
25
  - make router chains for better response
26
  - add router chains for education, skills, experience and default
27
  - UI: improve UI
28
- - enter for chat
29
  - chat sequence
30
- - deploy
 
 
6
 
7
  - basic.py:
8
  - main workflow of the chatbot in cli
9
+ - run with `python qa_chain_cli.py`
10
  - ui_qa.py
11
  - streamlit app for QA chatbot
12
+ - run with `streamlit run streamlit_app.py`
13
 
14
 
15
  ## Todos:
16
 
17
  - update readme for project structure, esp data, venv and env
18
  - add more data
19
+ - add updated cv with project
20
  - educational docs
21
  - experience letters
22
  - project docs/ reports/ readme
23
+ - linkedin stuff: get more from linkedin csv
24
  - better llm selection: using gemini for now
25
  - add ui selection for gemini free models
26
+ - improve qa pipeline and prompt
27
+ - add prompt history for agent mem
28
  - make router chains for better response
29
  - add router chains for education, skills, experience and default
30
  - UI: improve UI
 
31
  - chat sequence
32
+ - deploy:
33
+ - improve the huggingface guthub repo hosting setup
ui_qa.py DELETED
@@ -1,181 +0,0 @@
1
- import subprocess
2
- from pathlib import Path
3
- from typing import List
4
- import streamlit as st
5
- from qa_prompts import PROMPT_TMPL
6
-
7
- from langchain_community.vectorstores import FAISS
8
- from langchain.chains import RetrievalQA
9
- from langchain.prompts import PromptTemplate
10
- from langchain.embeddings.base import Embeddings
11
- from langchain_google_genai import ChatGoogleGenerativeAI
12
- from huggingface_hub import InferenceClient
13
-
14
- import os, streamlit as st
15
- from dotenv import load_dotenv
16
- load_dotenv() # still works locally
17
-
18
- GOOGLE_API_KEY = st.secrets.get("GOOGLE_API_KEY", os.getenv("GOOGLE_API_KEY"))
19
- HF_API_TOKEN = st.secrets.get("HUGGING_FACE_API_TOKEN", os.getenv("HUGGING_FACE_API_TOKEN"))
20
-
21
- EMBED_MODEL_NAME = st.secrets.get("HUGGING_FACE_EMBEDDING_MODEL", os.getenv("HUGGING_FACE_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
22
- LLM_MODEL_NAME = st.secrets.get("LLM_MODEL", os.getenv("LLM_MODEL", "gemini-1.5-flash"))
23
-
24
- ROOT_DIR = Path(__file__).parent
25
- INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
26
-
27
-
28
- ###### run ingest.py ######
29
-
30
- if not INDEX_DIR.exists():
31
- with st.spinner("Index not found. Building FAISS index (first run)…"):
32
- # Ensure ingest.py reads the same env/secrets model and paths
33
- proc = subprocess.run(["python", "ingest.py"], capture_output=True, text=True)
34
- if proc.returncode != 0:
35
- st.error(f"ingest.py failed:\n{proc.stderr}")
36
- st.stop()
37
-
38
-
39
- class HFAPIEmbeddings(Embeddings):
40
- def __init__(self, repo_id: str, token: str | None = None, timeout: float = 120.0):
41
- self.client = InferenceClient(model=repo_id, token=token, timeout=timeout)
42
-
43
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
44
- return self.client.feature_extraction(texts)
45
-
46
- def embed_query(self, text: str) -> List[float]:
47
- vec = self.client.feature_extraction(text)
48
- return vec[0] if (isinstance(vec, list) and vec and isinstance(vec[0], list)) else vec
49
-
50
-
51
-
52
- def build_chain_gemini(retriever, _llm_repo, _max_new, _temp, _show_sources):
53
- if not GOOGLE_API_KEY:
54
- raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")
55
-
56
- # Uses Google Generative AI (Gemini) hosted inference endpoint
57
- llm = ChatGoogleGenerativeAI(
58
- model=_llm_repo,
59
- api_key=GOOGLE_API_KEY,
60
- temperature=_temp,
61
- max_output_tokens=_max_new,
62
- convert_system_message_to_human=True,
63
- )
64
-
65
- prompt = PromptTemplate(
66
- input_variables=["context", "question"],
67
- template=PROMPT_TMPL,
68
- )
69
-
70
- #map reduce or stuff
71
- qa = RetrievalQA.from_chain_type(
72
- llm=llm,
73
- chain_type="stuff",
74
- retriever=retriever,
75
- chain_type_kwargs={"prompt": prompt},
76
- return_source_documents=_show_sources,
77
- )
78
- return qa
79
-
80
-
81
-
82
- # ========================= Streamlit UI =========================
83
- st.set_page_config(page_title="Maheen's Profile Chatbot", page_icon="💬", layout="centered")
84
- st.title("Maheen's Profile Chatbot")
85
- st.caption("RAG over my profile docs using FAISS + Hugging Face Inference API")
86
-
87
- # Sidebar settings
88
- st.sidebar.header("Settings")
89
- hf_token = HF_API_TOKEN
90
- if not hf_token:
91
- st.sidebar.warning("HUGGINGFACEHUB_API_TOKEN is not set. Set it in your shell before running the app.")
92
-
93
- # store_dir = st.sidebar.text_input("FAISS store path", value=INDEX_DIR)
94
-
95
- # llm_repo_id = st.sidebar.text_input("LLM repo (HF)", value=LLM_MODEL_NAME)
96
- # embed_repo_id = st.sidebar.text_input("Embedding model (HF)", value=EMBED_MODEL_NAME)
97
-
98
- # Display model names as text (read-only)
99
- st.sidebar.markdown(f"**Embedding Model:** `{EMBED_MODEL_NAME}`")
100
- st.sidebar.markdown(f"**Chat Model:** `{LLM_MODEL_NAME}`")
101
-
102
-
103
- # k = st.sidebar.number_input("Top-k retrieved chunks", min_value=1, max_value=20, value=4, step=1)
104
- k = 4
105
- # max_new_tokens = st.sidebar.number_input("Max new tokens", min_value=64, max_value=2048, value=512, step=64)
106
- max_new_tokens = 512
107
- # temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=1.0, value=0.1, step=0.05)
108
- temperature = 0.1
109
- # show_sources = st.sidebar.checkbox("Show sources", value=False)
110
- show_sources = False
111
-
112
-
113
- ###################
114
-
115
- # Session state for chat history
116
- if "history" not in st.session_state:
117
- st.session_state.history = [] # list of (user, assistant, sources)
118
-
119
- # Load vector store & chain lazily, cache across reruns
120
- @st.cache_resource(show_spinner=True)
121
- def _load_chain(_store_dir: str, _embed_repo: str, _llm_repo: str, _k: int, _max_new: int, _temp: float, _show_sources: bool):
122
- if not Path(_store_dir).exists():
123
- raise FileNotFoundError(f"FAISS store not found at '{_store_dir}'. Run ingest.py first.")
124
- embeddings = HFAPIEmbeddings(repo_id=_embed_repo, token=hf_token)
125
- vs = FAISS.load_local(
126
- _store_dir,
127
- embeddings,
128
- allow_dangerous_deserialization=True, # required by newer LC versions
129
- )
130
- retriever = vs.as_retriever(search_kwargs={"k": 4}) # hardcoded, change later
131
- chain = build_chain_gemini(retriever, _llm_repo, _max_new, _temp, _show_sources)
132
- return chain
133
-
134
-
135
- # Prepare chain
136
- with st.spinner("Preparing retriever & LLM…"):
137
- chain = _load_chain(INDEX_DIR, EMBED_MODEL_NAME, LLM_MODEL_NAME, k, max_new_tokens, temperature, show_sources)
138
-
139
- def render_sources(docs):
140
- if not docs:
141
- return
142
- st.markdown("**Sources**")
143
- for i, d in enumerate(docs, start=1):
144
- src = d.metadata.get("source", "unknown")
145
- page = d.metadata.get("page", None)
146
- label = f"{Path(src).name}" + (f" (page {page+1})" if isinstance(page, int) else "")
147
- with st.expander(f"{i}. {label}"):
148
- st.write(d.page_content[:1500] + ("…" if len(d.page_content) > 1500 else ""))
149
-
150
- # --- Chat input with Enter submit ---
151
- with st.form("chat-form", clear_on_submit=True):
152
- user_input = st.text_input(
153
- "Ask about my profile:",
154
- placeholder="e.g., What are your key projects?"
155
- )
156
- submitted = st.form_submit_button("Ask")
157
-
158
- if submitted and user_input.strip():
159
- with st.spinner("Thinking…"):
160
- try:
161
- res = chain.invoke({"query": user_input.strip()})
162
- if isinstance(res, dict):
163
- answer = res.get("result", "")
164
- sources = res.get("source_documents", []) if show_sources else []
165
- else:
166
- answer, sources = str(res), []
167
- except Exception as e:
168
- answer, sources = f"[error] {e}", []
169
- st.session_state.history.append((user_input.strip(), answer, sources))
170
-
171
- # Display history
172
- for q, a, srcs in st.session_state.history:
173
- st.markdown(f"**You:** {q}")
174
- st.markdown(f"**Assistant:** {a}")
175
- if show_sources:
176
- render_sources(srcs)
177
- st.markdown("---")
178
-
179
- # Footer
180
- # st.caption("Enter submits. Datastore path fixed from code/env. Models shown read-only.")
181
-