File size: 5,803 Bytes
d035cda
beabfb7
 
 
 
 
e6cb750
beabfb7
71bbed6
beabfb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6cb750
beabfb7
 
 
 
 
 
 
e6cb750
 
 
beabfb7
 
e6cb750
beabfb7
 
 
 
 
 
 
 
 
 
 
71bbed6
d035cda
 
 
 
ad4164f
e6cb750
 
71bbed6
beabfb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6cb750
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from huggingface_hub import login, HfApi, model_info, metadata_update
from sentence_transformers import SentenceTransformer, util
from datasets import Dataset
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from transformers import TrainerCallback, TrainingArguments
from typing import List, Callable, Optional
from pathlib import Path
from .config import AppConfig

# --- Model/Utility Functions ---

def authenticate_hf(token: Optional[str]) -> None:
    """Logs into the Hugging Face Hub."""
    if token:
        print("Logging into Hugging Face Hub...")
        login(token=token)
    else:
        print("Skipping Hugging Face login: HF_TOKEN not set.")

def load_embedding_model(model_name: str) -> SentenceTransformer:
    """Initializes the Sentence Transformer model."""
    print(f"Loading Sentence Transformer model: {model_name}")
    try:
        model = SentenceTransformer(model_name, model_kwargs={"device_map": "auto"})
        print(f"Model loaded successfully. {model.device}")
        return model
    except Exception as e:
        print(f"Error loading Sentence Transformer model {model_name}: {e}")
        raise

def get_top_hits(
    model: SentenceTransformer,
    target_titles: List[str],
    task_name: str,
    query: str = "MY_FAVORITE_NEWS",
    top_k: int = 5
) -> str:
    """Performs semantic search on target_titles and returns a formatted result string."""
    if not target_titles:
        return "No target titles available for search."

    # Encode the query
    query_embedding = model.encode(query, prompt_name=task_name)

    # Encode the target titles (only done once per call)
    title_embeddings = model.encode(target_titles, prompt_name=task_name)

    # Perform semantic search
    top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0]

    result = []
    for hit in top_hits:
        title = target_titles[hit['corpus_id']]
        score = hit['score']
        result.append(f"[{title}] {score:.4f}")

    return "\n".join(result)

def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
    """
    Uploads a local model folder to the Hugging Face Hub.
    Creates the repository if it doesn't exist.
    """
    try:
        api = HfApi(token=token)
        
        # Get the authenticated user's username
        user_info = api.whoami()
        username = user_info['name']
        
        # Construct the full repo ID
        repo_id = f"{username}/{repo_name}"
        print(f"Preparing to upload to: {repo_id}")

        # Create the repo (safe if it already exists)
        api.create_repo(repo_id=repo_id, exist_ok=True)
        
        # Upload the folder
        url = api.upload_folder(
            folder_path=folder_path,
            repo_id=repo_id,
            repo_type="model"
        )
        
        info = model_info(
            repo_id=repo_id,
            token=token
        )
        tags = info.card_data.tags
        tags.append("embeddinggemma-tuning-lab")
        metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
        
        return f"✅ Success! Model published at: {url}"
    except Exception as e:
        print(f"Upload failed: {e}")
        return f"❌ Upload failed: {str(e)}"

# --- Training Class and Function ---

class EvaluationCallback(TrainerCallback):
    """
    A callback that runs the semantic search evaluation at the end of each log step.
    The search function is passed in during initialization.
    """
    def __init__(self, search_fn: Callable[[], str]):
        self.search_fn = search_fn

    def on_log(self, args: TrainingArguments, state, control, **kwargs):
        print(f"Step {state.global_step} finished. Running evaluation:")
        print(f"\n{self.search_fn()}\n")


def train_with_dataset(
    model: SentenceTransformer,
    dataset: List[List[str]],
    output_dir: Path,
    task_name: str,
    search_fn: Callable[[], str]
) -> None:
    """
    Fine-tunes the provided Sentence Transformer MODEL on the dataset.

    The dataset should be a list of lists: [[anchor, positive, negative], ...].
    """
    # Convert to Hugging Face Dataset format
    data_as_dicts = [
        {"anchor": row[0], "positive": row[1], "negative": row[2]}
        for row in dataset
    ]

    train_dataset = Dataset.from_list(data_as_dicts)

    # Use MultipleNegativesRankingLoss, suitable for contrastive learning
    loss = MultipleNegativesRankingLoss(model)

    # Note: SentenceTransformer models typically have a 'prompts' attribute
    # which we need to access for the training arguments.
    prompts = getattr(model, 'prompts', {}).get(task_name)
    if not prompts:
        print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.")
        # Fallback to an empty list or appropriate default if required by the model's structure
        prompts = [] 

    args = SentenceTransformerTrainingArguments(
        output_dir=output_dir,
        prompts=prompts,
        num_train_epochs=4,
        per_device_train_batch_size=1,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        logging_steps=train_dataset.num_rows,
        report_to="none",
        save_strategy="no" # No saving during training, only at the end
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
        callbacks=[EvaluationCallback(search_fn)]
    )

    trainer.train()

    print("Training finished. Model weights are updated in memory.")

    # Save the final fine-tuned model
    trainer.save_model()

    print(f"Model saved locally to: {output_dir}")