SayknowLab commited on
Commit
5914126
ยท
verified ยท
1 Parent(s): 27924af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -55
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import pandas as pd
2
  import torch
3
- from flask import Flask, request, Response, render_template_string
4
  from transformers import AutoTokenizer, GPT2LMHeadModel
5
  from dicttoxml import dicttoxml
6
- import re
7
  import traceback
8
 
9
  app = Flask(__name__)
10
 
11
- # --- hCaptcha ์„ค์ • ๊ด€๋ จ ์ฝ”๋“œ ์ „๋ถ€ ์ œ๊ฑฐ๋จ ---
12
-
13
- # 1. ๋ชจ๋ธ ๋กœ๋“œ (๊ธฐ์กด๊ณผ ๋™์ผ)
14
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
15
  tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
16
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
17
  model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
18
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
19
 
20
- # 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ (๊ธฐ์กด๊ณผ ๋™์ผ)
21
  try:
22
  df = pd.read_excel('dataset.xlsx')
23
  knowledge_list = df['๋ฐ์ดํ„ฐ์…‹์— ๋„ฃ์„ ๋‚ด์šฉ(*)'].tolist()
@@ -26,7 +23,7 @@ except Exception as e:
26
  knowledge_list = []
27
 
28
  def find_relevant_context(query, top_n=2):
29
- """์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹๋ฐ์ดํ„ฐ ๋ฌธ์žฅ ์ตœ๋Œ€ top_n๊ฐœ ์ฐพ์•„์„œ ๋ฐ˜ํ™˜ (๊ธฐ์กด๊ณผ ๋™์ผ)"""
30
  query_words = query.replace(" ", "").lower()
31
  relevant_sentences = []
32
  for s in knowledge_list:
@@ -41,14 +38,13 @@ def ask_sayknow(query):
41
  try:
42
  context = find_relevant_context(query)
43
  persona_guide = (
44
- "๋„ˆ๋Š” ์ง€์‹ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ์ฑ—๋ด‡ Sayknow์•ผ. ์ž๊ธฐ์†Œ๊ฐœ(์ด๋ฆ„, ์ •์ฒด, ์ธ์‚ฌ ๋“ฑ) ์งˆ๋ฌธ์€ '์ €๋Š” Sayknow์ž…๋‹ˆ๋‹ค.'๋ผ๊ณ  ๋‹ตํ•ด. "
45
- "๊ทธ ์™ธ์—” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
46
  "์˜ˆ์‹œ: Q: ๋ถ„์ˆ˜์˜ ๋ง์…ˆ์ด ๋ญ์•ผ?\nA: ๋ถ„๋ชจ๊ฐ€ ๊ฐ™์„ ๋•Œ ๋ถ„์ž๋ผ๋ฆฌ ๋”ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.\n"
47
  )
48
  info = context if context else "์ •๋ณด ์—†์Œ"
49
  prompt = f"{persona_guide}---\n[์ •๋ณด]\n{info}\n[์งˆ๋ฌธ]\n{query}\n[๋‹ต๋ณ€] "
50
 
51
- # ์ด์ „ ๋‹ต๋ณ€ ๋กœ์ง ๊ฐœ์„  (attention_mask ์ถ”๊ฐ€) - ์ด ๋ถ€๋ถ„์€ ์ž˜ ์ž‘๋™ํ•˜๊ณ  ์žˆ์„ ๊ฑฐ์•ผ!
52
  tokenizer.pad_token = tokenizer.eos_token
53
  encoded_input = tokenizer.encode_plus(
54
  prompt,
@@ -58,13 +54,13 @@ def ask_sayknow(query):
58
  )
59
  input_ids = encoded_input['input_ids']
60
  attention_mask = encoded_input['attention_mask']
61
-
62
  model.eval()
63
  with torch.no_grad():
64
  gen_ids = model.generate(
65
  input_ids,
66
  attention_mask=attention_mask,
67
- max_new_tokens=200, # ๋‹ต๋ณ€์ด ์ž˜๋ฆฌ๋Š” ๋ฌธ์ œ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ์กฐ๊ธˆ ๋Š˜๋ ค๋ดค์–ด! (60 -> 80)
68
  min_length=5,
69
  repetition_penalty=1.3,
70
  do_sample=True,
@@ -74,29 +70,27 @@ def ask_sayknow(query):
74
  temperature=0.5,
75
  num_beams=1
76
  )
77
- raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True) # ์›๋ณธ ์‘๋‹ต ์ €์žฅ
78
-
79
- # --- ์‘๋‹ต ์ฒ˜๋ฆฌ ๋กœ์ง ๊ฐœ์„  ๋ฒ„์ „ (index out of range ์—๋Ÿฌ ๋ฐฉ์ง€) ---
80
- # 1. ๋ชจ๋ธ์ด ์ƒ์„ฑํ•œ ์ „์ฒด ํ…์ŠคํŠธ์—์„œ ํ”„๋กฌํ”„ํŠธ ๋ถ€๋ถ„ ์ž๋ฅด๊ธฐ (๋ฐ˜๋ณต๋˜๋Š” ๋ฌธ์ œ ๋ฐฉ์ง€)
81
- # prompt๊ฐ€ raw_response์˜ ์‹œ์ž‘ ๋ถ€๋ถ„์— ์žˆ๋‹ค๋ฉด ๊ทธ ๋ถ€๋ถ„์„ ์ž˜๋ผ๋‚ผ๊ฒŒ.
82
  if raw_response.startswith(prompt):
83
- extracted_answer = raw_response[len(prompt):].strip()
84
  else:
85
- extracted_answer = raw_response.strip()
86
 
87
- # 3. ๋ฌธ์žฅ ๋์ด ์ž์—ฐ์Šค๋Ÿฝ์ง€ ์•Š์œผ๋ฉด ๋งˆ์นจํ‘œ ์ถ”๊ฐ€
88
- if answer and answer[-1] not in ".!?":
89
  answer += "."
90
- elif not answer: # ๋นˆ ๋ฌธ์ž์—ด์ธ๋ฐ '.' ์ฐ์œผ๋ฉด ์—๋Ÿฌ๋‚˜๋‹ˆ ํ•œ๋ฒˆ ๋” ์ฒดํฌ
91
- answer = "์•Œ ์ˆ˜ ์—†๋Š” ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค." # ์ตœํ›„์˜ ๋ณด๋ฃจ
92
 
93
  return answer
94
  except Exception as e:
95
  print(f"ask_sayknow ์—๋Ÿฌ: {e}")
96
  traceback.print_exc()
97
- return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}" # ์™ธ๋ถ€ ์‚ฌ์šฉ์ž์—๊ฒŒ ๋ณด์ด๋Š” ๋ฉ”์‹œ์ง€!
98
 
99
- # 3. API (XML ์‘๋‹ต ์œ ์ง€) (๊ธฐ์กด๊ณผ ๋™์ผ)
100
  @app.route('/chatapi.html', methods=['GET'])
101
  @app.route('/index.html', methods=['GET'])
102
  def chat_api():
@@ -124,34 +118,5 @@ def chat_api():
124
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
125
  return Response(xml_output, mimetype='text/xml')
126
 
127
- # 4. ์›น UI (๊ฐ„๋‹จํ•œ ์งˆ๋ฌธ ํผ + ๋‹ต๋ณ€) - hCaptcha ์ฝ”๋“œ ์ „๋ถ€ ์ œ๊ฑฐ!
128
- @app.route('/', methods=['GET', 'POST'])
129
- def index():
130
- answer = ""
131
- question = ""
132
- # error_message ์ œ๊ฑฐ
133
-
134
- if request.method == "POST":
135
- question = request.form.get('question', '')
136
- # hcaptcha_response ๊ด€๋ จ ๋กœ์ง ์ œ๊ฑฐ
137
-
138
- # hCaptcha ๊ฒ€์ฆ ๋กœ์ง ์ œ๊ฑฐ
139
- if question: # ์งˆ๋ฌธ์ด ์žˆ์œผ๋ฉด ๋ฐ”๋กœ ๋‹ต๋ณ€ ์ƒ์„ฑ!
140
- answer = ask_sayknow(question)
141
-
142
- html = f"""
143
- <html>
144
- <head>
145
- <title>Sayknow ์ฑ—๋ด‡</title>
146
- <!-- hCaptcha ์Šคํฌ๋ฆฝํŠธ ์ œ๊ฑฐ -->
147
- </head>
148
- <body>
149
- <h2>Sayknow ํ•œ๊ตญ์–ด ์ฑ—๋ด‡</h2>
150
- <h1>์„œ๋น„์Šค๋ฅผ ์‚ฌ์šฉํ•˜์‹œ๋ ค๋ฉด <a herf=sayknow.ggm.kr>sayknow.ggm.kr<a>๋กœ ์ด๋™ํ•ด์ฃผ์„ธ์š”.
151
- </body>
152
- </html>
153
- """
154
- return render_template_string(html)
155
-
156
  if __name__ == '__main__':
157
- app.run(host='0.0.0.0', port=7860)
 
1
  import pandas as pd
2
  import torch
3
+ from flask import Flask, request, Response
4
  from transformers import AutoTokenizer, GPT2LMHeadModel
5
  from dicttoxml import dicttoxml
 
6
  import traceback
7
 
8
  app = Flask(__name__)
9
 
10
+ # --- 1. ๋ชจ๋ธ ๋กœ๋“œ ---
 
 
11
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
12
  tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
13
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
14
  model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
15
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
16
 
17
+ # --- 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ---
18
  try:
19
  df = pd.read_excel('dataset.xlsx')
20
  knowledge_list = df['๋ฐ์ดํ„ฐ์…‹์— ๋„ฃ์„ ๋‚ด์šฉ(*)'].tolist()
 
23
  knowledge_list = []
24
 
25
  def find_relevant_context(query, top_n=2):
26
+ """์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹ ๋ฐ์ดํ„ฐ ๋ฌธ์žฅ ์ตœ๋Œ€ top_n๊ฐœ ๋ฐ˜ํ™˜"""
27
  query_words = query.replace(" ", "").lower()
28
  relevant_sentences = []
29
  for s in knowledge_list:
 
38
  try:
39
  context = find_relevant_context(query)
40
  persona_guide = (
41
+ "๋„ˆ๋Š” ์ง€์‹ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ์ฑ—๋ด‡ Sayknow์•ผ. ์ž๊ธฐ์†Œ๊ฐœ ์งˆ๋ฌธ์—๋Š” '์ €๋Š” Sayknow์ž…๋‹ˆ๋‹ค.'๋ผ๊ณ  ๋‹ตํ•ด. "
42
+ "๊ทธ ์™ธ์—๋Š” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
43
  "์˜ˆ์‹œ: Q: ๋ถ„์ˆ˜์˜ ๋ง์…ˆ์ด ๋ญ์•ผ?\nA: ๋ถ„๋ชจ๊ฐ€ ๊ฐ™์„ ๋•Œ ๋ถ„์ž๋ผ๋ฆฌ ๋”ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.\n"
44
  )
45
  info = context if context else "์ •๋ณด ์—†์Œ"
46
  prompt = f"{persona_guide}---\n[์ •๋ณด]\n{info}\n[์งˆ๋ฌธ]\n{query}\n[๋‹ต๋ณ€] "
47
 
 
48
  tokenizer.pad_token = tokenizer.eos_token
49
  encoded_input = tokenizer.encode_plus(
50
  prompt,
 
54
  )
55
  input_ids = encoded_input['input_ids']
56
  attention_mask = encoded_input['attention_mask']
57
+
58
  model.eval()
59
  with torch.no_grad():
60
  gen_ids = model.generate(
61
  input_ids,
62
  attention_mask=attention_mask,
63
+ max_new_tokens=200,
64
  min_length=5,
65
  repetition_penalty=1.3,
66
  do_sample=True,
 
70
  temperature=0.5,
71
  num_beams=1
72
  )
73
+ raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
74
+
75
+ # ํ”„๋กฌํ”„ํŠธ ์ œ๊ฑฐ ํ›„ ์‹ค์ œ ๋‹ต๋ณ€ ์ถ”์ถœ
 
 
76
  if raw_response.startswith(prompt):
77
+ answer = raw_response[len(prompt):].strip()
78
  else:
79
+ answer = raw_response.strip()
80
 
81
+ # ๋ฌธ์žฅ ๋ ์ฒ˜๋ฆฌ
82
+ if answer and answer[-1] not in ".!?":
83
  answer += "."
84
+ elif not answer:
85
+ answer = "์•Œ ์ˆ˜ ์—†๋Š” ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
86
 
87
  return answer
88
  except Exception as e:
89
  print(f"ask_sayknow ์—๋Ÿฌ: {e}")
90
  traceback.print_exc()
91
+ return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}"
92
 
93
+ # --- 3. API (XML ์‘๋‹ต) ---
94
  @app.route('/chatapi.html', methods=['GET'])
95
  @app.route('/index.html', methods=['GET'])
96
  def chat_api():
 
118
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
119
  return Response(xml_output, mimetype='text/xml')
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if __name__ == '__main__':
122
+ app.run(host='0.0.0.0', port=7860)