|
|
from huggingface_hub import login, HfApi, model_info, metadata_update |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
from datasets import Dataset |
|
|
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments |
|
|
from sentence_transformers.losses import MultipleNegativesRankingLoss |
|
|
from transformers import TrainerCallback, TrainingArguments |
|
|
from typing import List, Callable, Optional |
|
|
from pathlib import Path |
|
|
from .config import AppConfig |
|
|
|
|
|
|
|
|
|
|
|
def authenticate_hf(token: Optional[str]) -> None: |
|
|
"""Logs into the Hugging Face Hub.""" |
|
|
if token: |
|
|
print("Logging into Hugging Face Hub...") |
|
|
login(token=token) |
|
|
else: |
|
|
print("Skipping Hugging Face login: HF_TOKEN not set.") |
|
|
|
|
|
def load_embedding_model(model_name: str) -> SentenceTransformer: |
|
|
"""Initializes the Sentence Transformer model.""" |
|
|
print(f"Loading Sentence Transformer model: {model_name}") |
|
|
try: |
|
|
model = SentenceTransformer(model_name, model_kwargs={"device_map": "auto"}) |
|
|
print(f"Model loaded successfully. {model.device}") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"Error loading Sentence Transformer model {model_name}: {e}") |
|
|
raise |
|
|
|
|
|
def get_top_hits( |
|
|
model: SentenceTransformer, |
|
|
target_titles: List[str], |
|
|
task_name: str, |
|
|
query: str = "MY_FAVORITE_NEWS", |
|
|
top_k: int = 5 |
|
|
) -> str: |
|
|
"""Performs semantic search on target_titles and returns a formatted result string.""" |
|
|
if not target_titles: |
|
|
return "No target titles available for search." |
|
|
|
|
|
|
|
|
query_embedding = model.encode(query, prompt_name=task_name) |
|
|
|
|
|
|
|
|
title_embeddings = model.encode(target_titles, prompt_name=task_name) |
|
|
|
|
|
|
|
|
top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0] |
|
|
|
|
|
result = [] |
|
|
for hit in top_hits: |
|
|
title = target_titles[hit['corpus_id']] |
|
|
score = hit['score'] |
|
|
result.append(f"[{title}] {score:.4f}") |
|
|
|
|
|
return "\n".join(result) |
|
|
|
|
|
def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str: |
|
|
""" |
|
|
Uploads a local model folder to the Hugging Face Hub. |
|
|
Creates the repository if it doesn't exist. |
|
|
""" |
|
|
try: |
|
|
api = HfApi(token=token) |
|
|
|
|
|
|
|
|
user_info = api.whoami() |
|
|
username = user_info['name'] |
|
|
|
|
|
|
|
|
repo_id = f"{username}/{repo_name}" |
|
|
print(f"Preparing to upload to: {repo_id}") |
|
|
|
|
|
|
|
|
api.create_repo(repo_id=repo_id, exist_ok=True) |
|
|
|
|
|
|
|
|
url = api.upload_folder( |
|
|
folder_path=folder_path, |
|
|
repo_id=repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
info = model_info( |
|
|
repo_id=repo_id, |
|
|
token=token |
|
|
) |
|
|
tags = info.card_data.tags |
|
|
tags.append("embeddinggemma-tuning-lab") |
|
|
metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token) |
|
|
|
|
|
return f"✅ Success! Model published at: {url}" |
|
|
except Exception as e: |
|
|
print(f"Upload failed: {e}") |
|
|
return f"❌ Upload failed: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
class EvaluationCallback(TrainerCallback): |
|
|
""" |
|
|
A callback that runs the semantic search evaluation at the end of each log step. |
|
|
The search function is passed in during initialization. |
|
|
""" |
|
|
def __init__(self, search_fn: Callable[[], str]): |
|
|
self.search_fn = search_fn |
|
|
|
|
|
def on_log(self, args: TrainingArguments, state, control, **kwargs): |
|
|
print(f"Step {state.global_step} finished. Running evaluation:") |
|
|
print(f"\n{self.search_fn()}\n") |
|
|
|
|
|
|
|
|
def train_with_dataset( |
|
|
model: SentenceTransformer, |
|
|
dataset: List[List[str]], |
|
|
output_dir: Path, |
|
|
task_name: str, |
|
|
search_fn: Callable[[], str] |
|
|
) -> None: |
|
|
""" |
|
|
Fine-tunes the provided Sentence Transformer MODEL on the dataset. |
|
|
|
|
|
The dataset should be a list of lists: [[anchor, positive, negative], ...]. |
|
|
""" |
|
|
|
|
|
data_as_dicts = [ |
|
|
{"anchor": row[0], "positive": row[1], "negative": row[2]} |
|
|
for row in dataset |
|
|
] |
|
|
|
|
|
train_dataset = Dataset.from_list(data_as_dicts) |
|
|
|
|
|
|
|
|
loss = MultipleNegativesRankingLoss(model) |
|
|
|
|
|
|
|
|
|
|
|
prompts = getattr(model, 'prompts', {}).get(task_name) |
|
|
if not prompts: |
|
|
print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.") |
|
|
|
|
|
prompts = [] |
|
|
|
|
|
args = SentenceTransformerTrainingArguments( |
|
|
output_dir=output_dir, |
|
|
prompts=prompts, |
|
|
num_train_epochs=4, |
|
|
per_device_train_batch_size=1, |
|
|
learning_rate=2e-5, |
|
|
warmup_ratio=0.1, |
|
|
logging_steps=train_dataset.num_rows, |
|
|
report_to="none", |
|
|
save_strategy="no" |
|
|
) |
|
|
|
|
|
trainer = SentenceTransformerTrainer( |
|
|
model=model, |
|
|
args=args, |
|
|
train_dataset=train_dataset, |
|
|
loss=loss, |
|
|
callbacks=[EvaluationCallback(search_fn)] |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
print("Training finished. Model weights are updated in memory.") |
|
|
|
|
|
|
|
|
trainer.save_model() |
|
|
|
|
|
print(f"Model saved locally to: {output_dir}") |
|
|
|