Spaces:
Running
Running
| import pandas as pd | |
| import torch | |
| from flask import Flask, request, Response | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from dicttoxml import dicttoxml | |
| import traceback | |
| import re | |
| from threading import Lock | |
| app = Flask(__name__) | |
| # --- 1. ๋๋ฐ์ด์ค ์ค์ --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"์ฌ์ฉ ๋๋ฐ์ด์ค: {device}") | |
| torch.set_grad_enabled(False) | |
| # --- 2. ๋ชจ๋ธ ๋ก๋ --- | |
| print("ํ ํฌ๋์ด์ ๋ก๋ฉ ์ค...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "LiquidAI/LFM2.5-1.2B-Instruct", | |
| trust_remote_code=True | |
| ) | |
| print("๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| try: | |
| # 8bit ๋ก๋ ์๋ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "LiquidAI/LFM2.5-1.2B-Instruct", | |
| device_map="auto", | |
| load_in_8bit=True, | |
| trust_remote_code=True | |
| ) | |
| print("8bit ๋ก๋ฉ ์ฑ๊ณต") | |
| except: | |
| # ์คํจ ์ ์ผ๋ฐ ๋ก๋ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "LiquidAI/LFM2.5-1.2B-Instruct", | |
| trust_remote_code=True | |
| ).to(device) | |
| print("์ผ๋ฐ ๋ก๋ฉ ์ฌ์ฉ") | |
| # torch 2.0 ์ด์์ด๋ฉด ์ปดํ์ผ | |
| try: | |
| model = torch.compile(model) | |
| print("torch.compile ์ ์ฉ ์๋ฃ") | |
| except: | |
| print("torch.compile ๋ฏธ์ ์ฉ (์ง์ ์ํจ)") | |
| print("๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| # --- 3. ๋ฐ์ดํฐ์ ๋ก๋ --- | |
| try: | |
| df = pd.read_excel('dataset.xlsx') | |
| knowledge_list = df['๋ฐ์ดํฐ์ ์ ๋ฃ์ ๋ด์ฉ(*)'].tolist() | |
| except Exception as e: | |
| print(f"๋ฐ์ดํฐ์ ๋ก๋ ์๋ฌ: {e}") | |
| knowledge_list = [] | |
| # --- 4. ๋์ ์์ฒญ ์ ํ์ฉ Lock (๊ตฌ์กฐ ์ ์ง) --- | |
| request_lock = Lock() | |
| # --- 5. ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ์ง์ ๊ฒ์ (๊ธฐ์กด ๋ฐฉ์ ์ ์ง) --- | |
| def find_relevant_context(query, top_n=2): | |
| query_words = query.replace(" ", "").lower() | |
| relevant_sentences = [] | |
| for s in knowledge_list: | |
| s_text = str(s).replace(" ", "").replace("\n", "").lower() | |
| if any(word.replace(" ", "") in s_text for word in query.split()): | |
| relevant_sentences.append(s) | |
| return " ".join(str(s) for s in relevant_sentences[:top_n]) if relevant_sentences else "" | |
| # --- 6. Sayknow ๋ต๋ณ ์์ฑ --- | |
| def ask_sayknow(query): | |
| try: | |
| context = find_relevant_context(query) | |
| persona_guide = ( | |
| "๋๋ ์ง์ ๊ธฐ๋ฐ ํ๊ตญ์ด ์ฑ๋ด Sayknow์ผ. ์๊ธฐ์๊ฐ ์ง๋ฌธ์๋ '์ ๋ Sayknow์ ๋๋ค.'๋ผ๊ณ ๋ตํด. " | |
| "๊ทธ ์ธ์๋ ์๋ ์ฐธ๊ณ ํด์ ์ ํํ๊ณ ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด ๋ฌธ์ฅ์ผ๋ก 80์ ์ด๋ด๋ก ๋ตํด.\n" | |
| "์์: Q: ๋ถ์์ ๋ง์ ์ด ๋ญ์ผ?\nA: ๋ถ๋ชจ๊ฐ ๊ฐ์ ๋ ๋ถ์๋ผ๋ฆฌ ๋ํ๋ฉด ๋ฉ๋๋ค.\n" | |
| ) | |
| info = context if context else "์ ๋ณด ์์" | |
| prompt = f"{persona_guide}---\n[์ ๋ณด]\n{info}\n[์ง๋ฌธ]\n{query}\n[๋ต๋ณ] " | |
| tokenizer.pad_token = tokenizer.eos_token | |
| encoded_input = tokenizer.encode_plus( | |
| prompt, | |
| return_tensors='pt', | |
| truncation=True, | |
| padding=True | |
| ) | |
| input_ids = encoded_input['input_ids'].to(device) | |
| attention_mask = encoded_input['attention_mask'].to(device) | |
| model.eval() | |
| gen_ids = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=60, # ์ค์ | |
| min_length=5, | |
| repetition_penalty=1.2, | |
| do_sample=True, | |
| top_k=30, | |
| top_p=0.8, | |
| temperature=0.5, | |
| num_beams=1, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True) | |
| # --- ๋ต๋ณ ์ถ์ถ --- | |
| answer = raw_response.replace(prompt, '').strip() | |
| if "๋ต๋ณ:" in answer: | |
| answer = answer.split("๋ต๋ณ:", 1)[1].strip() | |
| # --- ํ์ฒ๋ฆฌ (5๋ฒ ์ ์ง ์์ฒญ๋๋ก ๊ทธ๋๋ก ์ ์ง) --- | |
| answer = re.sub(r"[^๊ฐ-ํฃ0-9 .,!?~\n]", "", answer) | |
| answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer) | |
| answer = re.sub(r"[a-zA-Z]+", "", answer) | |
| answer = re.sub(r"[=^*/\\]+", "", answer) | |
| answer = re.sub(r"\s+", " ", answer).strip() | |
| # 80์ ์ ํ | |
| answer = answer[:80] | |
| if answer and answer[-1] not in ".!?": | |
| answer += "." | |
| elif not answer: | |
| answer = "์ฃ์กํฉ๋๋ค. ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์ฐพ์ ์ ์์ต๋๋ค." | |
| return answer | |
| except Exception as e: | |
| print(f"ask_sayknow ์๋ฌ: {e}") | |
| traceback.print_exc() | |
| return f"๋ด๋ถ ์ค๋ฅ: {str(e)}" | |
| # --- 7. API (XML ์๋ต) --- | |
| def chat_api(): | |
| query = request.args.get('askdata', '') | |
| if not query: | |
| result = {"status": "error", "message": "No data"} | |
| xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False) | |
| return Response(xml_output, mimetype='text/xml') | |
| # 6๋ฒ ์ ์ง ์์ฒญ โ Lock ์ ์ฒด ์ ์ง | |
| with request_lock: | |
| try: | |
| answer = ask_sayknow(query) | |
| result = { | |
| "service": "Sayknow", | |
| "question": query, | |
| "answer": answer | |
| } | |
| except Exception as e: | |
| print(f"chat_api ์๋ฌ: {e}") | |
| traceback.print_exc() | |
| result = { | |
| "service": "Sayknow", | |
| "question": query, | |
| "answer": f"์๋ฌ ๋ฐ์: {str(e)}", | |
| "error": str(e) | |
| } | |
| xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False) | |
| return Response(xml_output, mimetype='text/xml') | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |