rishabhsetiya commited on
Commit
f19235b
·
verified ·
1 Parent(s): f5ae900

Update fine_tuning.py

Browse files
Files changed (1) hide show
  1. fine_tuning.py +36 -0
fine_tuning.py CHANGED
@@ -160,6 +160,38 @@ def load_and_train(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
160
  trainer.train()
161
  model.eval()
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # -----------------------------
164
  # GENERATE ANSWER
165
  # -----------------------------
@@ -167,6 +199,10 @@ def generate_answer(prompt, max_tokens=200):
167
  if prompt.strip() == "":
168
  return "Please enter a prompt!"
169
 
 
 
 
 
170
  system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
171
  messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
172
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
160
  trainer.train()
161
  model.eval()
162
 
163
+ # ---------------- Guardrails ----------------
164
+ BLOCKED_TERMS = ["weather", "cricket", "movie", "song", "football", "holiday",
165
+ "travel", "recipe", "music", "game", "sports", "politics", "election"]
166
+
167
+ FINANCE_DOMAINS = [
168
+ "financial reporting", "balance sheet", "income statement",
169
+ "assets and liabilities", "equity", "revenue", "profit and loss",
170
+ "goodwill impairment", "cash flow", "dividends", "taxation",
171
+ "investment", "valuation", "capital structure", "ownership interests",
172
+ "subsidiaries", "shareholders equity", "expenses", "earnings",
173
+ "debt", "amortization", "depreciation"
174
+ ]
175
+ finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)
176
+
177
+ #--------------------------------------------------------------
178
+ # GUARD RAIL
179
+ #--------------------------------------------------------------
180
+ def validate_query(query: str, threshold: float = 0.5) -> bool:
181
+ q_lower = query.lower()
182
+ if any(bad in q_lower for bad in BLOCKED_TERMS):
183
+ print("[Guardrail] Rejected by blocklist.")
184
+ return False
185
+ q_emb = embed_model.encode(query, convert_to_tensor=True)
186
+ sim_scores = util.cos_sim(q_emb, finance_embeds)
187
+ max_score = float(sim_scores.max())
188
+ if max_score > threshold:
189
+ print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
190
+ return True
191
+ else:
192
+ print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
193
+ return False
194
+
195
  # -----------------------------
196
  # GENERATE ANSWER
197
  # -----------------------------
 
199
  if prompt.strip() == "":
200
  return "Please enter a prompt!"
201
 
202
+ if not validate_query(query):
203
+ print("Query rejected: Not finance-related.")
204
+ return "Query rejected: Please ask finance-related questions."
205
+
206
  system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
207
  messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
208
  input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)