JLee0 commited on
Commit
f66c7a6
Β·
1 Parent(s): 80f4f7d

Update chat interface logic

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import faiss
4
  import re
5
  import gradio as gr
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from sentence_transformers import SentenceTransformer
8
  from prompt import PROMPTS
9
 
@@ -56,12 +56,12 @@ def generate_answer(tokenizer, model, system_prompt:str, query: str, context: st
56
  **inputs,
57
  max_new_tokens=512,
58
  do_sample=True,
59
- temperature=0.7,
60
  top_p=0.8,
61
  )
62
  decoded = tokenizer.decode(out[0], skip_special_tokens=False)
63
  answer = decoded.split(prompt, 1)[-1]
64
- # strip special tokens
65
  for tok in [B, SS, SU, SA, EOT]:
66
  answer = answer.replace(tok, "")
67
  return answer.strip()
@@ -69,13 +69,24 @@ def generate_answer(tokenizer, model, system_prompt:str, query: str, context: st
69
  def post_process_answer(raw: str, prev_answer: str = "") -> str:
70
  if not raw:
71
  return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
72
- m = re.search(r"\<\|start_header_id\>assistant\<\|end_header_id\>(.*?)\<\|eot_id\>", raw, re.DOTALL)
73
- ans = m.group(1).strip() if m else raw.strip()
 
 
 
 
 
 
 
 
74
  ans = re.sub(r"\<\|.*?\|\>", "", ans).strip()
75
- if ans.lower().count("assistant") >= 2:
 
76
  return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
 
77
  if not ans or ans == prev_answer.strip():
78
  return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
 
79
  return ans
80
 
81
  def answer_query(
 
3
  import faiss
4
  import re
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsEAndBytesConfig
7
  from sentence_transformers import SentenceTransformer
8
  from prompt import PROMPTS
9
 
 
56
  **inputs,
57
  max_new_tokens=512,
58
  do_sample=True,
59
+ temperature=0.6,
60
  top_p=0.8,
61
  )
62
  decoded = tokenizer.decode(out[0], skip_special_tokens=False)
63
  answer = decoded.split(prompt, 1)[-1]
64
+
65
  for tok in [B, SS, SU, SA, EOT]:
66
  answer = answer.replace(tok, "")
67
  return answer.strip()
 
69
  def post_process_answer(raw: str, prev_answer: str = "") -> str:
70
  if not raw:
71
  return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
72
+
73
+ m = re.search(
74
+ r"\<\|start_header_id\>assistant\<\|end_header_id\>(.*?)\<\|eot_id\>",
75
+ raw, re.DOTALL
76
+ )
77
+ if m:
78
+ ans = m.group(1).strip()
79
+ else:
80
+ ans = raw.strip()
81
+
82
  ans = re.sub(r"\<\|.*?\|\>", "", ans).strip()
83
+
84
+ if ans.lower().count("assistant") >= 4:
85
  return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
86
+
87
  if not ans or ans == prev_answer.strip():
88
  return "제곡된 닡변이 μ—†μŠ΅λ‹ˆλ‹€."
89
+
90
  return ans
91
 
92
  def answer_query(