olamideba commited on
Commit
55609c0
·
1 Parent(s): 7049205

add huggingface hub support

Browse files
.dockerignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore local environments
2
+ venv/
3
+ .venv/
4
+ env/
5
+
6
+ # Ignore Python cache
7
+ **/__pycache__/
8
+ *.pyc
9
+
10
+ # Ignore large data or logs
11
+ *.log
12
+ data/
13
+ *.csv
14
+ *.sqlite
15
+
16
+ # Ignore Git history
17
+ .git
18
+ .gitignore
19
+
20
+ # Ignore local IDE settings
21
+ .vscode/
22
+ .idea/
23
+
24
+
25
+ # Models anf embeddings
26
+ .chroma/
27
+ embeddings/
28
+ data/
29
+ models/
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /
4
+
5
+ # Install curl for healthcheck
6
+ RUN apt-get update && apt-get install -y \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # RUN git clone https://github.com/mujeeb-gh/rag-chatbot-final.git .
11
+
12
+ COPY . .
13
+
14
+ RUN pip3 install -r requirements.txt
15
+
16
+ EXPOSE 8501
17
+
18
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
+
20
+ WORKDIR /app
21
+
22
+ # Create entrypoint script that downloads assets before starting Streamlit
23
+ RUN echo '#!/bin/bash\n\
24
+ python3 /app/scripts/download_assets.py\n\
25
+ exec streamlit run main.py --server.port=8501 --server.address=0.0.0.0\n\
26
+ ' > /app/entrypoint.sh && chmod +x /app/entrypoint.sh
27
+
28
+ ENTRYPOINT ["/app/entrypoint.sh"]
README.md CHANGED
@@ -6,6 +6,17 @@ To setup this project locally, you need to have a `.env` file in the root direct
6
 
7
  .venv is the new virtual env
8
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  Then, install the dependencies:
11
 
@@ -19,6 +30,40 @@ Without development dependencies:
19
  pip install -r requirements.txt
20
  ```
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ```
23
  astra
24
  ├─ .chroma
 
6
 
7
  .venv is the new virtual env
8
 
9
+ ### Environment Variables
10
+
11
+ Required for API keys:
12
+ - `GROQ_API_KEY`: Your Groq API key (for LLM)
13
+ - `OPENAI_API_KEY`: Your OpenAI API key (optional, for OpenAI models)
14
+ - `COHERE_API_KEY`: Your Cohere API key (optional)
15
+
16
+ For Docker deployment with HuggingFace Hub:
17
+ - `HF_MODELS_REPO`: HuggingFace Hub repository for models (e.g., "username/astra-models")
18
+ - `HF_CHROMADB_REPO`: HuggingFace Hub repository for ChromaDB embeddings (e.g., "username/astra-chromadb")
19
+ - `HF_TOKEN`: HuggingFace Hub token (required for private repositories)
20
 
21
  Then, install the dependencies:
22
 
 
30
  pip install -r requirements.txt
31
  ```
32
 
33
+ ## Docker Deployment
34
+
35
+ The Dockerfile is configured to automatically download models and ChromaDB embeddings from HuggingFace Hub at container startup.
36
+
37
+ ### Setting up HuggingFace Hub
38
+
39
+ 1. Create a HuggingFace account at https://huggingface.co/
40
+ 2. Create repositories for your models and ChromaDB:
41
+ - Create a repository for models (e.g., `your-username/astra-models`)
42
+ - Upload your model directories (`bge-large_finetuned/`, `bge-small_finetuned/`) to this repository
43
+ - Create a repository for ChromaDB (e.g., `your-username/astra-chromadb`)
44
+ - Compress your `.chroma/` directory and upload it as `chromadb.tar.gz` or `chromadb.zip`
45
+
46
+ 3. Set environment variables when running Docker:
47
+ ```bash
48
+ docker run -e HF_MODELS_REPO=your-username/astra-models \
49
+ -e HF_CHROMADB_REPO=your-username/astra-chromadb \
50
+ -e HF_TOKEN=your_hf_token \
51
+ -e GROQ_API_KEY=your_groq_key \
52
+ -p 8501:8501 your-image-name
53
+ ```
54
+
55
+ Or use a `.env` file with Docker Compose or `--env-file` flag.
56
+
57
+ ### Local Development
58
+
59
+ For local development, you can place models and ChromaDB files in the local directories:
60
+ - Models go in `models/` directory
61
+ - ChromaDB goes in `.chroma/` directory
62
+
63
+ The code will automatically use local files if available, falling back to HuggingFace Hub if not found.
64
+
65
+ ## Project Structure
66
+
67
  ```
68
  astra
69
  ├─ .chroma
app/scripts/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Scripts package
2
+
app/scripts/download_assets.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download assets (models and ChromaDB) from HuggingFace Hub if not already present locally.
3
+ This script runs at container startup to ensure required files are available.
4
+ """
5
+ import os
6
+ import shutil
7
+ import sys
8
+ import tarfile
9
+ import zipfile
10
+ from pathlib import Path
11
+
12
+ try:
13
+ from huggingface_hub import snapshot_download, hf_hub_download
14
+ from huggingface_hub.utils import HfHubHTTPError
15
+ except ImportError:
16
+ print("ERROR: huggingface_hub not installed. Please install it first.")
17
+ sys.exit(1)
18
+
19
+
20
+ def get_project_root():
21
+ """Get the project root directory."""
22
+ # This script is in app/scripts/, so go up two levels
23
+ script_dir = Path(__file__).parent
24
+ return script_dir.parent.parent
25
+
26
+
27
+ def download_models(models_repo: str, models_dir: Path, hf_token: str | None = None) -> None:
28
+ """
29
+ Download models from HuggingFace Hub if not present locally.
30
+
31
+ Args:
32
+ models_repo: HuggingFace Hub repository (e.g., "username/astra-models")
33
+ models_dir: Local directory to store models
34
+ hf_token: Optional HuggingFace token for private repos
35
+ """
36
+ if not models_repo:
37
+ print("WARNING: HF_MODELS_REPO not set. Skipping model download.")
38
+ return
39
+
40
+ print(f"Checking models in {models_dir}...")
41
+
42
+ # Check if models directory already has content
43
+ if models_dir.exists() and any(models_dir.iterdir()):
44
+ print(f"Models directory already contains files. Skipping download.")
45
+ print(f"To force re-download, delete {models_dir} and restart.")
46
+ return
47
+
48
+ print(f"Downloading models from {models_repo}...")
49
+ try:
50
+ # Ensure models directory exists
51
+ models_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ # Download the entire repository
54
+ snapshot_download(
55
+ repo_id=models_repo,
56
+ local_dir=str(models_dir),
57
+ token=hf_token,
58
+ resume_download=True,
59
+ )
60
+ print(f"✓ Models downloaded successfully to {models_dir}")
61
+ except HfHubHTTPError as e:
62
+ print(f"ERROR: Failed to download models from {models_repo}")
63
+ print(f"Error: {e}")
64
+ print("Make sure the repository exists and is accessible.")
65
+ sys.exit(1)
66
+ except Exception as e:
67
+ print(f"ERROR: Unexpected error while downloading models: {e}")
68
+ sys.exit(1)
69
+
70
+
71
+ def download_chromadb(chromadb_repo: str, chromadb_dir: Path, hf_token: str | None = None) -> None:
72
+ """
73
+ Download ChromaDB archive from HuggingFace Hub and extract it.
74
+
75
+ Args:
76
+ chromadb_repo: HuggingFace Hub repository (e.g., "username/astra-chromadb")
77
+ chromadb_dir: Local directory for ChromaDB
78
+ hf_token: Optional HuggingFace token for private repos
79
+ """
80
+ if not chromadb_repo:
81
+ print("WARNING: HF_CHROMADB_REPO not set. Skipping ChromaDB download.")
82
+ return
83
+
84
+ print(f"Checking ChromaDB in {chromadb_dir}...")
85
+
86
+ # Check if ChromaDB directory already has content
87
+ if chromadb_dir.exists() and any(chromadb_dir.iterdir()):
88
+ print(f"ChromaDB directory already contains files. Skipping download.")
89
+ print(f"To force re-download, delete {chromadb_dir} and restart.")
90
+ return
91
+
92
+ print(f"Downloading ChromaDB from {chromadb_repo}...")
93
+ try:
94
+ # Ensure chromadb directory exists
95
+ chromadb_dir.mkdir(parents=True, exist_ok=True)
96
+
97
+ # Try common archive filenames
98
+ archive_names = ["chromadb.tar.gz", "chromadb.zip", "chroma.tar.gz", "chroma.zip", ".chroma.tar.gz", ".chroma.zip"]
99
+
100
+ downloaded = False
101
+ for archive_name in archive_names:
102
+ try:
103
+ archive_path = hf_hub_download(
104
+ repo_id=chromadb_repo,
105
+ filename=archive_name,
106
+ local_dir=str(chromadb_dir.parent),
107
+ token=hf_token,
108
+ resume_download=True,
109
+ )
110
+
111
+ # Extract the archive
112
+ print(f"Extracting {archive_name}...")
113
+ if archive_name.endswith('.tar.gz'):
114
+ with tarfile.open(archive_path, 'r:gz') as tar:
115
+ # Get members and check if they're in a subdirectory
116
+ members = tar.getmembers()
117
+ # Extract to parent directory
118
+ tar.extractall(path=chromadb_dir.parent)
119
+
120
+ # If archive contains .chroma subdirectory, move contents up
121
+ extracted_chroma = chromadb_dir.parent / ".chroma"
122
+ if extracted_chroma.exists() and extracted_chroma.is_dir():
123
+ # Move contents from .chroma to chromadb_dir
124
+ for item in extracted_chroma.iterdir():
125
+ shutil.move(str(item), str(chromadb_dir / item.name))
126
+ extracted_chroma.rmdir()
127
+ elif archive_name.endswith('.zip'):
128
+ with zipfile.ZipFile(archive_path, 'r') as zip_ref:
129
+ zip_ref.extractall(path=chromadb_dir.parent)
130
+
131
+ # If archive contains .chroma subdirectory, move contents up
132
+ extracted_chroma = chromadb_dir.parent / ".chroma"
133
+ if extracted_chroma.exists() and extracted_chroma.is_dir():
134
+ # Move contents from .chroma to chromadb_dir
135
+ for item in extracted_chroma.iterdir():
136
+ shutil.move(str(item), str(chromadb_dir / item.name))
137
+ extracted_chroma.rmdir()
138
+
139
+ # Clean up the archive file
140
+ os.remove(archive_path)
141
+ print(f"✓ ChromaDB downloaded and extracted successfully to {chromadb_dir}")
142
+ downloaded = True
143
+ break
144
+ except HfHubHTTPError:
145
+ # Try next archive name
146
+ continue
147
+
148
+ if not downloaded:
149
+ # If no archive found, try downloading as a snapshot (directory structure)
150
+ print("No archive found, trying to download as directory snapshot...")
151
+ snapshot_download(
152
+ repo_id=chromadb_repo,
153
+ local_dir=str(chromadb_dir),
154
+ token=hf_token,
155
+ resume_download=True,
156
+ )
157
+ print(f"✓ ChromaDB downloaded successfully to {chromadb_dir}")
158
+
159
+ except HfHubHTTPError as e:
160
+ print(f"ERROR: Failed to download ChromaDB from {chromadb_repo}")
161
+ print(f"Error: {e}")
162
+ print("Make sure the repository exists and is accessible.")
163
+ sys.exit(1)
164
+ except Exception as e:
165
+ print(f"ERROR: Unexpected error while downloading ChromaDB: {e}")
166
+ sys.exit(1)
167
+
168
+
169
+ def main():
170
+ """Main function to download all required assets."""
171
+ print("=" * 60)
172
+ print("Downloading assets from HuggingFace Hub...")
173
+ print("=" * 60)
174
+
175
+ project_root = get_project_root()
176
+ models_dir = project_root / "models"
177
+ chromadb_dir = project_root / ".chroma"
178
+
179
+ # Get configuration from environment variables
180
+ models_repo = os.getenv("HF_MODELS_REPO", "")
181
+ chromadb_repo = os.getenv("HF_CHROMADB_REPO", "")
182
+ hf_token = os.getenv("HF_TOKEN", None)
183
+
184
+ # Download models
185
+ if models_repo:
186
+ download_models(models_repo, models_dir, hf_token)
187
+ else:
188
+ print("INFO: HF_MODELS_REPO not configured. Models must be available locally.")
189
+
190
+ # Download ChromaDB
191
+ if chromadb_repo:
192
+ download_chromadb(chromadb_repo, chromadb_dir, hf_token)
193
+ else:
194
+ print("INFO: HF_CHROMADB_REPO not configured. ChromaDB must be available locally.")
195
+
196
+ print("=" * 60)
197
+ print("Asset download complete!")
198
+ print("=" * 60)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()
203
+
app/src/llm.py CHANGED
@@ -9,7 +9,7 @@ load_dotenv()
9
 
10
  CHAT_MODEL = Literal["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
11
  groq_api_key = os.getenv('GROQ_API_KEY')
12
- openrouter_api_key = os.getenv('OPENROUTER_API_KEY')
13
 
14
  client = Groq(
15
  api_key=groq_api_key,
 
9
 
10
  CHAT_MODEL = Literal["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]
11
  groq_api_key = os.getenv('GROQ_API_KEY')
12
+ openrouter_api_key = settings.openrouter_api_key
13
 
14
  client = Groq(
15
  api_key=groq_api_key,
app/src/sentence.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Literal, List
2
  import numpy as np
3
  import os
4
 
@@ -10,19 +10,36 @@ EMBED_MODEL = Literal["BAAI/bge-small-en-v1.5", "BAAI/bge-base-en-v1.5", "BAAI/b
10
 
11
 
12
  def sentence_embed(
13
- texts: str | List[str], model_name_or_path: EMBED_MODEL = "BAAI/bge-large-en-v1.5", device: str = "cpu"
14
  ) -> list[list[float]]:
15
  """
16
  Embeds the given texts using the specified model.
17
 
18
  Args:
19
- texts (str | List[str], str]): The list of texts or text to embed.
20
- model (EMBED_MODEL): The embedding model to use.
 
 
 
 
 
21
 
22
  Returns:
23
- np.ndarray: The embeddings of the texts.
24
  """
25
- model = SentenceTransformer(os.path.join(MODELS_DIR, model_name_or_path))
 
 
 
 
 
 
 
 
 
 
 
 
26
  embeddings: np.ndarray = model.encode(sentences=texts, device=device, show_progress_bar=True)
27
  embeddings_list: list = embeddings.tolist()
28
  return embeddings_list
 
1
+ from typing import Literal, List, Union
2
  import numpy as np
3
  import os
4
 
 
10
 
11
 
12
  def sentence_embed(
13
+ texts: str | List[str], model_name_or_path: Union[str, EMBED_MODEL] = "BAAI/bge-large-en-v1.5", device: str = "cpu"
14
  ) -> list[list[float]]:
15
  """
16
  Embeds the given texts using the specified model.
17
 
18
  Args:
19
+ texts (str | List[str]): The list of texts or text to embed.
20
+ model_name_or_path (Union[str, EMBED_MODEL]): The embedding model to use.
21
+ Can be:
22
+ - A HuggingFace Hub identifier (e.g., "BAAI/bge-large-en-v1.5" or "username/repo-name")
23
+ - A local path relative to MODELS_DIR (e.g., "bge-small_finetuned")
24
+ - An absolute path
25
+ device (str): Device to use for encoding (default: "cpu").
26
 
27
  Returns:
28
+ list[list[float]]: The embeddings of the texts.
29
  """
30
+ # Check if it's a local path (starts with / or ./ or exists in MODELS_DIR)
31
+ local_model_path = os.path.join(MODELS_DIR, model_name_or_path)
32
+
33
+ # If it's a HuggingFace Hub identifier (contains /) or local path exists, use it directly
34
+ # SentenceTransformer handles both HF Hub identifiers and local paths
35
+ if os.path.exists(local_model_path):
36
+ model_path = local_model_path
37
+ else:
38
+ # Assume it's either an HF Hub identifier or a local path that doesn't exist yet
39
+ # SentenceTransformer will handle HF Hub downloads automatically
40
+ model_path = model_name_or_path
41
+
42
+ model = SentenceTransformer(model_path)
43
  embeddings: np.ndarray = model.encode(sentences=texts, device=device, show_progress_bar=True)
44
  embeddings_list: list = embeddings.tolist()
45
  return embeddings_list
app/src/settings.py CHANGED
@@ -17,6 +17,11 @@ class Settings(BaseSettings):
17
  cohere_api_key: str = ""
18
  groq_api_key: str = ""
19
  openai_api_key: str = ""
 
 
 
 
 
20
 
21
 
22
  settings = Settings()
 
17
  cohere_api_key: str = ""
18
  groq_api_key: str = ""
19
  openai_api_key: str = ""
20
+ openrouter_api_key: str = ""
21
+ # HuggingFace Hub configuration
22
+ hf_models_repo: str = os.getenv("HF_MODELS_REPO", "")
23
+ hf_chromadb_repo: str = os.getenv("HF_CHROMADB_REPO", "")
24
+ hf_token: str = os.getenv("HF_TOKEN", "")
25
 
26
 
27
  settings = Settings()
requirements.txt CHANGED
@@ -2,6 +2,7 @@ chromadb==0.5.0
2
  datasets==2.19.0
3
  evaluate==0.4.2
4
  groq==1.0.0
 
5
  numpy==1.24.3
6
  openai==2.14.0
7
  pandas==2.0.3
 
2
  datasets==2.19.0
3
  evaluate==0.4.2
4
  groq==1.0.0
5
+ huggingface_hub==0.24.0
6
  numpy==1.24.3
7
  openai==2.14.0
8
  pandas==2.0.3