from huggingface_hub import login, HfApi, model_info, metadata_update from sentence_transformers import SentenceTransformer, util from datasets import Dataset from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.losses import MultipleNegativesRankingLoss from transformers import TrainerCallback, TrainingArguments from typing import List, Callable, Optional from pathlib import Path from .config import AppConfig # --- Model/Utility Functions --- def authenticate_hf(token: Optional[str]) -> None: """Logs into the Hugging Face Hub.""" if token: print("Logging into Hugging Face Hub...") login(token=token) else: print("Skipping Hugging Face login: HF_TOKEN not set.") def load_embedding_model(model_name: str) -> SentenceTransformer: """Initializes the Sentence Transformer model.""" print(f"Loading Sentence Transformer model: {model_name}") try: model = SentenceTransformer(model_name, model_kwargs={"device_map": "auto"}) print(f"Model loaded successfully. {model.device}") return model except Exception as e: print(f"Error loading Sentence Transformer model {model_name}: {e}") raise def get_top_hits( model: SentenceTransformer, target_titles: List[str], task_name: str, query: str = "MY_FAVORITE_NEWS", top_k: int = 5 ) -> str: """Performs semantic search on target_titles and returns a formatted result string.""" if not target_titles: return "No target titles available for search." # Encode the query query_embedding = model.encode(query, prompt_name=task_name) # Encode the target titles (only done once per call) title_embeddings = model.encode(target_titles, prompt_name=task_name) # Perform semantic search top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0] result = [] for hit in top_hits: title = target_titles[hit['corpus_id']] score = hit['score'] result.append(f"[{title}] {score:.4f}") return "\n".join(result) def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str: """ Uploads a local model folder to the Hugging Face Hub. Creates the repository if it doesn't exist. """ try: api = HfApi(token=token) # Get the authenticated user's username user_info = api.whoami() username = user_info['name'] # Construct the full repo ID repo_id = f"{username}/{repo_name}" print(f"Preparing to upload to: {repo_id}") # Create the repo (safe if it already exists) api.create_repo(repo_id=repo_id, exist_ok=True) # Upload the folder url = api.upload_folder( folder_path=folder_path, repo_id=repo_id, repo_type="model" ) info = model_info( repo_id=repo_id, token=token ) tags = info.card_data.tags tags.append("embeddinggemma-tuning-lab") metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token) return f"✅ Success! Model published at: {url}" except Exception as e: print(f"Upload failed: {e}") return f"❌ Upload failed: {str(e)}" # --- Training Class and Function --- class EvaluationCallback(TrainerCallback): """ A callback that runs the semantic search evaluation at the end of each log step. The search function is passed in during initialization. """ def __init__(self, search_fn: Callable[[], str]): self.search_fn = search_fn def on_log(self, args: TrainingArguments, state, control, **kwargs): print(f"Step {state.global_step} finished. Running evaluation:") print(f"\n{self.search_fn()}\n") def train_with_dataset( model: SentenceTransformer, dataset: List[List[str]], output_dir: Path, task_name: str, search_fn: Callable[[], str] ) -> None: """ Fine-tunes the provided Sentence Transformer MODEL on the dataset. The dataset should be a list of lists: [[anchor, positive, negative], ...]. """ # Convert to Hugging Face Dataset format data_as_dicts = [ {"anchor": row[0], "positive": row[1], "negative": row[2]} for row in dataset ] train_dataset = Dataset.from_list(data_as_dicts) # Use MultipleNegativesRankingLoss, suitable for contrastive learning loss = MultipleNegativesRankingLoss(model) # Note: SentenceTransformer models typically have a 'prompts' attribute # which we need to access for the training arguments. prompts = getattr(model, 'prompts', {}).get(task_name) if not prompts: print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.") # Fallback to an empty list or appropriate default if required by the model's structure prompts = [] args = SentenceTransformerTrainingArguments( output_dir=output_dir, prompts=prompts, num_train_epochs=4, per_device_train_batch_size=1, learning_rate=2e-5, warmup_ratio=0.1, logging_steps=train_dataset.num_rows, report_to="none", save_strategy="no" # No saving during training, only at the end ) trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, callbacks=[EvaluationCallback(search_fn)] ) trainer.train() print("Training finished. Model weights are updated in memory.") # Save the final fine-tuned model trainer.save_model() print(f"Model saved locally to: {output_dir}")