Spaces:
Sleeping
Sleeping
Update fine_tuning.py
Browse files- fine_tuning.py +36 -0
fine_tuning.py
CHANGED
|
@@ -160,6 +160,38 @@ def load_and_train(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
|
|
| 160 |
trainer.train()
|
| 161 |
model.eval()
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# -----------------------------
|
| 164 |
# GENERATE ANSWER
|
| 165 |
# -----------------------------
|
|
@@ -167,6 +199,10 @@ def generate_answer(prompt, max_tokens=200):
|
|
| 167 |
if prompt.strip() == "":
|
| 168 |
return "Please enter a prompt!"
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
|
| 171 |
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
|
| 172 |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
| 160 |
trainer.train()
|
| 161 |
model.eval()
|
| 162 |
|
| 163 |
+
# ---------------- Guardrails ----------------
|
| 164 |
+
BLOCKED_TERMS = ["weather", "cricket", "movie", "song", "football", "holiday",
|
| 165 |
+
"travel", "recipe", "music", "game", "sports", "politics", "election"]
|
| 166 |
+
|
| 167 |
+
FINANCE_DOMAINS = [
|
| 168 |
+
"financial reporting", "balance sheet", "income statement",
|
| 169 |
+
"assets and liabilities", "equity", "revenue", "profit and loss",
|
| 170 |
+
"goodwill impairment", "cash flow", "dividends", "taxation",
|
| 171 |
+
"investment", "valuation", "capital structure", "ownership interests",
|
| 172 |
+
"subsidiaries", "shareholders equity", "expenses", "earnings",
|
| 173 |
+
"debt", "amortization", "depreciation"
|
| 174 |
+
]
|
| 175 |
+
finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
|
| 176 |
+
|
| 177 |
+
#--------------------------------------------------------------
|
| 178 |
+
# GUARD RAIL
|
| 179 |
+
#--------------------------------------------------------------
|
| 180 |
+
def validate_query(query: str, threshold: float = 0.5) -> bool:
|
| 181 |
+
q_lower = query.lower()
|
| 182 |
+
if any(bad in q_lower for bad in BLOCKED_TERMS):
|
| 183 |
+
print("[Guardrail] Rejected by blocklist.")
|
| 184 |
+
return False
|
| 185 |
+
q_emb = embed_model.encode(query, convert_to_tensor=True)
|
| 186 |
+
sim_scores = util.cos_sim(q_emb, finance_embeds)
|
| 187 |
+
max_score = float(sim_scores.max())
|
| 188 |
+
if max_score > threshold:
|
| 189 |
+
print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
|
| 190 |
+
return True
|
| 191 |
+
else:
|
| 192 |
+
print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
# -----------------------------
|
| 196 |
# GENERATE ANSWER
|
| 197 |
# -----------------------------
|
|
|
|
| 199 |
if prompt.strip() == "":
|
| 200 |
return "Please enter a prompt!"
|
| 201 |
|
| 202 |
+
if not validate_query(query):
|
| 203 |
+
print("Query rejected: Not finance-related.")
|
| 204 |
+
return "Query rejected: Please ask finance-related questions."
|
| 205 |
+
|
| 206 |
system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
|
| 207 |
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
|
| 208 |
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|