SayknowLab commited on
Commit
cce1204
ยท
verified ยท
1 Parent(s): 0ff739c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -56
app.py CHANGED
@@ -4,15 +4,16 @@ from flask import Flask, request, Response
4
  from transformers import AutoTokenizer, GPT2LMHeadModel
5
  from dicttoxml import dicttoxml
6
  import traceback
7
- from threading import Lock # โ† ์ถ”๊ฐ€
 
8
 
9
  app = Flask(__name__)
10
 
11
  # --- 1. ๋ชจ๋ธ ๋กœ๋“œ ---
12
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
13
- tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
14
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
15
- model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
16
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
17
 
18
  # --- 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ---
@@ -26,21 +27,21 @@ except Exception as e:
26
  # --- 3. ๋™์‹œ ์š”์ฒญ ์ œํ•œ์šฉ Lock ---
27
  request_lock = Lock()
28
 
 
29
  def find_relevant_context(query, top_n=2):
30
- """์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹ ๋ฐ์ดํ„ฐ ๋ฌธ์žฅ ์ตœ๋Œ€ top_n๊ฐœ ๋ฐ˜ํ™˜"""
31
  query_words = query.replace(" ", "").lower()
32
  relevant_sentences = []
33
  for s in knowledge_list:
34
  s_text = str(s).replace(" ", "").replace("\n", "").lower()
35
  if any(word.replace(" ", "") in s_text for word in query.split()):
36
  relevant_sentences.append(s)
37
- if relevant_sentences:
38
- return " ".join(str(s) for s in relevant_sentences[:top_n])
39
- return ""
40
 
 
41
  def ask_sayknow(query):
42
  try:
43
  context = find_relevant_context(query)
 
44
  persona_guide = (
45
  "๋„ˆ๋Š” ์ง€์‹ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ์ฑ—๋ด‡ Sayknow์•ผ. ์ž๊ธฐ์†Œ๊ฐœ ์งˆ๋ฌธ์—๋Š” '์ €๋Š” Sayknow์ž…๋‹ˆ๋‹ค.'๋ผ๊ณ  ๋‹ตํ•ด. "
46
  "๊ทธ ์™ธ์—๋Š” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
@@ -64,76 +65,46 @@ def ask_sayknow(query):
64
  gen_ids = model.generate(
65
  input_ids,
66
  attention_mask=attention_mask,
67
- max_new_tokens=100, # โ† 200 -> 100์œผ๋กœ ์ค„์—ฌ ์›Œ์ปค ์ ์œ  ์‹œ๊ฐ„ ๋‹จ์ถ•
68
  min_length=5,
69
  repetition_penalty=1.3,
70
  do_sample=True,
71
  top_k=30,
72
- top_p=0.85,
73
- pad_token_id=tokenizer.pad_token_id,
74
- temperature=0.5,
75
- num_beams=1
76
  )
 
77
  raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
78
 
79
- # ํ”„๋กฌํ”„ํŠธ ์ œ๊ฑฐ ํ›„ ์‹ค์ œ ๋‹ต๋ณ€ ์ถ”์ถœ
80
- if raw_response.startswith(prompt):
81
- answer = raw_response[len(prompt):].strip()
82
- else:
83
- answer = raw_response.strip()
84
-
85
- # 2. '๋‹ต๋ณ€:' ํ‚ค์›Œ๋“œ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์ง„์งœ ๋‹ต๋ณ€ ๋ถ€๋ถ„ ์ถ”์ถœ
86
- if "๋‹ต๋ณ€:" in extracted_answer:
87
- answer = extracted_answer.split("๋‹ต๋ณ€:", 1)[1].strip() # ์ฒซ ๋ฒˆ์งธ "๋‹ต๋ณ€:" ์ดํ›„๋งŒ
88
- else:
89
- # ๋งŒ์•ฝ "๋‹ต๋ณ€:" ํƒœ๊ทธ๊ฐ€ ์—†์œผ๋ฉด, ํ”„๋กฌํ”„ํŠธ์˜ ์ง€์‹œ์‚ฌํ•ญ ์ค‘๋ณต ๋“ฑ์„ ์ œ๊ฑฐ ์‹œ๋„
90
- persona_end_marker = "๋‹ตํ•ด.\n" # persona_guide์˜ ํŠน์ • ๋ ๋ถ€๋ถ„์„ ํ‘œ์‹œ
91
- if persona_end_marker in extracted_answer:
92
- try:
93
- answer = extracted_answer[extracted_answer.rindex(persona_end_marker) + len(persona_end_marker):].strip()
94
- except ValueError:
95
- answer = extracted_answer # ์•ˆ๋˜๋ฉด ๊ทธ๋ƒฅ ์ „์ฒด ์‚ฌ์šฉ
96
- else:
97
- answer = extracted_answer # ๊ทธ๊ฒƒ๋„ ์—†์œผ๋ฉด ๊ทธ๋ƒฅ ์ „์ฒด ์‚ฌ์šฉ
98
-
99
- # ๊ทธ๋ž˜๋„ ๋‹ต๋ณ€์ด ๋น„์–ด์žˆ์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๋ฅผ ๋Œ€์ฒด
100
- if not answer:
101
- answer = "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ฐพ์„ ์ˆ˜ ์—†๊ฑฐ๋‚˜ ๋‚ด์šฉ์ด ๋ช…ํ™•ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค."
102
-
103
-
104
- # 1. ์˜๋ฏธ ์—†๋Š” ์ˆ˜์‹/์˜๋ฌธ/ํŠน์ˆ˜๋ฌธ์ž/๋ฐ˜๋ณต๋ฌธ์ž ๋“ฑ ํ•„ํ„ฐ๋ง (๊ธฐ์กด๊ณผ ๋™์ผ)
105
- # ์ด ๋ถ€๋ถ„์„ ๋จผ์ € ํ•œ๋ฒˆ ์ ์šฉํ•ด์„œ answer๊ฐ€ ์—‰๋šฑํ•œ ๋ฌธ์ž์—ด์ด ๋˜๋Š” ๊ฑธ ๋ฐฉ์ง€
106
  answer = re.sub(r"[^๊ฐ€-ํžฃ0-9 .,!?~\n]", "", answer)
107
  answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer)
108
  answer = re.sub(r"[a-zA-Z]+", "", answer)
109
  answer = re.sub(r"[=^*/\\]+", "", answer)
110
  answer = re.sub(r"\s+", " ", answer).strip()
111
 
112
- # 2. 80์ž ์ด๋‚ด๋กœ ์ž๋ฅด๊ธฐ (ํ•œ๊ธ€ ๊ธฐ์ค€) (๊ธฐ์กด๊ณผ ๋™์ผ)
113
- def truncate_korean(text, max_len=80):
114
- count = 0
115
- result = ""
116
- for ch in text:
117
- result += ch
118
- count += 1
119
- if count >= max_len:
120
- break
121
- return result
122
- answer = truncate_korean(answer, 80)
123
-
124
- # ๋ฌธ์žฅ ๋ ์ฒ˜๋ฆฌ
125
  if answer and answer[-1] not in ".!?":
126
  answer += "."
127
  elif not answer:
128
- answer = "์•Œ ์ˆ˜ ์—†๋Š” ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
129
 
130
  return answer
 
131
  except Exception as e:
132
  print(f"ask_sayknow ์—๋Ÿฌ: {e}")
133
  traceback.print_exc()
134
  return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}"
135
 
136
- # --- 4. API (XML ์‘๋‹ต) ---
137
  @app.route('/chatapi.html', methods=['GET'])
138
  @app.route('/index.html', methods=['GET'])
139
  def chat_api():
@@ -143,8 +114,7 @@ def chat_api():
143
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
144
  return Response(xml_output, mimetype='text/xml')
145
 
146
- # โ† Lock์œผ๋กœ ์š”์ฒญ ์ˆœ์ฐจ ์ฒ˜๋ฆฌ
147
- with request_lock:
148
  try:
149
  answer = ask_sayknow(query)
150
  result = {
@@ -165,5 +135,6 @@ def chat_api():
165
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
166
  return Response(xml_output, mimetype='text/xml')
167
 
 
168
  if __name__ == '__main__':
169
  app.run(host='0.0.0.0', port=7860)
 
4
  from transformers import AutoTokenizer, GPT2LMHeadModel
5
  from dicttoxml import dicttoxml
6
  import traceback
7
+ import re
8
+ from threading import Lock
9
 
10
  app = Flask(__name__)
11
 
12
  # --- 1. ๋ชจ๋ธ ๋กœ๋“œ ---
13
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
14
+ tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2", trust_remote_code=True)
15
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
16
+ model = GPT2LMHeadModel.from_pretrained("skt/kogpt2", trust_remote_code=True)
17
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
18
 
19
  # --- 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ---
 
27
  # --- 3. ๋™์‹œ ์š”์ฒญ ์ œํ•œ์šฉ Lock ---
28
  request_lock = Lock()
29
 
30
+ # --- 4. ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹ ๊ฒ€์ƒ‰ ---
31
  def find_relevant_context(query, top_n=2):
 
32
  query_words = query.replace(" ", "").lower()
33
  relevant_sentences = []
34
  for s in knowledge_list:
35
  s_text = str(s).replace(" ", "").replace("\n", "").lower()
36
  if any(word.replace(" ", "") in s_text for word in query.split()):
37
  relevant_sentences.append(s)
38
+ return " ".join(str(s) for s in relevant_sentences[:top_n]) if relevant_sentences else ""
 
 
39
 
40
+ # --- 5. Sayknow ๋‹ต๋ณ€ ์ƒ์„ฑ ---
41
  def ask_sayknow(query):
42
  try:
43
  context = find_relevant_context(query)
44
+
45
  persona_guide = (
46
  "๋„ˆ๋Š” ์ง€์‹ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ์ฑ—๋ด‡ Sayknow์•ผ. ์ž๊ธฐ์†Œ๊ฐœ ์งˆ๋ฌธ์—๋Š” '์ €๋Š” Sayknow์ž…๋‹ˆ๋‹ค.'๋ผ๊ณ  ๋‹ตํ•ด. "
47
  "๊ทธ ์™ธ์—๋Š” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
 
65
  gen_ids = model.generate(
66
  input_ids,
67
  attention_mask=attention_mask,
68
+ max_new_tokens=100,
69
  min_length=5,
70
  repetition_penalty=1.3,
71
  do_sample=True,
72
  top_k=30,
73
+ top_p=0.9, # ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€
74
+ temperature=0.7, # ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€
75
+ num_beams=1,
76
+ pad_token_id=tokenizer.pad_token_id
77
  )
78
+
79
  raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
80
 
81
+ # --- ๋‹ต๋ณ€ ์ถ”์ถœ ---
82
+ answer = raw_response.replace(prompt, '').strip()
83
+ if "๋‹ต๋ณ€:" in answer:
84
+ answer = answer.split("๋‹ต๋ณ€:", 1)[1].strip()
85
+
86
+ # ์˜๋ฏธ ์—†๋Š” ๋ฌธ์ž ์ œ๊ฑฐ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  answer = re.sub(r"[^๊ฐ€-ํžฃ0-9 .,!?~\n]", "", answer)
88
  answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer)
89
  answer = re.sub(r"[a-zA-Z]+", "", answer)
90
  answer = re.sub(r"[=^*/\\]+", "", answer)
91
  answer = re.sub(r"\s+", " ", answer).strip()
92
 
93
+ # 80์ž ์ œํ•œ
94
+ answer = answer[:80]
 
 
 
 
 
 
 
 
 
 
 
95
  if answer and answer[-1] not in ".!?":
96
  answer += "."
97
  elif not answer:
98
+ answer = "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
99
 
100
  return answer
101
+
102
  except Exception as e:
103
  print(f"ask_sayknow ์—๋Ÿฌ: {e}")
104
  traceback.print_exc()
105
  return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}"
106
 
107
+ # --- 6. API (XML ์‘๋‹ต) ---
108
  @app.route('/chatapi.html', methods=['GET'])
109
  @app.route('/index.html', methods=['GET'])
110
  def chat_api():
 
114
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
115
  return Response(xml_output, mimetype='text/xml')
116
 
117
+ with request_lock: # knowledge_list ์ ‘๊ทผ ๋ณดํ˜ธ
 
118
  try:
119
  answer = ask_sayknow(query)
120
  result = {
 
135
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
136
  return Response(xml_output, mimetype='text/xml')
137
 
138
+
139
  if __name__ == '__main__':
140
  app.run(host='0.0.0.0', port=7860)