Nagharjun Mathi Mariappan commited on
Commit ·
84629c8
1
Parent(s): 8064e4b
fix llm explanation
Browse files- backend/engines/cards_engine.py +37 -37
- backend/engines/movies_engine.py +22 -22
- backend/llm.py +26 -25
- backend/main.py +2 -1
- backend/requirements.txt +2 -1
- frontend/src/components/ResultsGrid.jsx +21 -11
- frontend/src/pages/CardsPage.jsx +1 -0
- frontend/src/pages/MoviesPage.jsx +2 -1
- frontend/src/styles.css +52 -0
backend/engines/cards_engine.py
CHANGED
|
@@ -108,40 +108,40 @@ class CardsEngine:
|
|
| 108 |
out["score"] = score[idx]
|
| 109 |
return out
|
| 110 |
|
| 111 |
-
def explain(self, user_input: dict, card_row: dict, score: float) -> str:
|
| 112 |
-
|
| 113 |
-
Keep it short (6-10 bullet lines). Be careful: do not invent fees/APR/rewards.
|
| 114 |
-
If a field is missing/blank, say "not listed".
|
| 115 |
-
|
| 116 |
-
User profile:
|
| 117 |
-
- credit_score: {user_input.get("credit_score")}
|
| 118 |
-
- annual_income: {user_input.get("annual_income")}
|
| 119 |
-
- carry_balance: {user_input.get("carry_balance")}
|
| 120 |
-
- travel_abroad: {user_input.get("travel_abroad")}
|
| 121 |
-
- no_annual_fee: {user_input.get("no_annual_fee")}
|
| 122 |
-
- balance_transfer: {user_input.get("balance_transfer")}
|
| 123 |
-
- rewards_pref: {user_input.get("rewards_pref")}
|
| 124 |
-
- monthly spend: groceries={user_input.get("spend_groceries")}, dining={user_input.get("spend_dining")},
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
Recommended card (from dataset row):
|
| 128 |
-
- Institution: {card_row.get("Institution Name")}
|
| 129 |
-
- Product: {card_row.get("Product Name")}
|
| 130 |
-
- Annual Fee: {card_row.get("Annual Fee")}
|
| 131 |
-
- Foreign Transaction Fee (%): {card_row.get("Foreign Transaction Fee (%)")}
|
| 132 |
-
- Purchase APR median: {card_row.get("Purchase APR median")}
|
| 133 |
-
- Balance Transfer Offered?: {card_row.get("Balance Transfer Offered?")}
|
| 134 |
-
- Rewards: {card_row.get("Rewards")}
|
| 135 |
-
- Card Features: {card_row.get("Card Features")}
|
| 136 |
-
- Website: {card_row.get("Website for Consumer")}
|
| 137 |
-
- Phone: {card_row.get("Telephone Number for Consumers")}
|
| 138 |
-
|
| 139 |
-
Model score: {score:.3f}
|
| 140 |
-
|
| 141 |
-
Explain:
|
| 142 |
-
1) why it matches this user's needs
|
| 143 |
-
2) what to watch out for (APR/fees)
|
| 144 |
-
3) who this card is best for (one sentence)
|
| 145 |
-
Add a final line: "Not financial advice."
|
| 146 |
-
"""
|
| 147 |
-
|
|
|
|
| 108 |
out["score"] = score[idx]
|
| 109 |
return out
|
| 110 |
|
| 111 |
+
def explain(self, user_input: dict, card_row: dict, score: float) -> str:
|
| 112 |
+
prompt = f"""You are explaining credit-card recommendations to a user.
|
| 113 |
+
Keep it short (6-10 bullet lines). Be careful: do not invent fees/APR/rewards.
|
| 114 |
+
If a field is missing/blank, say "not listed".
|
| 115 |
+
|
| 116 |
+
User profile:
|
| 117 |
+
- credit_score: {user_input.get("credit_score")}
|
| 118 |
+
- annual_income: {user_input.get("annual_income")}
|
| 119 |
+
- carry_balance: {user_input.get("carry_balance")}
|
| 120 |
+
- travel_abroad: {user_input.get("travel_abroad")}
|
| 121 |
+
- no_annual_fee: {user_input.get("no_annual_fee")}
|
| 122 |
+
- balance_transfer: {user_input.get("balance_transfer")}
|
| 123 |
+
- rewards_pref: {user_input.get("rewards_pref")}
|
| 124 |
+
- monthly spend: groceries={user_input.get("spend_groceries")}, dining={user_input.get("spend_dining")},
|
| 125 |
+
gas={user_input.get("spend_gas")}, travel={user_input.get("spend_travel")}, online={user_input.get("spend_online")}
|
| 126 |
+
|
| 127 |
+
Recommended card (from dataset row):
|
| 128 |
+
- Institution: {card_row.get("Institution Name")}
|
| 129 |
+
- Product: {card_row.get("Product Name")}
|
| 130 |
+
- Annual Fee: {card_row.get("Annual Fee")}
|
| 131 |
+
- Foreign Transaction Fee (%): {card_row.get("Foreign Transaction Fee (%)")}
|
| 132 |
+
- Purchase APR median: {card_row.get("Purchase APR median")}
|
| 133 |
+
- Balance Transfer Offered?: {card_row.get("Balance Transfer Offered?")}
|
| 134 |
+
- Rewards: {card_row.get("Rewards")}
|
| 135 |
+
- Card Features: {card_row.get("Card Features")}
|
| 136 |
+
- Website: {card_row.get("Website for Consumer")}
|
| 137 |
+
- Phone: {card_row.get("Telephone Number for Consumers")}
|
| 138 |
+
|
| 139 |
+
Model score: {score:.3f}
|
| 140 |
+
|
| 141 |
+
Explain:
|
| 142 |
+
1) why it matches this user's needs
|
| 143 |
+
2) what to watch out for (APR/fees)
|
| 144 |
+
3) who this card is best for (one sentence)
|
| 145 |
+
Add a final line: "Not financial advice."
|
| 146 |
+
"""
|
| 147 |
+
return generate_text(prompt)
|
backend/engines/movies_engine.py
CHANGED
|
@@ -115,25 +115,25 @@ class MoviesEngine:
|
|
| 115 |
out["score"] = scores[idx]
|
| 116 |
return out
|
| 117 |
|
| 118 |
-
def explain(self, user_input: dict, movie_row: dict, score: float) -> str:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
Keep it short (6-10 bullet lines). Do not invent facts not provided.
|
| 124 |
-
|
| 125 |
-
User preferences:
|
| 126 |
-
- Selected genres: {genres}
|
| 127 |
-
- Liked movie ids: {liked_ids}
|
| 128 |
-
|
| 129 |
-
Recommended movie:
|
| 130 |
-
- Title: {movie_row.get("title")}
|
| 131 |
-
- Genres: {movie_row.get("genres")}
|
| 132 |
-
- Model score: {score:.3f}
|
| 133 |
-
|
| 134 |
-
Explain:
|
| 135 |
-
1) why it matches the user
|
| 136 |
-
2) what kind of vibe they should expect based on title/genres
|
| 137 |
-
3) when to pick this vs the other results (one sentence)
|
| 138 |
-
"""
|
| 139 |
-
|
|
|
|
| 115 |
out["score"] = scores[idx]
|
| 116 |
return out
|
| 117 |
|
| 118 |
+
def explain(self, user_input: dict, movie_row: dict, score: float) -> str:
|
| 119 |
+
genres = user_input.get("genres") or []
|
| 120 |
+
liked_ids = user_input.get("liked_movie_ids") or []
|
| 121 |
+
|
| 122 |
+
prompt = f"""You are explaining recommendations to a user in a friendly way.
|
| 123 |
+
Keep it short (6-10 bullet lines). Do not invent facts not provided.
|
| 124 |
+
|
| 125 |
+
User preferences:
|
| 126 |
+
- Selected genres: {genres}
|
| 127 |
+
- Liked movie ids: {liked_ids}
|
| 128 |
+
|
| 129 |
+
Recommended movie:
|
| 130 |
+
- Title: {movie_row.get("title")}
|
| 131 |
+
- Genres: {movie_row.get("genres")}
|
| 132 |
+
- Model score: {score:.3f}
|
| 133 |
+
|
| 134 |
+
Explain:
|
| 135 |
+
1) why it matches the user
|
| 136 |
+
2) what kind of vibe they should expect based on title/genres
|
| 137 |
+
3) when to pick this vs the other results (one sentence)
|
| 138 |
+
"""
|
| 139 |
+
return generate_text(prompt)
|
backend/llm.py
CHANGED
|
@@ -1,33 +1,34 @@
|
|
| 1 |
import os
|
| 2 |
-
from
|
| 3 |
-
from huggingface_hub import InferenceClient
|
| 4 |
|
| 5 |
-
def _fallback(prompt):
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
def generate_text(prompt, max_new_tokens=220):
|
| 10 |
provider = (os.getenv("LLM_PROVIDER") or "off").lower()
|
| 11 |
if provider == "off":
|
| 12 |
return _fallback(prompt)
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
return _fallback(prompt)
|
| 19 |
-
|
| 20 |
-
try:
|
| 21 |
-
client = InferenceClient(model=model, token=token)
|
| 22 |
-
out = client.text_generation(
|
| 23 |
-
prompt,
|
| 24 |
-
max_new_tokens=max_new_tokens,
|
| 25 |
-
temperature=0.3,
|
| 26 |
-
top_p=0.9,
|
| 27 |
-
repetition_penalty=1.05,
|
| 28 |
-
)
|
| 29 |
-
return (out or "").strip() or _fallback(prompt)
|
| 30 |
-
except Exception:
|
| 31 |
-
return _fallback(prompt)
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
from openai import OpenAI
|
|
|
|
| 3 |
|
| 4 |
+
def _fallback(prompt: str) -> str:
|
| 5 |
+
return (
|
| 6 |
+
"Here’s why this showed up:\n"
|
| 7 |
+
"- It matches your stated preferences.\n"
|
| 8 |
+
"- It scored highly in the ranking model.\n"
|
| 9 |
+
"- Compare a couple options to pick the best fit.\n"
|
| 10 |
+
)
|
| 11 |
|
| 12 |
+
def generate_text(prompt: str, max_new_tokens: int = 220) -> str:
|
| 13 |
provider = (os.getenv("LLM_PROVIDER") or "off").lower()
|
| 14 |
if provider == "off":
|
| 15 |
return _fallback(prompt)
|
| 16 |
|
| 17 |
+
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 18 |
+
model = os.getenv("HF_CHAT_MODEL") or "google/gemma-2-2b-it"
|
| 19 |
+
if not token:
|
| 20 |
+
return _fallback(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
try:
|
| 23 |
+
client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=token)
|
| 24 |
+
resp = client.chat.completions.create(
|
| 25 |
+
model=model,
|
| 26 |
+
messages=[{"role": "user", "content": prompt}],
|
| 27 |
+
max_tokens=max_new_tokens,
|
| 28 |
+
temperature=0.3,
|
| 29 |
+
top_p=0.9,
|
| 30 |
+
)
|
| 31 |
+
text = (resp.choices[0].message.content or "").strip()
|
| 32 |
+
return text or _fallback(prompt)
|
| 33 |
+
except Exception:
|
| 34 |
+
return _fallback(prompt)
|
backend/main.py
CHANGED
|
@@ -143,13 +143,14 @@ class ExplainCardReq(BaseModel):
|
|
| 143 |
|
| 144 |
@app.post("/api/cards/explain", response_model=ExplainOut)
|
| 145 |
def explain_card(req: ExplainCardReq):
|
| 146 |
-
row = cards_engine.
|
| 147 |
if row.empty:
|
| 148 |
return ExplainOut(text="Could not find that card in the catalog.")
|
| 149 |
card_row = row.iloc[0].to_dict()
|
| 150 |
text = cards_engine.explain(req.user_input.model_dump(), card_row, req.score)
|
| 151 |
return ExplainOut(text=text)
|
| 152 |
|
|
|
|
| 153 |
from fastapi.staticfiles import StaticFiles
|
| 154 |
|
| 155 |
STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
|
|
|
|
| 143 |
|
| 144 |
@app.post("/api/cards/explain", response_model=ExplainOut)
|
| 145 |
def explain_card(req: ExplainCardReq):
|
| 146 |
+
row = cards_engine.meta[cards_engine.meta["Product Name"] == req.product]
|
| 147 |
if row.empty:
|
| 148 |
return ExplainOut(text="Could not find that card in the catalog.")
|
| 149 |
card_row = row.iloc[0].to_dict()
|
| 150 |
text = cards_engine.explain(req.user_input.model_dump(), card_row, req.score)
|
| 151 |
return ExplainOut(text=text)
|
| 152 |
|
| 153 |
+
|
| 154 |
from fastapi.staticfiles import StaticFiles
|
| 155 |
|
| 156 |
STATIC_DIR = os.path.join(os.path.dirname(__file__), "static")
|
backend/requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ scikit-learn==1.5.2
|
|
| 7 |
joblib==1.4.2
|
| 8 |
torch==2.4.1
|
| 9 |
sentence-transformers==3.0.1
|
| 10 |
-
huggingface_hub==0.26.3
|
|
|
|
|
|
| 7 |
joblib==1.4.2
|
| 8 |
torch==2.4.1
|
| 9 |
sentence-transformers==3.0.1
|
| 10 |
+
huggingface_hub==0.26.3
|
| 11 |
+
openai>=1.0.0
|
frontend/src/components/ResultsGrid.jsx
CHANGED
|
@@ -1,20 +1,30 @@
|
|
| 1 |
-
export default function ResultsGrid({ items, renderTitle, renderMeta, onExplain }) {
|
| 2 |
if (!items?.length) return null;
|
| 3 |
|
| 4 |
return (
|
| 5 |
-
<div className="
|
| 6 |
{items.map((r, idx) => (
|
| 7 |
-
<div
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
| 16 |
</div>
|
| 17 |
))}
|
| 18 |
</div>
|
| 19 |
);
|
| 20 |
-
}
|
|
|
|
| 1 |
+
export default function ResultsGrid({ items, renderTitle, renderMeta, onExplain, getKey }) {
|
| 2 |
if (!items?.length) return null;
|
| 3 |
|
| 4 |
return (
|
| 5 |
+
<div className="resultsList">
|
| 6 |
{items.map((r, idx) => (
|
| 7 |
+
<div
|
| 8 |
+
className="resultCard"
|
| 9 |
+
key={getKey ? getKey(r) : (r.movieId ?? r.product ?? idx)}
|
| 10 |
+
>
|
| 11 |
+
<div className="resultRow">
|
| 12 |
+
<div className="resultLeft">
|
| 13 |
+
<div className="resultTitleRow">
|
| 14 |
+
<span className="resultIndex">{idx + 1}.</span>
|
| 15 |
+
<div className="resultTitle">{renderTitle(r)}</div>
|
| 16 |
+
</div>
|
| 17 |
+
<div className="muted small">{renderMeta(r)}</div>
|
| 18 |
+
</div>
|
| 19 |
|
| 20 |
+
{onExplain && (
|
| 21 |
+
<button className="secondary whyBtn" onClick={() => onExplain(r)}>
|
| 22 |
+
Why this?
|
| 23 |
+
</button>
|
| 24 |
+
)}
|
| 25 |
+
</div>
|
| 26 |
</div>
|
| 27 |
))}
|
| 28 |
</div>
|
| 29 |
);
|
| 30 |
+
}
|
frontend/src/pages/CardsPage.jsx
CHANGED
|
@@ -69,6 +69,7 @@ export default function CardsPage() {
|
|
| 69 |
|
| 70 |
<ResultsGrid
|
| 71 |
items={recs}
|
|
|
|
| 72 |
renderTitle={(r) => r.product}
|
| 73 |
renderMeta={(r) => `${r.institution || "Institution"} • Score: ${r.score.toFixed(3)}`}
|
| 74 |
onExplain={explainOne}
|
|
|
|
| 69 |
|
| 70 |
<ResultsGrid
|
| 71 |
items={recs}
|
| 72 |
+
getKey={(r) => r.product}
|
| 73 |
renderTitle={(r) => r.product}
|
| 74 |
renderMeta={(r) => `${r.institution || "Institution"} • Score: ${r.score.toFixed(3)}`}
|
| 75 |
onExplain={explainOne}
|
frontend/src/pages/MoviesPage.jsx
CHANGED
|
@@ -2,7 +2,7 @@ import { useState } from "react";
|
|
| 2 |
import MovieForm from "../components/MovieForm";
|
| 3 |
import ResultsGrid from "../components/ResultsGrid";
|
| 4 |
|
| 5 |
-
const API =
|
| 6 |
|
| 7 |
export default function MoviesPage() {
|
| 8 |
const [liked, setLiked] = useState([]);
|
|
@@ -102,6 +102,7 @@ export default function MoviesPage() {
|
|
| 102 |
|
| 103 |
<ResultsGrid
|
| 104 |
items={recs}
|
|
|
|
| 105 |
renderTitle={(r) => r.title}
|
| 106 |
renderMeta={(r) => `Score: ${r.score.toFixed(3)} • ${r.genres || "-"}`}
|
| 107 |
onExplain={explainOne}
|
|
|
|
| 2 |
import MovieForm from "../components/MovieForm";
|
| 3 |
import ResultsGrid from "../components/ResultsGrid";
|
| 4 |
|
| 5 |
+
const API = "";
|
| 6 |
|
| 7 |
export default function MoviesPage() {
|
| 8 |
const [liked, setLiked] = useState([]);
|
|
|
|
| 102 |
|
| 103 |
<ResultsGrid
|
| 104 |
items={recs}
|
| 105 |
+
getKey={(r) => r.movieId}
|
| 106 |
renderTitle={(r) => r.title}
|
| 107 |
renderMeta={(r) => `Score: ${r.score.toFixed(3)} • ${r.genres || "-"}`}
|
| 108 |
onExplain={explainOne}
|
frontend/src/styles.css
CHANGED
|
@@ -83,3 +83,55 @@ input:focus, select:focus { border-color: var(--accent); }
|
|
| 83 |
.grid2 { grid-template-columns: 1fr; }
|
| 84 |
.results { grid-template-columns: 1fr; }
|
| 85 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
.grid2 { grid-template-columns: 1fr; }
|
| 84 |
.results { grid-template-columns: 1fr; }
|
| 85 |
}
|
| 86 |
+
.resultsList{
|
| 87 |
+
display:flex;
|
| 88 |
+
flex-direction:column;
|
| 89 |
+
gap:12px;
|
| 90 |
+
margin-top:14px;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.resultCard{
|
| 94 |
+
border:1px solid rgba(255,255,255,0.10);
|
| 95 |
+
border-radius:14px;
|
| 96 |
+
padding:14px;
|
| 97 |
+
background:rgba(255,255,255,0.03);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.resultCard:hover{
|
| 101 |
+
background:rgba(255,255,255,0.05);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.resultRow{
|
| 105 |
+
display:flex;
|
| 106 |
+
align-items:flex-start;
|
| 107 |
+
justify-content:space-between;
|
| 108 |
+
gap:12px;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.resultLeft{
|
| 112 |
+
flex:1;
|
| 113 |
+
min-width:0;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
.resultTitleRow{
|
| 117 |
+
display:flex;
|
| 118 |
+
align-items:baseline;
|
| 119 |
+
gap:10px;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.resultIndex{
|
| 123 |
+
opacity:0.65;
|
| 124 |
+
font-size:13px;
|
| 125 |
+
flex:0 0 auto;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
.resultTitle{
|
| 129 |
+
font-size:18px;
|
| 130 |
+
font-weight:650;
|
| 131 |
+
line-height:1.25;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
.whyBtn{
|
| 135 |
+
white-space:nowrap;
|
| 136 |
+
align-self:center;
|
| 137 |
+
}
|