Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,7 +22,6 @@ base = AutoModelForCausalLM.from_pretrained(
|
|
| 22 |
model = PeftModel.from_pretrained(base, ADAPTER_REPO).to(device)
|
| 23 |
model.eval()
|
| 24 |
|
| 25 |
-
# Optional merge (can speed up). If it fails, just continue.
|
| 26 |
try:
|
| 27 |
model = model.merge_and_unload()
|
| 28 |
model.to(device)
|
|
@@ -30,11 +29,18 @@ try:
|
|
| 30 |
except Exception:
|
| 31 |
pass
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
finance_words = [
|
| 34 |
"stock","shares","profit","profits","loss","losses","revenue","earnings","dividend","market",
|
| 35 |
"bank","loan","interest","inflation","bond","equity","merger","acquisition",
|
| 36 |
"ipo","valuation","cash","cashflow","forecast","guidance","quarter","q1","q2","q3","q4",
|
| 37 |
-
"ceo","cfo","board","layoffs","bankruptcy","debt","default","margin"
|
| 38 |
]
|
| 39 |
|
| 40 |
def looks_finance(text: str) -> bool:
|
|
@@ -45,33 +51,29 @@ def is_greeting(text: str) -> bool:
|
|
| 45 |
t = (text or "").lower().strip()
|
| 46 |
return t in ["hi", "hello", "hey", "good morning", "good afternoon", "good evening"]
|
| 47 |
|
| 48 |
-
def extract_label(gen_text: str) -> str:
|
| 49 |
-
"""
|
| 50 |
-
Extract the first occurrence of one of the labels from generated text only.
|
| 51 |
-
"""
|
| 52 |
-
t = (gen_text or "").lower()
|
| 53 |
-
m = re.search(r"\b(negative|neutral|positive)\b", t)
|
| 54 |
-
return m.group(1) if m else "neutral"
|
| 55 |
-
|
| 56 |
@torch.inference_mode()
|
| 57 |
-
def
|
| 58 |
"""
|
| 59 |
-
|
| 60 |
-
|
| 61 |
"""
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
|
| 76 |
@torch.inference_mode()
|
| 77 |
def predict_label(msg: str) -> str:
|
|
@@ -80,11 +82,10 @@ def predict_label(msg: str) -> str:
|
|
| 80 |
f"Text: {msg.strip()}\n"
|
| 81 |
"Answer:"
|
| 82 |
)
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
return label
|
| 88 |
|
| 89 |
def chat(msg, history):
|
| 90 |
msg = (msg or "").strip()
|
|
|
|
| 22 |
model = PeftModel.from_pretrained(base, ADAPTER_REPO).to(device)
|
| 23 |
model.eval()
|
| 24 |
|
|
|
|
| 25 |
try:
|
| 26 |
model = model.merge_and_unload()
|
| 27 |
model.to(device)
|
|
|
|
| 29 |
except Exception:
|
| 30 |
pass
|
| 31 |
|
| 32 |
+
LABELS = ["negative", "neutral", "positive"]
|
| 33 |
+
|
| 34 |
+
label_token_ids = {
|
| 35 |
+
lab: tokenizer(" " + lab, add_special_tokens=False)["input_ids"]
|
| 36 |
+
for lab in LABELS
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
finance_words = [
|
| 40 |
"stock","shares","profit","profits","loss","losses","revenue","earnings","dividend","market",
|
| 41 |
"bank","loan","interest","inflation","bond","equity","merger","acquisition",
|
| 42 |
"ipo","valuation","cash","cashflow","forecast","guidance","quarter","q1","q2","q3","q4",
|
| 43 |
+
"ceo","cfo","board","layoffs","bankruptcy","debt","default","margin","miss","downgrade"
|
| 44 |
]
|
| 45 |
|
| 46 |
def looks_finance(text: str) -> bool:
|
|
|
|
| 51 |
t = (text or "").lower().strip()
|
| 52 |
return t in ["hi", "hello", "hey", "good morning", "good afternoon", "good evening"]
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
@torch.inference_mode()
|
| 55 |
+
def score_label_with_cache(prompt_ids, lab_ids) -> float:
|
| 56 |
"""
|
| 57 |
+
Score P(label | prompt) using cached past_key_values.
|
| 58 |
+
Returns average log-prob per label token (length-normalized).
|
| 59 |
"""
|
| 60 |
+
# Run prompt once to get cache
|
| 61 |
+
prompt = torch.tensor([prompt_ids], device=device)
|
| 62 |
+
out = model(input_ids=prompt, use_cache=True)
|
| 63 |
+
past = out.past_key_values
|
| 64 |
+
|
| 65 |
+
logp_sum = 0.0
|
| 66 |
+
prev_token = prompt[:, -1:]
|
| 67 |
+
|
| 68 |
+
for tok_id in lab_ids:
|
| 69 |
+
step = model(input_ids=prev_token, past_key_values=past, use_cache=True)
|
| 70 |
+
logits = step.logits[:, -1, :]
|
| 71 |
+
logp_sum += torch.log_softmax(logits, dim=-1)[0, tok_id].item()
|
| 72 |
|
| 73 |
+
past = step.past_key_values
|
| 74 |
+
prev_token = torch.tensor([[tok_id]], device=device)
|
| 75 |
+
|
| 76 |
+
return logp_sum / max(len(lab_ids), 1)
|
| 77 |
|
| 78 |
@torch.inference_mode()
|
| 79 |
def predict_label(msg: str) -> str:
|
|
|
|
| 82 |
f"Text: {msg.strip()}\n"
|
| 83 |
"Answer:"
|
| 84 |
)
|
| 85 |
+
prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
| 86 |
|
| 87 |
+
scores = {lab: score_label_with_cache(prompt_ids, label_token_ids[lab]) for lab in LABELS}
|
| 88 |
+
return max(scores, key=scores.get)
|
|
|
|
|
|
|
| 89 |
|
| 90 |
def chat(msg, history):
|
| 91 |
msg = (msg or "").strip()
|