Nagharjun Mathi Mariappan commited on
Commit
84629c8
·
1 Parent(s): 8064e4b

fix llm explanation

Browse files
backend/engines/cards_engine.py CHANGED
@@ -108,40 +108,40 @@ class CardsEngine:
108
  out["score"] = score[idx]
109
  return out
110
 
111
- def explain(self, user_input: dict, card_row: dict, score: float) -> str:
112
- prompt = f"""You are explaining credit-card recommendations to a user.
113
- Keep it short (6-10 bullet lines). Be careful: do not invent fees/APR/rewards.
114
- If a field is missing/blank, say "not listed".
115
-
116
- User profile:
117
- - credit_score: {user_input.get("credit_score")}
118
- - annual_income: {user_input.get("annual_income")}
119
- - carry_balance: {user_input.get("carry_balance")}
120
- - travel_abroad: {user_input.get("travel_abroad")}
121
- - no_annual_fee: {user_input.get("no_annual_fee")}
122
- - balance_transfer: {user_input.get("balance_transfer")}
123
- - rewards_pref: {user_input.get("rewards_pref")}
124
- - monthly spend: groceries={user_input.get("spend_groceries")}, dining={user_input.get("spend_dining")},
125
- gas={user_input.get("spend_gas")}, travel={user_input.get("spend_travel")}, online={user_input.get("spend_online")}
126
-
127
- Recommended card (from dataset row):
128
- - Institution: {card_row.get("Institution Name")}
129
- - Product: {card_row.get("Product Name")}
130
- - Annual Fee: {card_row.get("Annual Fee")}
131
- - Foreign Transaction Fee (%): {card_row.get("Foreign Transaction Fee (%)")}
132
- - Purchase APR median: {card_row.get("Purchase APR median")}
133
- - Balance Transfer Offered?: {card_row.get("Balance Transfer Offered?")}
134
- - Rewards: {card_row.get("Rewards")}
135
- - Card Features: {card_row.get("Card Features")}
136
- - Website: {card_row.get("Website for Consumer")}
137
- - Phone: {card_row.get("Telephone Number for Consumers")}
138
-
139
- Model score: {score:.3f}
140
-
141
- Explain:
142
- 1) why it matches this user's needs
143
- 2) what to watch out for (APR/fees)
144
- 3) who this card is best for (one sentence)
145
- Add a final line: "Not financial advice."
146
- """
147
- return generate_text(prompt)
 
108
  out["score"] = score[idx]
109
  return out
110
 
111
+ def explain(self, user_input: dict, card_row: dict, score: float) -> str:
112
+ prompt = f"""You are explaining credit-card recommendations to a user.
113
+ Keep it short (6-10 bullet lines). Be careful: do not invent fees/APR/rewards.
114
+ If a field is missing/blank, say "not listed".
115
+
116
+ User profile:
117
+ - credit_score: {user_input.get("credit_score")}
118
+ - annual_income: {user_input.get("annual_income")}
119
+ - carry_balance: {user_input.get("carry_balance")}
120
+ - travel_abroad: {user_input.get("travel_abroad")}
121
+ - no_annual_fee: {user_input.get("no_annual_fee")}
122
+ - balance_transfer: {user_input.get("balance_transfer")}
123
+ - rewards_pref: {user_input.get("rewards_pref")}
124
+ - monthly spend: groceries={user_input.get("spend_groceries")}, dining={user_input.get("spend_dining")},
125
+ gas={user_input.get("spend_gas")}, travel={user_input.get("spend_travel")}, online={user_input.get("spend_online")}
126
+
127
+ Recommended card (from dataset row):
128
+ - Institution: {card_row.get("Institution Name")}
129
+ - Product: {card_row.get("Product Name")}
130
+ - Annual Fee: {card_row.get("Annual Fee")}
131
+ - Foreign Transaction Fee (%): {card_row.get("Foreign Transaction Fee (%)")}
132
+ - Purchase APR median: {card_row.get("Purchase APR median")}
133
+ - Balance Transfer Offered?: {card_row.get("Balance Transfer Offered?")}
134
+ - Rewards: {card_row.get("Rewards")}
135
+ - Card Features: {card_row.get("Card Features")}
136
+ - Website: {card_row.get("Website for Consumer")}
137
+ - Phone: {card_row.get("Telephone Number for Consumers")}
138
+
139
+ Model score: {score:.3f}
140
+
141
+ Explain:
142
+ 1) why it matches this user's needs
143
+ 2) what to watch out for (APR/fees)
144
+ 3) who this card is best for (one sentence)
145
+ Add a final line: "Not financial advice."
146
+ """
147
+ return generate_text(prompt)
backend/engines/movies_engine.py CHANGED
@@ -115,25 +115,25 @@ class MoviesEngine:
115
  out["score"] = scores[idx]
116
  return out
117
 
118
- def explain(self, user_input: dict, movie_row: dict, score: float) -> str:
119
- genres = user_input.get("genres") or []
120
- liked_ids = user_input.get("liked_movie_ids") or []
121
-
122
- prompt = f"""You are explaining recommendations to a user in a friendly way.
123
- Keep it short (6-10 bullet lines). Do not invent facts not provided.
124
-
125
- User preferences:
126
- - Selected genres: {genres}
127
- - Liked movie ids: {liked_ids}
128
-
129
- Recommended movie:
130
- - Title: {movie_row.get("title")}
131
- - Genres: {movie_row.get("genres")}
132
- - Model score: {score:.3f}
133
-
134
- Explain:
135
- 1) why it matches the user
136
- 2) what kind of vibe they should expect based on title/genres
137
- 3) when to pick this vs the other results (one sentence)
138
- """
139
- return generate_text(prompt)
 
115
  out["score"] = scores[idx]
116
  return out
117
 
118
+ def explain(self, user_input: dict, movie_row: dict, score: float) -> str:
119
+ genres = user_input.get("genres") or []
120
+ liked_ids = user_input.get("liked_movie_ids") or []
121
+
122
+ prompt = f"""You are explaining recommendations to a user in a friendly way.
123
+ Keep it short (6-10 bullet lines). Do not invent facts not provided.
124
+
125
+ User preferences:
126
+ - Selected genres: {genres}
127
+ - Liked movie ids: {liked_ids}
128
+
129
+ Recommended movie:
130
+ - Title: {movie_row.get("title")}
131
+ - Genres: {movie_row.get("genres")}
132
+ - Model score: {score:.3f}
133
+
134
+ Explain:
135
+ 1) why it matches the user
136
+ 2) what kind of vibe they should expect based on title/genres
137
+ 3) when to pick this vs the other results (one sentence)
138
+ """
139
+ return generate_text(prompt)
backend/llm.py CHANGED
@@ -1,33 +1,34 @@
1
  import os
2
- from typing import Optional
3
- from huggingface_hub import InferenceClient
4
 
5
- def _fallback(prompt):
6
- lines = prompt.splitlines()
7
- return "Here’s why this showed up:\n- It matches your preferences.\n- It scored high in the ranking model.\n\nTip: Click a few results and compare which fits your taste/budget best."
 
 
 
 
8
 
9
- def generate_text(prompt, max_new_tokens=220):
10
  provider = (os.getenv("LLM_PROVIDER") or "off").lower()
11
  if provider == "off":
12
  return _fallback(prompt)
13
 
14
- if provider == "hf":
15
- token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
16
- model = os.getenv("HF_LLM_MODEL") or "mistralai/Mistral-7B-Instruct-v0.3"
17
- if not token:
18
- return _fallback(prompt)
19
-
20
- try:
21
- client = InferenceClient(model=model, token=token)
22
- out = client.text_generation(
23
- prompt,
24
- max_new_tokens=max_new_tokens,
25
- temperature=0.3,
26
- top_p=0.9,
27
- repetition_penalty=1.05,
28
- )
29
- return (out or "").strip() or _fallback(prompt)
30
- except Exception:
31
- return _fallback(prompt)
32
 
33
- return _fallback(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from openai import OpenAI
 
3
 
4
+ def _fallback(prompt: str) -> str:
5
+ return (
6
+ "Here’s why this showed up:\n"
7
+ "- It matches your stated preferences.\n"
8
+ "- It scored highly in the ranking model.\n"
9
+ "- Compare a couple options to pick the best fit.\n"
10
+ )
11
 
12
+ def generate_text(prompt: str, max_new_tokens: int = 220) -> str:
13
  provider = (os.getenv("LLM_PROVIDER") or "off").lower()
14
  if provider == "off":
15
  return _fallback(prompt)
16
 
17
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
18
+ model = os.getenv("HF_CHAT_MODEL") or "google/gemma-2-2b-it"
19
+ if not token:
20
+ return _fallback(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ try:
23
+ client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=token)
24
+ resp = client.chat.completions.create(
25
+ model=model,
26
+ messages=[{"role": "user", "content": prompt}],
27
+ max_tokens=max_new_tokens,
28
+ temperature=0.3,
29
+ top_p=0.9,
30
+ )
31
+ text = (resp.choices[0].message.content or "").strip()
32
+ return text or _fallback(prompt)
33
+ except Exception:
34
+ return _fallback(prompt)
backend/main.py CHANGED
@@ -143,13 +143,14 @@ class ExplainCardReq(BaseModel):
143
 
144
  @app.post("/api/cards/explain", response_model=ExplainOut)
145
  def explain_card(req: ExplainCardReq):
146
- row = cards_engine.cards[cards_engine.cards["Product Name"] == req.product]
147
  if row.empty:
148
  return ExplainOut(text="Could not find that card in the catalog.")
149
  card_row = row.iloc[0].to_dict()
150
  text = cards_engine.explain(req.user_input.model_dump(), card_row, req.score)
151
  return ExplainOut(text=text)
152
 
 
153
  from fastapi.staticfiles import StaticFiles
154
 
155
  STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
 
143
 
144
  @app.post("/api/cards/explain", response_model=ExplainOut)
145
  def explain_card(req: ExplainCardReq):
146
+ row = cards_engine.meta[cards_engine.meta["Product Name"] == req.product]
147
  if row.empty:
148
  return ExplainOut(text="Could not find that card in the catalog.")
149
  card_row = row.iloc[0].to_dict()
150
  text = cards_engine.explain(req.user_input.model_dump(), card_row, req.score)
151
  return ExplainOut(text=text)
152
 
153
+
154
  from fastapi.staticfiles import StaticFiles
155
 
156
  STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
backend/requirements.txt CHANGED
@@ -7,4 +7,5 @@ scikit-learn==1.5.2
7
  joblib==1.4.2
8
  torch==2.4.1
9
  sentence-transformers==3.0.1
10
- huggingface_hub==0.26.3
 
 
7
  joblib==1.4.2
8
  torch==2.4.1
9
  sentence-transformers==3.0.1
10
+ huggingface_hub==0.26.3
11
+ openai>=1.0.0
frontend/src/components/ResultsGrid.jsx CHANGED
@@ -1,20 +1,30 @@
1
- export default function ResultsGrid({ items, renderTitle, renderMeta, onExplain }) {
2
  if (!items?.length) return null;
3
 
4
  return (
5
- <div className="grid">
6
  {items.map((r, idx) => (
7
- <div className="resultCard" key={idx}>
8
- <div className="resultTitle">{renderTitle(r)}</div>
9
- <div className="muted small">{renderMeta(r)}</div>
 
 
 
 
 
 
 
 
 
10
 
11
- {onExplain && (
12
- <button className="secondary" style={{ marginTop: 10 }} onClick={() => onExplain(r)}>
13
- Why this?
14
- </button>
15
- )}
 
16
  </div>
17
  ))}
18
  </div>
19
  );
20
- }
 
1
+ export default function ResultsGrid({ items, renderTitle, renderMeta, onExplain, getKey }) {
2
  if (!items?.length) return null;
3
 
4
  return (
5
+ <div className="resultsList">
6
  {items.map((r, idx) => (
7
+ <div
8
+ className="resultCard"
9
+ key={getKey ? getKey(r) : (r.movieId ?? r.product ?? idx)}
10
+ >
11
+ <div className="resultRow">
12
+ <div className="resultLeft">
13
+ <div className="resultTitleRow">
14
+ <span className="resultIndex">{idx + 1}.</span>
15
+ <div className="resultTitle">{renderTitle(r)}</div>
16
+ </div>
17
+ <div className="muted small">{renderMeta(r)}</div>
18
+ </div>
19
 
20
+ {onExplain && (
21
+ <button className="secondary whyBtn" onClick={() => onExplain(r)}>
22
+ Why this?
23
+ </button>
24
+ )}
25
+ </div>
26
  </div>
27
  ))}
28
  </div>
29
  );
30
+ }
frontend/src/pages/CardsPage.jsx CHANGED
@@ -69,6 +69,7 @@ export default function CardsPage() {
69
 
70
  <ResultsGrid
71
  items={recs}
 
72
  renderTitle={(r) => r.product}
73
  renderMeta={(r) => `${r.institution || "Institution"} • Score: ${r.score.toFixed(3)}`}
74
  onExplain={explainOne}
 
69
 
70
  <ResultsGrid
71
  items={recs}
72
+ getKey={(r) => r.product}
73
  renderTitle={(r) => r.product}
74
  renderMeta={(r) => `${r.institution || "Institution"} • Score: ${r.score.toFixed(3)}`}
75
  onExplain={explainOne}
frontend/src/pages/MoviesPage.jsx CHANGED
@@ -2,7 +2,7 @@ import { useState } from "react";
2
  import MovieForm from "../components/MovieForm";
3
  import ResultsGrid from "../components/ResultsGrid";
4
 
5
- const API = import.meta.env.VITE_API_BASE || "";
6
 
7
  export default function MoviesPage() {
8
  const [liked, setLiked] = useState([]);
@@ -102,6 +102,7 @@ export default function MoviesPage() {
102
 
103
  <ResultsGrid
104
  items={recs}
 
105
  renderTitle={(r) => r.title}
106
  renderMeta={(r) => `Score: ${r.score.toFixed(3)} • ${r.genres || "-"}`}
107
  onExplain={explainOne}
 
2
  import MovieForm from "../components/MovieForm";
3
  import ResultsGrid from "../components/ResultsGrid";
4
 
5
+ const API = "";
6
 
7
  export default function MoviesPage() {
8
  const [liked, setLiked] = useState([]);
 
102
 
103
  <ResultsGrid
104
  items={recs}
105
+ getKey={(r) => r.movieId}
106
  renderTitle={(r) => r.title}
107
  renderMeta={(r) => `Score: ${r.score.toFixed(3)} • ${r.genres || "-"}`}
108
  onExplain={explainOne}
frontend/src/styles.css CHANGED
@@ -83,3 +83,55 @@ input:focus, select:focus { border-color: var(--accent); }
83
  .grid2 { grid-template-columns: 1fr; }
84
  .results { grid-template-columns: 1fr; }
85
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  .grid2 { grid-template-columns: 1fr; }
84
  .results { grid-template-columns: 1fr; }
85
  }
86
+ .resultsList{
87
+ display:flex;
88
+ flex-direction:column;
89
+ gap:12px;
90
+ margin-top:14px;
91
+ }
92
+
93
+ .resultCard{
94
+ border:1px solid rgba(255,255,255,0.10);
95
+ border-radius:14px;
96
+ padding:14px;
97
+ background:rgba(255,255,255,0.03);
98
+ }
99
+
100
+ .resultCard:hover{
101
+ background:rgba(255,255,255,0.05);
102
+ }
103
+
104
+ .resultRow{
105
+ display:flex;
106
+ align-items:flex-start;
107
+ justify-content:space-between;
108
+ gap:12px;
109
+ }
110
+
111
+ .resultLeft{
112
+ flex:1;
113
+ min-width:0;
114
+ }
115
+
116
+ .resultTitleRow{
117
+ display:flex;
118
+ align-items:baseline;
119
+ gap:10px;
120
+ }
121
+
122
+ .resultIndex{
123
+ opacity:0.65;
124
+ font-size:13px;
125
+ flex:0 0 auto;
126
+ }
127
+
128
+ .resultTitle{
129
+ font-size:18px;
130
+ font-weight:650;
131
+ line-height:1.25;
132
+ }
133
+
134
+ .whyBtn{
135
+ white-space:nowrap;
136
+ align-self:center;
137
+ }