SayknowLab commited on
Commit
0661949
ยท
verified ยท
1 Parent(s): 5bb4640

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -27
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import pandas as pd
2
  import torch
3
  from flask import Flask, request, Response
4
- from transformers import AutoTokenizer, GPT2LMHeadModel
5
  from dicttoxml import dicttoxml
6
  import traceback
7
  import re
@@ -9,14 +9,47 @@ from threading import Lock
9
 
10
  app = Flask(__name__)
11
 
12
- # --- 1. ๋ชจ๋ธ ๋กœ๋“œ ---
 
 
 
 
 
 
13
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
14
- tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
 
 
 
 
15
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
16
- model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
18
 
19
- # --- 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ---
20
  try:
21
  df = pd.read_excel('dataset.xlsx')
22
  knowledge_list = df['๋ฐ์ดํ„ฐ์…‹์— ๋„ฃ์„ ๋‚ด์šฉ(*)'].tolist()
@@ -24,10 +57,10 @@ except Exception as e:
24
  print(f"๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์—๋Ÿฌ: {e}")
25
  knowledge_list = []
26
 
27
- # --- 3. ๋™์‹œ ์š”์ฒญ ์ œํ•œ์šฉ Lock ---
28
  request_lock = Lock()
29
 
30
- # --- 4. ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹ ๊ฒ€์ƒ‰ ---
31
  def find_relevant_context(query, top_n=2):
32
  query_words = query.replace(" ", "").lower()
33
  relevant_sentences = []
@@ -37,7 +70,7 @@ def find_relevant_context(query, top_n=2):
37
  relevant_sentences.append(s)
38
  return " ".join(str(s) for s in relevant_sentences[:top_n]) if relevant_sentences else ""
39
 
40
- # --- 5. Sayknow ๋‹ต๋ณ€ ์ƒ์„ฑ ---
41
  def ask_sayknow(query):
42
  try:
43
  context = find_relevant_context(query)
@@ -47,43 +80,48 @@ def ask_sayknow(query):
47
  "๊ทธ ์™ธ์—๋Š” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
48
  "์˜ˆ์‹œ: Q: ๋ถ„์ˆ˜์˜ ๋ง์…ˆ์ด ๋ญ์•ผ?\nA: ๋ถ„๋ชจ๊ฐ€ ๊ฐ™์„ ๋•Œ ๋ถ„์ž๋ผ๋ฆฌ ๋”ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.\n"
49
  )
 
50
  info = context if context else "์ •๋ณด ์—†์Œ"
 
51
  prompt = f"{persona_guide}---\n[์ •๋ณด]\n{info}\n[์งˆ๋ฌธ]\n{query}\n[๋‹ต๋ณ€] "
52
 
53
  tokenizer.pad_token = tokenizer.eos_token
 
54
  encoded_input = tokenizer.encode_plus(
55
  prompt,
56
  return_tensors='pt',
57
  truncation=True,
58
  padding=True
59
  )
60
- input_ids = encoded_input['input_ids']
61
- attention_mask = encoded_input['attention_mask']
 
62
 
63
  model.eval()
64
- with torch.no_grad():
65
- gen_ids = model.generate(
66
- input_ids,
67
- attention_mask=attention_mask,
68
- max_new_tokens=100,
69
- min_length=5,
70
- repetition_penalty=1.3,
71
- do_sample=True,
72
- top_k=30,
73
- top_p=0.9, # ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€
74
- temperature=0.7, # ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€
75
- num_beams=1,
76
- pad_token_id=tokenizer.pad_token_id
77
- )
78
 
79
  raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
80
 
81
  # --- ๋‹ต๋ณ€ ์ถ”์ถœ ---
82
  answer = raw_response.replace(prompt, '').strip()
 
83
  if "๋‹ต๋ณ€:" in answer:
84
  answer = answer.split("๋‹ต๋ณ€:", 1)[1].strip()
85
 
86
- # ์˜๋ฏธ ์—†๋Š” ๋ฌธ์ž ์ œ๊ฑฐ
87
  answer = re.sub(r"[^๊ฐ€-ํžฃ0-9 .,!?~\n]", "", answer)
88
  answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer)
89
  answer = re.sub(r"[a-zA-Z]+", "", answer)
@@ -92,6 +130,7 @@ def ask_sayknow(query):
92
 
93
  # 80์ž ์ œํ•œ
94
  answer = answer[:80]
 
95
  if answer and answer[-1] not in ".!?":
96
  answer += "."
97
  elif not answer:
@@ -104,17 +143,19 @@ def ask_sayknow(query):
104
  traceback.print_exc()
105
  return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}"
106
 
107
- # --- 6. API (XML ์‘๋‹ต) ---
108
  @app.route('/chatapi.html', methods=['GET'])
109
  @app.route('/index.html', methods=['GET'])
110
  def chat_api():
111
  query = request.args.get('askdata', '')
 
112
  if not query:
113
  result = {"status": "error", "message": "No data"}
114
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
115
  return Response(xml_output, mimetype='text/xml')
116
 
117
- with request_lock: # knowledge_list ์ ‘๊ทผ ๋ณดํ˜ธ
 
118
  try:
119
  answer = ask_sayknow(query)
120
  result = {
 
1
  import pandas as pd
2
  import torch
3
  from flask import Flask, request, Response
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from dicttoxml import dicttoxml
6
  import traceback
7
  import re
 
9
 
10
  app = Flask(__name__)
11
 
12
+ # --- 1. ๋””๋ฐ”์ด์Šค ์„ค์ • ---
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"์‚ฌ์šฉ ๋””๋ฐ”์ด์Šค: {device}")
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+ # --- 2. ๋ชจ๋ธ ๋กœ๋“œ ---
19
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ "LiquidAI/LFM2.5-1.2B-Instruct",
22
+ trust_remote_code=True
23
+ )
24
+
25
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
26
+ try:
27
+ # 8bit ๋กœ๋“œ ์‹œ๋„
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ "LiquidAI/LFM2.5-1.2B-Instruct",
30
+ device_map="auto",
31
+ load_in_8bit=True,
32
+ trust_remote_code=True
33
+ )
34
+ print("8bit ๋กœ๋”ฉ ์„ฑ๊ณต")
35
+ except:
36
+ # ์‹คํŒจ ์‹œ ์ผ๋ฐ˜ ๋กœ๋“œ
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ "LiquidAI/LFM2.5-1.2B-Instruct",
39
+ trust_remote_code=True
40
+ ).to(device)
41
+ print("์ผ๋ฐ˜ ๋กœ๋”ฉ ์‚ฌ์šฉ")
42
+
43
+ # torch 2.0 ์ด์ƒ์ด๋ฉด ์ปดํŒŒ์ผ
44
+ try:
45
+ model = torch.compile(model)
46
+ print("torch.compile ์ ์šฉ ์™„๋ฃŒ")
47
+ except:
48
+ print("torch.compile ๋ฏธ์ ์šฉ (์ง€์› ์•ˆํ•จ)")
49
+
50
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
51
 
52
+ # --- 3. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ---
53
  try:
54
  df = pd.read_excel('dataset.xlsx')
55
  knowledge_list = df['๋ฐ์ดํ„ฐ์…‹์— ๋„ฃ์„ ๋‚ด์šฉ(*)'].tolist()
 
57
  print(f"๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์—๋Ÿฌ: {e}")
58
  knowledge_list = []
59
 
60
+ # --- 4. ๋™์‹œ ์š”์ฒญ ์ œํ•œ์šฉ Lock (๊ตฌ์กฐ ์œ ์ง€) ---
61
  request_lock = Lock()
62
 
63
+ # --- 5. ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹ ๊ฒ€์ƒ‰ (๊ธฐ์กด ๋ฐฉ์‹ ์œ ์ง€) ---
64
  def find_relevant_context(query, top_n=2):
65
  query_words = query.replace(" ", "").lower()
66
  relevant_sentences = []
 
70
  relevant_sentences.append(s)
71
  return " ".join(str(s) for s in relevant_sentences[:top_n]) if relevant_sentences else ""
72
 
73
+ # --- 6. Sayknow ๋‹ต๋ณ€ ์ƒ์„ฑ ---
74
  def ask_sayknow(query):
75
  try:
76
  context = find_relevant_context(query)
 
80
  "๊ทธ ์™ธ์—๋Š” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
81
  "์˜ˆ์‹œ: Q: ๋ถ„์ˆ˜์˜ ๋ง์…ˆ์ด ๋ญ์•ผ?\nA: ๋ถ„๋ชจ๊ฐ€ ๊ฐ™์„ ๋•Œ ๋ถ„์ž๋ผ๋ฆฌ ๋”ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.\n"
82
  )
83
+
84
  info = context if context else "์ •๋ณด ์—†์Œ"
85
+
86
  prompt = f"{persona_guide}---\n[์ •๋ณด]\n{info}\n[์งˆ๋ฌธ]\n{query}\n[๋‹ต๋ณ€] "
87
 
88
  tokenizer.pad_token = tokenizer.eos_token
89
+
90
  encoded_input = tokenizer.encode_plus(
91
  prompt,
92
  return_tensors='pt',
93
  truncation=True,
94
  padding=True
95
  )
96
+
97
+ input_ids = encoded_input['input_ids'].to(device)
98
+ attention_mask = encoded_input['attention_mask'].to(device)
99
 
100
  model.eval()
101
+
102
+ gen_ids = model.generate(
103
+ input_ids,
104
+ attention_mask=attention_mask,
105
+ max_new_tokens=60, # ์ค„์ž„
106
+ min_length=5,
107
+ repetition_penalty=1.2,
108
+ do_sample=True,
109
+ top_k=30,
110
+ top_p=0.8,
111
+ temperature=0.5,
112
+ num_beams=1,
113
+ pad_token_id=tokenizer.pad_token_id
114
+ )
115
 
116
  raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
117
 
118
  # --- ๋‹ต๋ณ€ ์ถ”์ถœ ---
119
  answer = raw_response.replace(prompt, '').strip()
120
+
121
  if "๋‹ต๋ณ€:" in answer:
122
  answer = answer.split("๋‹ต๋ณ€:", 1)[1].strip()
123
 
124
+ # --- ํ›„์ฒ˜๋ฆฌ (5๋ฒˆ ์œ ์ง€ ์š”์ฒญ๋Œ€๋กœ ๊ทธ๋Œ€๋กœ ์œ ์ง€) ---
125
  answer = re.sub(r"[^๊ฐ€-ํžฃ0-9 .,!?~\n]", "", answer)
126
  answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer)
127
  answer = re.sub(r"[a-zA-Z]+", "", answer)
 
130
 
131
  # 80์ž ์ œํ•œ
132
  answer = answer[:80]
133
+
134
  if answer and answer[-1] not in ".!?":
135
  answer += "."
136
  elif not answer:
 
143
  traceback.print_exc()
144
  return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}"
145
 
146
+ # --- 7. API (XML ์‘๋‹ต) ---
147
  @app.route('/chatapi.html', methods=['GET'])
148
  @app.route('/index.html', methods=['GET'])
149
  def chat_api():
150
  query = request.args.get('askdata', '')
151
+
152
  if not query:
153
  result = {"status": "error", "message": "No data"}
154
  xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
155
  return Response(xml_output, mimetype='text/xml')
156
 
157
+ # 6๋ฒˆ ์œ ์ง€ ์š”์ฒญ โ†’ Lock ์ „์ฒด ์œ ์ง€
158
+ with request_lock:
159
  try:
160
  answer = ask_sayknow(query)
161
  result = {