davidtran999 commited on
Commit
d4fe0a4
·
verified ·
1 Parent(s): 7919014

Upload backend/scripts/generate_embeddings.py with huggingface_hub

Browse files
backend/scripts/generate_embeddings.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to generate and store embeddings for Procedure, Fine, Office, Advisory models.
3
+ """
4
+ import argparse
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import List, Tuple
9
+ import numpy as np
10
+
11
+ ROOT_DIR = Path(__file__).resolve().parents[2]
12
+ BACKEND_DIR = ROOT_DIR / "backend"
13
+ HUE_PORTAL_DIR = BACKEND_DIR / "hue_portal"
14
+
15
+ # Add backend directory to sys.path so Django can find hue_portal package
16
+ # Django needs to import hue_portal.hue_portal.settings, so backend/ must be in path
17
+ # IMPORTANT: Only add BACKEND_DIR, not HUE_PORTAL_DIR, because Django needs to find
18
+ # the hue_portal package (which is in backend/hue_portal), not the hue_portal directory itself
19
+ if str(BACKEND_DIR) not in sys.path:
20
+ sys.path.insert(0, str(BACKEND_DIR))
21
+
22
+ # Add root for other imports if needed (but not HUE_PORTAL_DIR as it breaks Django imports)
23
+ if str(ROOT_DIR) not in sys.path:
24
+ sys.path.insert(0, str(ROOT_DIR))
25
+
26
+ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "hue_portal.hue_portal.settings")
27
+
28
+ import django
29
+ django.setup()
30
+
31
+ from hue_portal.core.models import Procedure, Fine, Office, Advisory, LegalSection
32
+ from hue_portal.core.embeddings import (
33
+ get_embedding_model,
34
+ generate_embeddings_batch,
35
+ get_embedding_dimension
36
+ )
37
+
38
+
39
+ def prepare_text_for_embedding(obj) -> str:
40
+ """
41
+ Prepare text from model instance for embedding.
42
+ """
43
+ if isinstance(obj, Procedure):
44
+ fields = [obj.title, obj.domain, obj.level, obj.conditions, obj.dossier]
45
+ elif isinstance(obj, Fine):
46
+ fields = [obj.name, obj.code, obj.article, obj.decree, obj.remedial]
47
+ elif isinstance(obj, Office):
48
+ fields = [obj.unit_name, obj.address, obj.district, obj.service_scope]
49
+ elif isinstance(obj, Advisory):
50
+ fields = [obj.title, obj.summary]
51
+ elif isinstance(obj, LegalSection):
52
+ fields = [obj.section_code, obj.section_title, obj.content, getattr(obj.document, "title", "")]
53
+ else:
54
+ return ""
55
+
56
+ # Combine non-empty fields
57
+ text = " ".join(str(f) for f in fields if f and str(f).strip())
58
+ return text.strip()
59
+
60
+
61
+ def generate_embeddings_for_model(model_class, model_name: str, batch_size: int = 32, dry_run: bool = False):
62
+ """
63
+ Generate embeddings for all instances of a model.
64
+
65
+ Args:
66
+ model_class: Django model class.
67
+ model_name: Name of the model (for display).
68
+ batch_size: Batch size for processing.
69
+ dry_run: If True, only show what would be done without saving.
70
+ """
71
+ print(f"\n{'='*60}")
72
+ print(f"Processing {model_name}")
73
+ print(f"{'='*60}")
74
+
75
+ # Get all instances
76
+ instances = list(model_class.objects.all())
77
+ total = len(instances)
78
+
79
+ if total == 0:
80
+ print(f"No {model_name} instances found. Skipping.")
81
+ return 0, 0
82
+
83
+ print(f"Found {total} {model_name} instances")
84
+
85
+ # Prepare texts
86
+ texts = []
87
+ valid_indices = []
88
+ for idx, instance in enumerate(instances):
89
+ text = prepare_text_for_embedding(instance)
90
+ if text:
91
+ texts.append(text)
92
+ valid_indices.append(idx)
93
+ else:
94
+ print(f"⚠️ Skipping {model_name} ID {instance.id}: empty text")
95
+
96
+ if not texts:
97
+ print(f"No valid texts found for {model_name}. Skipping.")
98
+ return 0, 0
99
+
100
+ print(f"Generating embeddings for {len(texts)} valid instances...")
101
+
102
+ # Load model
103
+ model = get_embedding_model()
104
+ if model is None:
105
+ print(f"❌ Cannot load embedding model. Skipping {model_name}.")
106
+ return 0, 0
107
+
108
+ # Generate embeddings
109
+ embeddings = generate_embeddings_batch(texts, model=model, batch_size=batch_size)
110
+
111
+ # Save embeddings (if not dry run)
112
+ saved = 0
113
+ failed = 0
114
+
115
+ for idx, embedding in zip(valid_indices, embeddings):
116
+ instance = instances[idx]
117
+
118
+ if embedding is None:
119
+ print(f"⚠️ Failed to generate embedding for {model_name} ID {instance.id}")
120
+ failed += 1
121
+ continue
122
+
123
+ if not dry_run:
124
+ # Convert numpy array to binary for storage
125
+ try:
126
+ import pickle
127
+ embedding_binary = pickle.dumps(embedding)
128
+ instance.embedding = embedding_binary
129
+ instance.save(update_fields=['embedding'])
130
+ print(f"✅ Generated and saved embedding for {model_name} ID {instance.id} (dim={len(embedding)})")
131
+ saved += 1
132
+ except Exception as e:
133
+ print(f"❌ Error saving embedding for {model_name} ID {instance.id}: {e}")
134
+ failed += 1
135
+ else:
136
+ print(f"[DRY RUN] Would save embedding for {model_name} ID {instance.id} (dim={len(embedding)})")
137
+ saved += 1
138
+
139
+ print(f"\n{model_name} Summary: {saved} saved, {failed} failed")
140
+ return saved, failed
141
+
142
+
143
+ def main():
144
+ parser = argparse.ArgumentParser(description="Generate embeddings for all models")
145
+ parser.add_argument("--model", choices=["procedure", "fine", "office", "advisory", "legal", "all"],
146
+ default="all", help="Which model to process")
147
+ parser.add_argument("--batch-size", type=int, default=32, help="Batch size for embedding generation")
148
+ parser.add_argument("--dry-run", action="store_true", help="Simulate without saving")
149
+ parser.add_argument("--model-name", type=str, help="Override embedding model name")
150
+ args = parser.parse_args()
151
+
152
+ print("="*60)
153
+ print("Embedding Generation Script")
154
+ print("="*60)
155
+
156
+ if args.dry_run:
157
+ print("⚠️ DRY RUN MODE - No changes will be saved")
158
+
159
+ if args.model_name:
160
+ print(f"Using model: {args.model_name}")
161
+ get_embedding_model(model_name=args.model_name, force_reload=True)
162
+ else:
163
+ print(f"Using default model: keepitreal/vietnamese-sbert-v2")
164
+
165
+ # Check model dimension
166
+ dim = get_embedding_dimension()
167
+ if dim > 0:
168
+ print(f"Embedding dimension: {dim}")
169
+ else:
170
+ print("⚠️ Could not determine embedding dimension")
171
+
172
+ total_saved = 0
173
+ total_failed = 0
174
+
175
+ models_to_process = []
176
+ if args.model == "all":
177
+ models_to_process = [
178
+ (Procedure, "Procedure"),
179
+ (Fine, "Fine"),
180
+ (Office, "Office"),
181
+ (Advisory, "Advisory"),
182
+ (LegalSection, "LegalSection"),
183
+ ]
184
+ else:
185
+ model_map = {
186
+ "procedure": (Procedure, "Procedure"),
187
+ "fine": (Fine, "Fine"),
188
+ "office": (Office, "Office"),
189
+ "advisory": (Advisory, "Advisory"),
190
+ "legal": (LegalSection, "LegalSection"),
191
+ }
192
+ if args.model in model_map:
193
+ models_to_process = [model_map[args.model]]
194
+
195
+ for model_class, model_name in models_to_process:
196
+ saved, failed = generate_embeddings_for_model(
197
+ model_class, model_name,
198
+ batch_size=args.batch_size,
199
+ dry_run=args.dry_run
200
+ )
201
+ total_saved += saved
202
+ total_failed += failed
203
+
204
+ print("\n" + "="*60)
205
+ print("Final Summary")
206
+ print("="*60)
207
+ print(f"Total saved: {total_saved}")
208
+ print(f"Total failed: {total_failed}")
209
+ print("="*60)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()
214
+