File size: 5,803 Bytes
d035cda beabfb7 e6cb750 beabfb7 71bbed6 beabfb7 e6cb750 beabfb7 e6cb750 beabfb7 e6cb750 beabfb7 71bbed6 d035cda ad4164f e6cb750 71bbed6 beabfb7 e6cb750 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from huggingface_hub import login, HfApi, model_info, metadata_update
from sentence_transformers import SentenceTransformer, util
from datasets import Dataset
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback, TrainingArguments
from typing import List, Callable, Optional
from pathlib import Path
from .config import AppConfig
# --- Model/Utility Functions ---
def authenticate_hf(token: Optional[str]) -> None:
"""Logs into the Hugging Face Hub."""
if token:
print("Logging into Hugging Face Hub...")
login(token=token)
else:
print("Skipping Hugging Face login: HF_TOKEN not set.")
def load_embedding_model(model_name: str) -> SentenceTransformer:
"""Initializes the Sentence Transformer model."""
print(f"Loading Sentence Transformer model: {model_name}")
try:
model = SentenceTransformer(model_name, model_kwargs={"device_map": "auto"})
print(f"Model loaded successfully. {model.device}")
return model
except Exception as e:
print(f"Error loading Sentence Transformer model {model_name}: {e}")
raise
def get_top_hits(
model: SentenceTransformer,
target_titles: List[str],
task_name: str,
query: str = "MY_FAVORITE_NEWS",
top_k: int = 5
) -> str:
"""Performs semantic search on target_titles and returns a formatted result string."""
if not target_titles:
return "No target titles available for search."
# Encode the query
query_embedding = model.encode(query, prompt_name=task_name)
# Encode the target titles (only done once per call)
title_embeddings = model.encode(target_titles, prompt_name=task_name)
# Perform semantic search
top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0]
result = []
for hit in top_hits:
title = target_titles[hit['corpus_id']]
score = hit['score']
result.append(f"[{title}] {score:.4f}")
return "\n".join(result)
def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
"""
Uploads a local model folder to the Hugging Face Hub.
Creates the repository if it doesn't exist.
"""
try:
api = HfApi(token=token)
# Get the authenticated user's username
user_info = api.whoami()
username = user_info['name']
# Construct the full repo ID
repo_id = f"{username}/{repo_name}"
print(f"Preparing to upload to: {repo_id}")
# Create the repo (safe if it already exists)
api.create_repo(repo_id=repo_id, exist_ok=True)
# Upload the folder
url = api.upload_folder(
folder_path=folder_path,
repo_id=repo_id,
repo_type="model"
)
info = model_info(
repo_id=repo_id,
token=token
)
tags = info.card_data.tags
tags.append("embeddinggemma-tuning-lab")
metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
return f"✅ Success! Model published at: {url}"
except Exception as e:
print(f"Upload failed: {e}")
return f"❌ Upload failed: {str(e)}"
# --- Training Class and Function ---
class EvaluationCallback(TrainerCallback):
"""
A callback that runs the semantic search evaluation at the end of each log step.
The search function is passed in during initialization.
"""
def __init__(self, search_fn: Callable[[], str]):
self.search_fn = search_fn
def on_log(self, args: TrainingArguments, state, control, **kwargs):
print(f"Step {state.global_step} finished. Running evaluation:")
print(f"\n{self.search_fn()}\n")
def train_with_dataset(
model: SentenceTransformer,
dataset: List[List[str]],
output_dir: Path,
task_name: str,
search_fn: Callable[[], str]
) -> None:
"""
Fine-tunes the provided Sentence Transformer MODEL on the dataset.
The dataset should be a list of lists: [[anchor, positive, negative], ...].
"""
# Convert to Hugging Face Dataset format
data_as_dicts = [
{"anchor": row[0], "positive": row[1], "negative": row[2]}
for row in dataset
]
train_dataset = Dataset.from_list(data_as_dicts)
# Use MultipleNegativesRankingLoss, suitable for contrastive learning
loss = MultipleNegativesRankingLoss(model)
# Note: SentenceTransformer models typically have a 'prompts' attribute
# which we need to access for the training arguments.
prompts = getattr(model, 'prompts', {}).get(task_name)
if not prompts:
print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.")
# Fallback to an empty list or appropriate default if required by the model's structure
prompts = []
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
prompts=prompts,
num_train_epochs=4,
per_device_train_batch_size=1,
learning_rate=2e-5,
warmup_ratio=0.1,
logging_steps=train_dataset.num_rows,
report_to="none",
save_strategy="no" # No saving during training, only at the end
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
callbacks=[EvaluationCallback(search_fn)]
)
trainer.train()
print("Training finished. Model weights are updated in memory.")
# Save the final fine-tuned model
trainer.save_model()
print(f"Model saved locally to: {output_dir}")
|