Spaces:
Running
Running
| import pandas as pd | |
| import torch | |
| from flask import Flask, request, Response, render_template_string | |
| from transformers import AutoTokenizer, GPT2LMHeadModel | |
| from dicttoxml import dicttoxml | |
| import re | |
| import traceback | |
| app = Flask(__name__) | |
| # --- hCaptcha ์ค์ ๊ด๋ จ ์ฝ๋ ์ ๋ถ ์ ๊ฑฐ๋จ --- | |
| # 1. ๋ชจ๋ธ ๋ก๋ (๊ธฐ์กด๊ณผ ๋์ผ) | |
| print("ํ ํฌ๋์ด์ ๋ก๋ฉ ์ค...") | |
| tokenizer = AutoTokenizer.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True) | |
| print("๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| model = GPT2LMHeadModel.from_pretrained("skt/kogpt2-base-v2", trust_remote_code=True) | |
| print("๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| # 2. ๋ฐ์ดํฐ์ ๋ก๋ (๊ธฐ์กด๊ณผ ๋์ผ) | |
| try: | |
| df = pd.read_excel('dataset.xlsx') | |
| knowledge_list = df['๋ฐ์ดํฐ์ ์ ๋ฃ์ ๋ด์ฉ(*)'].tolist() | |
| except Exception as e: | |
| print(f"๋ฐ์ดํฐ์ ๋ก๋ ์๋ฌ: {e}") | |
| knowledge_list = [] | |
| def find_relevant_context(query, top_n=2): | |
| """์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ์ง์๋ฐ์ดํฐ ๋ฌธ์ฅ ์ต๋ top_n๊ฐ ์ฐพ์์ ๋ฐํ (๊ธฐ์กด๊ณผ ๋์ผ)""" | |
| query_words = query.replace(" ", "").lower() | |
| relevant_sentences = [] | |
| for s in knowledge_list: | |
| s_text = str(s).replace(" ", "").replace("\n", "").lower() | |
| if any(word.replace(" ", "") in s_text for word in query.split()): | |
| relevant_sentences.append(s) | |
| if relevant_sentences: | |
| return " ".join(str(s) for s in relevant_sentences[:top_n]) | |
| return "" | |
| def ask_sayknow(query): | |
| try: | |
| context = find_relevant_context(query) | |
| persona_guide = ( | |
| "๋๋ ์ง์ ๊ธฐ๋ฐ ํ๊ตญ์ด ์ฑ๋ด Sayknow์ผ. ์๊ธฐ์๊ฐ(์ด๋ฆ, ์ ์ฒด, ์ธ์ฌ ๋ฑ) ์ง๋ฌธ์ '์ ๋ Sayknow์ ๋๋ค.'๋ผ๊ณ ๋ตํด. " | |
| "๊ทธ ์ธ์ ์๋ ์ฐธ๊ณ ํด์ ์ ํํ๊ณ ์์ฐ์ค๋ฌ์ด ํ๊ตญ์ด ๋ฌธ์ฅ์ผ๋ก 80์ ์ด๋ด๋ก ๋ตํด.\n" | |
| "์์: Q: ๋ถ์์ ๋ง์ ์ด ๋ญ์ผ?\nA: ๋ถ๋ชจ๊ฐ ๊ฐ์ ๋ ๋ถ์๋ผ๋ฆฌ ๋ํ๋ฉด ๋ฉ๋๋ค.\n" | |
| ) | |
| info = context if context else "์ ๋ณด ์์" | |
| prompt = f"{persona_guide}---\n[์ ๋ณด]\n{info}\n[์ง๋ฌธ]\n{query}\n[๋ต๋ณ] " | |
| # ์ด์ ๋ต๋ณ ๋ก์ง ๊ฐ์ (attention_mask ์ถ๊ฐ) - ์ด ๋ถ๋ถ์ ์ ์๋ํ๊ณ ์์ ๊ฑฐ์ผ! | |
| tokenizer.pad_token = tokenizer.eos_token | |
| encoded_input = tokenizer.encode_plus( | |
| prompt, | |
| return_tensors='pt', | |
| truncation=True, | |
| padding=True | |
| ) | |
| input_ids = encoded_input['input_ids'] | |
| attention_mask = encoded_input['attention_mask'] | |
| model.eval() | |
| with torch.no_grad(): | |
| gen_ids = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=512, # ๋ต๋ณ์ด ์๋ฆฌ๋ ๋ฌธ์ ๋ฐฉ์ง๋ฅผ ์ํด ์กฐ๊ธ ๋๋ ค๋ดค์ด! (60 -> 80) | |
| min_length=5, | |
| repetition_penalty=1.3, | |
| do_sample=True, | |
| top_k=30, | |
| top_p=0.85, | |
| pad_token_id=tokenizer.pad_token_id, | |
| temperature=0.5, | |
| num_beams=1 | |
| ) | |
| raw_response = tokenizer.decode(gen_ids[0], skip_special_tokens=True) # ์๋ณธ ์๋ต ์ ์ฅ | |
| # --- ์๋ต ์ฒ๋ฆฌ ๋ก์ง ๊ฐ์ ๋ฒ์ (index out of range ์๋ฌ ๋ฐฉ์ง) --- | |
| # 1. ๋ชจ๋ธ์ด ์์ฑํ ์ ์ฒด ํ ์คํธ์์ ํ๋กฌํํธ ๋ถ๋ถ ์๋ฅด๊ธฐ (๋ฐ๋ณต๋๋ ๋ฌธ์ ๋ฐฉ์ง) | |
| # prompt๊ฐ raw_response์ ์์ ๋ถ๋ถ์ ์๋ค๋ฉด ๊ทธ ๋ถ๋ถ์ ์๋ผ๋ผ๊ฒ. | |
| if raw_response.startswith(prompt): | |
| extracted_answer = raw_response[len(prompt):].strip() | |
| else: | |
| extracted_answer = raw_response.strip() | |
| # 2. '๋ต๋ณ:' ํค์๋๋ฅผ ๊ธฐ์ค์ผ๋ก ์ง์ง ๋ต๋ณ ๋ถ๋ถ ์ถ์ถ | |
| if "๋ต๋ณ:" in extracted_answer: | |
| answer = extracted_answer.split("๋ต๋ณ:", 1)[1].strip() # ์ฒซ ๋ฒ์งธ "๋ต๋ณ:" ์ดํ๋ง | |
| else: | |
| # ๋ง์ฝ "๋ต๋ณ:" ํ๊ทธ๊ฐ ์์ผ๋ฉด, ํ๋กฌํํธ์ ์ง์์ฌํญ ์ค๋ณต ๋ฑ์ ์ ๊ฑฐ ์๋ | |
| persona_end_marker = "๋ตํด.\n" # persona_guide์ ํน์ ๋ ๋ถ๋ถ์ ํ์ | |
| if persona_end_marker in extracted_answer: | |
| try: | |
| answer = extracted_answer[extracted_answer.rindex(persona_end_marker) + len(persona_end_marker):].strip() | |
| except ValueError: | |
| answer = extracted_answer # ์๋๋ฉด ๊ทธ๋ฅ ์ ์ฒด ์ฌ์ฉ | |
| else: | |
| answer = extracted_answer # ๊ทธ๊ฒ๋ ์์ผ๋ฉด ๊ทธ๋ฅ ์ ์ฒด ์ฌ์ฉ | |
| # ๊ทธ๋๋ ๋ต๋ณ์ด ๋น์ด์์ผ๋ฉด ์ค๋ฅ ๋ฉ์์ง๋ฅผ ๋์ฒด | |
| if not answer: | |
| answer = "์ฃ์กํฉ๋๋ค. ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์ฐพ์ ์ ์๊ฑฐ๋ ๋ด์ฉ์ด ๋ช ํํ์ง ์์ต๋๋ค." | |
| # 1. ์๋ฏธ ์๋ ์์/์๋ฌธ/ํน์๋ฌธ์/๋ฐ๋ณต๋ฌธ์ ๋ฑ ํํฐ๋ง (๊ธฐ์กด๊ณผ ๋์ผ) | |
| # ์ด ๋ถ๋ถ์ ๋จผ์ ํ๋ฒ ์ ์ฉํด์ answer๊ฐ ์๋ฑํ ๋ฌธ์์ด์ด ๋๋ ๊ฑธ ๋ฐฉ์ง | |
| answer = re.sub(r"[^๊ฐ-ํฃ0-9 .,!?~\n]", "", answer) | |
| answer = re.sub(r"([.,!?~])\1{2,}", r"\1", answer) | |
| answer = re.sub(r"[a-zA-Z]+", "", answer) | |
| answer = re.sub(r"[=^*/\\]+", "", answer) | |
| answer = re.sub(r"\s+", " ", answer).strip() | |
| # 2. 80์ ์ด๋ด๋ก ์๋ฅด๊ธฐ (ํ๊ธ ๊ธฐ์ค) (๊ธฐ์กด๊ณผ ๋์ผ) | |
| def truncate_korean(text, max_len=80): | |
| count = 0 | |
| result = "" | |
| for ch in text: | |
| result += ch | |
| count += 1 | |
| if count >= max_len: | |
| break | |
| return result | |
| answer = truncate_korean(answer, 80) | |
| # 3. ๋ฌธ์ฅ ๋์ด ์์ฐ์ค๋ฝ์ง ์์ผ๋ฉด ๋ง์นจํ ์ถ๊ฐ | |
| if answer and answer[-1] not in ".!?": | |
| answer += "." | |
| elif not answer: # ๋น ๋ฌธ์์ด์ธ๋ฐ '.' ์ฐ์ผ๋ฉด ์๋ฌ๋๋ ํ๋ฒ ๋ ์ฒดํฌ | |
| answer = "์ ์ ์๋ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค." # ์ตํ์ ๋ณด๋ฃจ | |
| return answer | |
| except Exception as e: | |
| print(f"ask_sayknow ์๋ฌ: {e}") | |
| traceback.print_exc() | |
| return f"๋ด๋ถ ์ค๋ฅ: {str(e)}" # ์ธ๋ถ ์ฌ์ฉ์์๊ฒ ๋ณด์ด๋ ๋ฉ์์ง! | |
| # 3. API (XML ์๋ต ์ ์ง) (๊ธฐ์กด๊ณผ ๋์ผ) | |
| def chat_api(): | |
| query = request.args.get('askdata', '') | |
| if not query: | |
| result = {"status": "error", "message": "No data"} | |
| else: | |
| try: | |
| answer = ask_sayknow(query) | |
| result = { | |
| "service": "Sayknow", | |
| "question": query, | |
| "answer": answer | |
| } | |
| except Exception as e: | |
| print(f"chat_api ์๋ฌ: {e}") | |
| traceback.print_exc() | |
| result = { | |
| "service": "Sayknow", | |
| "question": query, | |
| "answer": f"์๋ฌ ๋ฐ์: {str(e)}", | |
| "error": str(e) | |
| } | |
| xml_output = dicttoxml(result, custom_root='SayknowAPI', attr_type=False) | |
| return Response(xml_output, mimetype='text/xml') | |
| # 4. ์น UI (๊ฐ๋จํ ์ง๋ฌธ ํผ + ๋ต๋ณ) - hCaptcha ์ฝ๋ ์ ๋ถ ์ ๊ฑฐ! | |
| def index(): | |
| answer = "" | |
| question = "" | |
| # error_message ์ ๊ฑฐ | |
| if request.method == "POST": | |
| question = request.form.get('question', '') | |
| # hcaptcha_response ๊ด๋ จ ๋ก์ง ์ ๊ฑฐ | |
| # hCaptcha ๊ฒ์ฆ ๋ก์ง ์ ๊ฑฐ | |
| if question: # ์ง๋ฌธ์ด ์์ผ๋ฉด ๋ฐ๋ก ๋ต๋ณ ์์ฑ! | |
| answer = ask_sayknow(question) | |
| html = f""" | |
| <html> | |
| <head> | |
| <title>Sayknow ์ฑ๋ด</title> | |
| <!-- hCaptcha ์คํฌ๋ฆฝํธ ์ ๊ฑฐ --> | |
| </head> | |
| <body> | |
| <h2>Sayknow ํ๊ตญ์ด ์ฑ๋ด</h2> | |
| <form method="post" action="/"> | |
| <input type="text" name="question" value="{question}" placeholder="์ง๋ฌธ์ ์ ๋ ฅํ์ธ์" style="width:300px;" autofocus /> | |
| <br/><br/> | |
| <!-- hCaptcha ์์ ฏ ์ ๊ฑฐ --> | |
| <!-- ์๋ฌ ๋ฉ์์ง ๋ณด์ฌ์ฃผ๋ ๋ถ๋ถ ์ ๊ฑฐ --> | |
| <br/> | |
| <input type="submit" value="์ง๋ฌธํ๊ธฐ" /> | |
| </form> | |
| <hr> | |
| <h3>๋ต๋ณ:</h3> | |
| <p style="white-space: pre-wrap;">{answer}</p> | |
| </body> | |
| </html> | |
| """ | |
| return render_template_string(html) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |