Wengelawiit commited on
Commit
0858872
·
verified ·
1 Parent(s): 19f248b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -29
app.py CHANGED
@@ -22,7 +22,6 @@ base = AutoModelForCausalLM.from_pretrained(
22
  model = PeftModel.from_pretrained(base, ADAPTER_REPO).to(device)
23
  model.eval()
24
 
25
- # Optional merge (can speed up). If it fails, just continue.
26
  try:
27
  model = model.merge_and_unload()
28
  model.to(device)
@@ -30,11 +29,18 @@ try:
30
  except Exception:
31
  pass
32
 
 
 
 
 
 
 
 
33
  finance_words = [
34
  "stock","shares","profit","profits","loss","losses","revenue","earnings","dividend","market",
35
  "bank","loan","interest","inflation","bond","equity","merger","acquisition",
36
  "ipo","valuation","cash","cashflow","forecast","guidance","quarter","q1","q2","q3","q4",
37
- "ceo","cfo","board","layoffs","bankruptcy","debt","default","margin"
38
  ]
39
 
40
  def looks_finance(text: str) -> bool:
@@ -45,33 +51,29 @@ def is_greeting(text: str) -> bool:
45
  t = (text or "").lower().strip()
46
  return t in ["hi", "hello", "hey", "good morning", "good afternoon", "good evening"]
47
 
48
- def extract_label(gen_text: str) -> str:
49
- """
50
- Extract the first occurrence of one of the labels from generated text only.
51
- """
52
- t = (gen_text or "").lower()
53
- m = re.search(r"\b(negative|neutral|positive)\b", t)
54
- return m.group(1) if m else "neutral"
55
-
56
  @torch.inference_mode()
57
- def generate_answer_only(prompt: str, max_new_tokens: int = 4) -> str:
58
  """
59
- Generate ONLY the new tokens after the prompt.
60
- This avoids the 'prompt contains the labels' bug.
61
  """
62
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
63
- input_len = inputs["input_ids"].shape[1]
64
-
65
- out = model.generate(
66
- **inputs,
67
- max_new_tokens=max_new_tokens,
68
- do_sample=False,
69
- temperature=0.0,
70
- pad_token_id=tokenizer.eos_token_id
71
- )
 
 
72
 
73
- gen_tokens = out[0][input_len:]
74
- return tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
 
 
75
 
76
  @torch.inference_mode()
77
  def predict_label(msg: str) -> str:
@@ -80,11 +82,10 @@ def predict_label(msg: str) -> str:
80
  f"Text: {msg.strip()}\n"
81
  "Answer:"
82
  )
 
83
 
84
- gen = generate_answer_only(prompt, max_new_tokens=4)
85
- label = extract_label(gen)
86
-
87
- return label
88
 
89
  def chat(msg, history):
90
  msg = (msg or "").strip()
 
22
  model = PeftModel.from_pretrained(base, ADAPTER_REPO).to(device)
23
  model.eval()
24
 
 
25
  try:
26
  model = model.merge_and_unload()
27
  model.to(device)
 
29
  except Exception:
30
  pass
31
 
32
+ LABELS = ["negative", "neutral", "positive"]
33
+
34
+ label_token_ids = {
35
+ lab: tokenizer(" " + lab, add_special_tokens=False)["input_ids"]
36
+ for lab in LABELS
37
+ }
38
+
39
  finance_words = [
40
  "stock","shares","profit","profits","loss","losses","revenue","earnings","dividend","market",
41
  "bank","loan","interest","inflation","bond","equity","merger","acquisition",
42
  "ipo","valuation","cash","cashflow","forecast","guidance","quarter","q1","q2","q3","q4",
43
+ "ceo","cfo","board","layoffs","bankruptcy","debt","default","margin","miss","downgrade"
44
  ]
45
 
46
  def looks_finance(text: str) -> bool:
 
51
  t = (text or "").lower().strip()
52
  return t in ["hi", "hello", "hey", "good morning", "good afternoon", "good evening"]
53
 
 
 
 
 
 
 
 
 
54
  @torch.inference_mode()
55
+ def score_label_with_cache(prompt_ids, lab_ids) -> float:
56
  """
57
+ Score P(label | prompt) using cached past_key_values.
58
+ Returns average log-prob per label token (length-normalized).
59
  """
60
+ # Run prompt once to get cache
61
+ prompt = torch.tensor([prompt_ids], device=device)
62
+ out = model(input_ids=prompt, use_cache=True)
63
+ past = out.past_key_values
64
+
65
+ logp_sum = 0.0
66
+ prev_token = prompt[:, -1:]
67
+
68
+ for tok_id in lab_ids:
69
+ step = model(input_ids=prev_token, past_key_values=past, use_cache=True)
70
+ logits = step.logits[:, -1, :]
71
+ logp_sum += torch.log_softmax(logits, dim=-1)[0, tok_id].item()
72
 
73
+ past = step.past_key_values
74
+ prev_token = torch.tensor([[tok_id]], device=device)
75
+
76
+ return logp_sum / max(len(lab_ids), 1)
77
 
78
  @torch.inference_mode()
79
  def predict_label(msg: str) -> str:
 
82
  f"Text: {msg.strip()}\n"
83
  "Answer:"
84
  )
85
+ prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
86
 
87
+ scores = {lab: score_label_with_cache(prompt_ids, label_token_ids[lab]) for lab in LABELS}
88
+ return max(scores, key=scores.get)
 
 
89
 
90
  def chat(msg, history):
91
  msg = (msg or "").strip()