Sayknow_v1 / app.py
SayknowLab's picture
Update app.py
0661949 verified
raw
history blame
5.72 kB
import pandas as pd
import torch
from flask import Flask, request, Response
from transformers import AutoTokenizer, AutoModelForCausalLM
from dicttoxml import dicttoxml
import traceback
import re
from threading import Lock
app = Flask(__name__)
# --- 1. ๋””๋ฐ”์ด์Šค ์„ค์ • ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์‚ฌ์šฉ ๋””๋ฐ”์ด์Šค: {device}")
torch.set_grad_enabled(False)
# --- 2. ๋ชจ๋ธ ๋กœ๋“œ ---
print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
tokenizer = AutoTokenizer.from_pretrained(
"LiquidAI/LFM2.5-1.2B-Instruct",
trust_remote_code=True
)
print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
try:
# 8bit ๋กœ๋“œ ์‹œ๋„
model = AutoModelForCausalLM.from_pretrained(
"LiquidAI/LFM2.5-1.2B-Instruct",
device_map="auto",
load_in_8bit=True,
trust_remote_code=True
)
print("8bit ๋กœ๋”ฉ ์„ฑ๊ณต")
except:
# ์‹คํŒจ ์‹œ ์ผ๋ฐ˜ ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained(
"LiquidAI/LFM2.5-1.2B-Instruct",
trust_remote_code=True
).to(device)
print("์ผ๋ฐ˜ ๋กœ๋”ฉ ์‚ฌ์šฉ")
# torch 2.0 ์ด์ƒ์ด๋ฉด ์ปดํŒŒ์ผ
try:
model = torch.compile(model)
print("torch.compile ์ ์šฉ ์™„๋ฃŒ")
except:
print("torch.compile ๋ฏธ์ ์šฉ (์ง€์› ์•ˆํ•จ)")
print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
# --- 3. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ---
try:
df = pd.read_excel('dataset.xlsx')
knowledge_list = df['๋ฐ์ดํ„ฐ์…‹์— ๋„ฃ์„ ๋‚ด์šฉ(*)'].tolist()
except Exception as e:
print(f"๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ์—๋Ÿฌ: {e}")
knowledge_list = []
# --- 4. ๋™์‹œ ์š”์ฒญ ์ œํ•œ์šฉ Lock (๊ตฌ์กฐ ์œ ์ง€) ---
request_lock = Lock()
# --- 5. ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ์ง€์‹ ๊ฒ€์ƒ‰ (๊ธฐ์กด ๋ฐฉ์‹ ์œ ์ง€) ---
def find_relevant_context(query, top_n=2):
query_words = query.replace(" ", "").lower()
relevant_sentences = []
for s in knowledge_list:
s_text = str(s).replace(" ", "").replace("\n", "").lower()
if any(word.replace(" ", "") in s_text for word in query.split()):
relevant_sentences.append(s)
return " ".join(str(s) for s in relevant_sentences[:top_n]) if relevant_sentences else ""
# --- 6. Sayknow ๋‹ต๋ณ€ ์ƒ์„ฑ ---
def ask_sayknow(query):
try:
context = find_relevant_context(query)
persona_guide = (
"๋„ˆ๋Š” ์ง€์‹ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ์ฑ—๋ด‡ Sayknow์•ผ. ์ž๊ธฐ์†Œ๊ฐœ ์งˆ๋ฌธ์—๋Š” '์ €๋Š” Sayknow์ž…๋‹ˆ๋‹ค.'๋ผ๊ณ  ๋‹ตํ•ด. "
"๊ทธ ์™ธ์—๋Š” ์•„๋ž˜ ์ฐธ๊ณ ํ•ด์„œ ์ •ํ™•ํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์œผ๋กœ 80์ž ์ด๋‚ด๋กœ ๋‹ตํ•ด.\n"
"์˜ˆ์‹œ: Q: ๋ถ„์ˆ˜์˜ ๋ง์…ˆ์ด ๋ญ์•ผ?\nA: ๋ถ„๋ชจ๊ฐ€ ๊ฐ™์„ ๋•Œ ๋ถ„์ž๋ผ๋ฆฌ ๋”ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.\n"
)
info = context if context else "์ •๋ณด ์—†์Œ"
prompt = f"{persona_guide}---\n[์ •๋ณด]\n{info}\n[์งˆ๋ฌธ]\n{query}\n[๋‹ต๋ณ€] "
tokenizer.pad_token = tokenizer.eos_token
encoded_input = tokenizer.encode_plus(
prompt,
return_tensors='pt',
truncation=True,
padding=True
)
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
model.eval()
gen_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=60, # ์ค„์ž„
min_length=5,
repetition_penalty=1.2,
do_sample=True,
top_k=30,
top_p=0.8,
temperature=0.5,
num_beams=1,
pad_token_id=tokenizer.pad_token_id
)
raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# --- ๋‹ต๋ณ€ ์ถ”์ถœ ---
answer = raw_response.replace(prompt, '').strip()
if "๋‹ต๋ณ€:" in answer:
answer = answer.split("๋‹ต๋ณ€:", 1)[1].strip()
# --- ํ›„์ฒ˜๋ฆฌ (5๋ฒˆ ์œ ์ง€ ์š”์ฒญ๋Œ€๋กœ ๊ทธ๋Œ€๋กœ ์œ ์ง€) ---
answer = re.sub(r"[^๊ฐ€-ํžฃ0-9 .,!?~\n]", "", answer)
answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer)
answer = re.sub(r"[a-zA-Z]+", "", answer)
answer = re.sub(r"[=^*/\\]+", "", answer)
answer = re.sub(r"\s+", " ", answer).strip()
# 80์ž ์ œํ•œ
answer = answer[:80]
if answer and answer[-1] not in ".!?":
answer += "."
elif not answer:
answer = "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
return answer
except Exception as e:
print(f"ask_sayknow ์—๋Ÿฌ: {e}")
traceback.print_exc()
return f"๋‚ด๋ถ€ ์˜ค๋ฅ˜: {str(e)}"
# --- 7. API (XML ์‘๋‹ต) ---
@app.route('/chatapi.html', methods=['GET'])
@app.route('/index.html', methods=['GET'])
def chat_api():
query = request.args.get('askdata', '')
if not query:
result = {"status": "error", "message": "No data"}
xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
return Response(xml_output, mimetype='text/xml')
# 6๋ฒˆ ์œ ์ง€ ์š”์ฒญ โ†’ Lock ์ „์ฒด ์œ ์ง€
with request_lock:
try:
answer = ask_sayknow(query)
result = {
"service": "Sayknow",
"question": query,
"answer": answer
}
except Exception as e:
print(f"chat_api ์—๋Ÿฌ: {e}")
traceback.print_exc()
result = {
"service": "Sayknow",
"question": query,
"answer": f"์—๋Ÿฌ ๋ฐœ์ƒ: {str(e)}",
"error": str(e)
}
xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False)
return Response(xml_output, mimetype='text/xml')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)