Yash Sakhale commited on
Commit
54b0b19
·
1 Parent(s): 99f283a

Add HF Hub model loading support

Browse files
Files changed (3) hide show
  1. ml_models.py +93 -15
  2. requirements.txt +1 -0
  3. upload_models_to_hf.py +86 -0
ml_models.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  ML Model Loader and Utilities
3
  Handles loading and using the conflict prediction model and package embeddings.
 
4
  """
5
 
6
  import json
@@ -10,27 +11,70 @@ from typing import Dict, List, Tuple, Optional
10
  import numpy as np
11
  from packaging.requirements import Requirement
12
 
 
 
 
 
 
 
 
 
13
 
14
  class ConflictPredictor:
15
  """Load and use the conflict prediction model."""
16
 
17
- def __init__(self, model_path: Optional[Path] = None):
18
- """Initialize the conflict predictor."""
 
 
 
 
 
 
 
 
 
 
19
  if model_path is None:
20
  model_path = Path(__file__).parent / "models" / "conflict_predictor.pkl"
21
 
22
- self.model = None
23
  self.model_path = model_path
24
 
 
25
  if model_path.exists():
26
  try:
27
  with open(model_path, 'rb') as f:
28
  self.model = pickle.load(f)
29
- print(f"Loaded conflict prediction model from {model_path}")
 
30
  except Exception as e:
31
- print(f"⚠️ Could not load conflict prediction model: {e}")
32
- else:
33
- print(f"⚠️ Conflict prediction model not found at {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def extract_features(self, requirements_text: str) -> np.ndarray:
36
  """Extract features from requirements text (same as training)."""
@@ -130,24 +174,58 @@ class ConflictPredictor:
130
  class PackageEmbeddings:
131
  """Load and use package embeddings for similarity matching."""
132
 
133
- def __init__(self, embeddings_path: Optional[Path] = None):
134
- """Initialize package embeddings."""
135
- if embeddings_path is None:
136
- embeddings_path = Path(__file__).parent / "models" / "package_embeddings.json"
137
 
 
 
 
 
 
138
  self.embeddings = {}
139
  self.embeddings_path = embeddings_path
140
  self.model = None
141
 
 
 
 
 
 
 
142
  if embeddings_path.exists():
143
  try:
144
  with open(embeddings_path, 'r') as f:
145
  self.embeddings = json.load(f)
146
- print(f"Loaded {len(self.embeddings)} package embeddings from {embeddings_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  except Exception as e:
148
- print(f"⚠️ Could not load embeddings: {e}")
149
- else:
150
- print(f"⚠️ Embeddings not found at {embeddings_path}")
151
 
152
  def _load_model(self):
153
  """Lazy load the sentence transformer model."""
 
1
  """
2
  ML Model Loader and Utilities
3
  Handles loading and using the conflict prediction model and package embeddings.
4
+ Loads from local files if available, otherwise downloads from Hugging Face Hub.
5
  """
6
 
7
  import json
 
11
  import numpy as np
12
  from packaging.requirements import Requirement
13
 
14
+ # Try to import huggingface_hub for model downloading
15
+ try:
16
+ from huggingface_hub import hf_hub_download
17
+ HF_HUB_AVAILABLE = True
18
+ except ImportError:
19
+ HF_HUB_AVAILABLE = False
20
+ print("Warning: huggingface_hub not available. Models must be loaded locally.")
21
+
22
 
23
  class ConflictPredictor:
24
  """Load and use the conflict prediction model."""
25
 
26
+ def __init__(self, model_path: Optional[Path] = None, repo_id: str = "ysakhale/dependency-conflict-models"):
27
+ """Initialize the conflict predictor.
28
+
29
+ Args:
30
+ model_path: Local path to model file (optional)
31
+ repo_id: Hugging Face repository ID to download from if local file not found
32
+ """
33
+ self.repo_id = repo_id
34
+ self.model = None
35
+ self.model_path = model_path
36
+
37
+ # Try local path first
38
  if model_path is None:
39
  model_path = Path(__file__).parent / "models" / "conflict_predictor.pkl"
40
 
 
41
  self.model_path = model_path
42
 
43
+ # Try loading from local file
44
  if model_path.exists():
45
  try:
46
  with open(model_path, 'rb') as f:
47
  self.model = pickle.load(f)
48
+ print(f"Loaded conflict prediction model from {model_path}")
49
+ return
50
  except Exception as e:
51
+ print(f"Could not load conflict prediction model from local: {e}")
52
+
53
+ # If local file doesn't exist, try downloading from HF Hub
54
+ if HF_HUB_AVAILABLE:
55
+ try:
56
+ print(f"Model not found locally. Downloading from Hugging Face Hub: {repo_id}")
57
+ downloaded_path = hf_hub_download(
58
+ repo_id=repo_id,
59
+ filename="conflict_predictor.pkl",
60
+ repo_type="model"
61
+ )
62
+ with open(downloaded_path, 'rb') as f:
63
+ self.model = pickle.load(f)
64
+ print(f"Loaded conflict prediction model from Hugging Face Hub")
65
+ # Optionally cache it locally
66
+ try:
67
+ model_path.parent.mkdir(parents=True, exist_ok=True)
68
+ import shutil
69
+ shutil.copy(downloaded_path, model_path)
70
+ print(f"Cached model locally at {model_path}")
71
+ except:
72
+ pass
73
+ return
74
+ except Exception as e:
75
+ print(f"Could not download model from Hugging Face Hub: {e}")
76
+
77
+ print(f"Warning: Conflict prediction model not available")
78
 
79
  def extract_features(self, requirements_text: str) -> np.ndarray:
80
  """Extract features from requirements text (same as training)."""
 
174
  class PackageEmbeddings:
175
  """Load and use package embeddings for similarity matching."""
176
 
177
+ def __init__(self, embeddings_path: Optional[Path] = None, repo_id: str = "ysakhale/dependency-conflict-models"):
178
+ """Initialize package embeddings.
 
 
179
 
180
+ Args:
181
+ embeddings_path: Local path to embeddings file (optional)
182
+ repo_id: Hugging Face repository ID to download from if local file not found
183
+ """
184
+ self.repo_id = repo_id
185
  self.embeddings = {}
186
  self.embeddings_path = embeddings_path
187
  self.model = None
188
 
189
+ if embeddings_path is None:
190
+ embeddings_path = Path(__file__).parent / "models" / "package_embeddings.json"
191
+
192
+ self.embeddings_path = embeddings_path
193
+
194
+ # Try loading from local file
195
  if embeddings_path.exists():
196
  try:
197
  with open(embeddings_path, 'r') as f:
198
  self.embeddings = json.load(f)
199
+ print(f"Loaded {len(self.embeddings)} package embeddings from {embeddings_path}")
200
+ return
201
+ except Exception as e:
202
+ print(f"Could not load embeddings from local: {e}")
203
+
204
+ # If local file doesn't exist, try downloading from HF Hub
205
+ if HF_HUB_AVAILABLE:
206
+ try:
207
+ print(f"Embeddings not found locally. Downloading from Hugging Face Hub: {repo_id}")
208
+ downloaded_path = hf_hub_download(
209
+ repo_id=repo_id,
210
+ filename="package_embeddings.json",
211
+ repo_type="model"
212
+ )
213
+ with open(downloaded_path, 'r') as f:
214
+ self.embeddings = json.load(f)
215
+ print(f"Loaded {len(self.embeddings)} package embeddings from Hugging Face Hub")
216
+ # Optionally cache it locally
217
+ try:
218
+ embeddings_path.parent.mkdir(parents=True, exist_ok=True)
219
+ import shutil
220
+ shutil.copy(downloaded_path, embeddings_path)
221
+ print(f"Cached embeddings locally at {embeddings_path}")
222
+ except:
223
+ pass
224
+ return
225
  except Exception as e:
226
+ print(f"Could not download embeddings from Hugging Face Hub: {e}")
227
+
228
+ print(f"Warning: Package embeddings not available")
229
 
230
  def _load_model(self):
231
  """Lazy load the sentence transformer model."""
requirements.txt CHANGED
@@ -5,4 +5,5 @@ requests>=2.31.0
5
  scikit-learn>=1.3.0
6
  sentence-transformers>=2.2.0
7
  numpy>=1.24.0
 
8
 
 
5
  scikit-learn>=1.3.0
6
  sentence-transformers>=2.2.0
7
  numpy>=1.24.0
8
+ huggingface-hub>=0.20.0
9
 
upload_models_to_hf.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload ML models to Hugging Face Hub
3
+ This allows the models to be loaded in Hugging Face Spaces
4
+ """
5
+
6
+ from pathlib import Path
7
+ from huggingface_hub import HfApi, login
8
+ import os
9
+
10
+ def upload_models():
11
+ """Upload models to Hugging Face Hub."""
12
+
13
+ # Check if models exist
14
+ models_dir = Path("models")
15
+ if not models_dir.exists():
16
+ print("Error: models/ directory not found!")
17
+ print("Please train the models first:")
18
+ print(" python train_conflict_model.py")
19
+ print(" python generate_embeddings.py")
20
+ return
21
+
22
+ # Check for model files
23
+ model_files = {
24
+ "conflict_predictor.pkl": models_dir / "conflict_predictor.pkl",
25
+ "package_embeddings.json": models_dir / "package_embeddings.json",
26
+ "embedding_info.json": models_dir / "embedding_info.json"
27
+ }
28
+
29
+ missing = [name for name, path in model_files.items() if not path.exists()]
30
+ if missing:
31
+ print(f"Error: Missing model files: {missing}")
32
+ print("Please train the models first:")
33
+ print(" python train_conflict_model.py")
34
+ print(" python generate_embeddings.py")
35
+ return
36
+
37
+ # Login to Hugging Face
38
+ print("Logging in to Hugging Face...")
39
+ print("(You'll need to enter your HF token - get it from https://huggingface.co/settings/tokens)")
40
+ try:
41
+ login()
42
+ except Exception as e:
43
+ print(f"Login error: {e}")
44
+ print("\nYou can also set HF_TOKEN environment variable:")
45
+ print(" $env:HF_TOKEN='your_token_here' # PowerShell")
46
+ return
47
+
48
+ # Initialize API
49
+ api = HfApi()
50
+
51
+ # Repository name for models
52
+ repo_id = "ysakhale/dependency-conflict-models"
53
+
54
+ # Create repository if it doesn't exist
55
+ try:
56
+ api.create_repo(
57
+ repo_id=repo_id,
58
+ repo_type="model",
59
+ exist_ok=True,
60
+ private=False
61
+ )
62
+ print(f"Repository {repo_id} is ready!")
63
+ except Exception as e:
64
+ print(f"Note: {e}")
65
+
66
+ # Upload each model file
67
+ print("\nUploading models...")
68
+ for filename, filepath in model_files.items():
69
+ print(f"Uploading {filename}...")
70
+ try:
71
+ api.upload_file(
72
+ path_or_fileobj=str(filepath),
73
+ path_in_repo=filename,
74
+ repo_id=repo_id,
75
+ repo_type="model"
76
+ )
77
+ print(f" ✓ {filename} uploaded successfully!")
78
+ except Exception as e:
79
+ print(f" ✗ Error uploading {filename}: {e}")
80
+
81
+ print(f"\n✅ Models uploaded to: https://huggingface.co/{repo_id}")
82
+ print("\nNext step: Update ml_models.py to load from this repository")
83
+
84
+ if __name__ == "__main__":
85
+ upload_models()
86
+