Update app.py
Browse files
app.py
CHANGED
|
@@ -1114,50 +1114,36 @@ def recommend_content_based(user_profile: dict, top_n=5):
|
|
| 1114 |
#####################################
|
| 1115 |
# 5) ์ฑ๋ด ๋ก์ง
|
| 1116 |
#####################################
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
|
|
|
| 1120 |
|
| 1121 |
-
def chat_response(user_input, mode="emotion"
|
| 1122 |
if mode not in ["emotion", "rational"]:
|
| 1123 |
raise HTTPException(status_code=400, detail="mode๋ 'emotion' ๋๋ 'rational'์ด์ด์ผ ํฉ๋๋ค.")
|
| 1124 |
|
| 1125 |
prompt = f"<{mode}><usr>{user_input}</usr><sys>"
|
| 1126 |
-
payload = {
|
| 1127 |
-
"inputs": prompt,
|
| 1128 |
-
"parameters": {
|
| 1129 |
-
"max_new_tokens": 128,
|
| 1130 |
-
"temperature": 0.7,
|
| 1131 |
-
"top_p": 0.9,
|
| 1132 |
-
"top_k": 50,
|
| 1133 |
-
"repetition_penalty": 1.2,
|
| 1134 |
-
"do_sample": True
|
| 1135 |
-
},
|
| 1136 |
-
"options": {"wait_for_model": True}
|
| 1137 |
-
}
|
| 1138 |
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
-
|
| 1158 |
-
return f"API Error: {response.status_code}, {response.text}"
|
| 1159 |
-
|
| 1160 |
-
return "๐จ ๋ชจ๋ธ ๋ก๋ฉ์ด ๋๋ฌด ์ค๋ ๊ฑธ๋ฆฝ๋๋ค. ์ ์ ํ ๋ค์ ์๋ํ์ธ์."
|
| 1161 |
|
| 1162 |
|
| 1163 |
#์ฐ์ธ๋ถ๋ฅ ๋ชจ๋ธ ์ถ๊ฐ
|
|
|
|
| 1114 |
#####################################
|
| 1115 |
# 5) ์ฑ๋ด ๋ก์ง
|
| 1116 |
#####################################
|
| 1117 |
+
tokenizer = AutoTokenizer.from_pretrained("Chanjeans/tfchatbot_2")
|
| 1118 |
+
model = AutoModelForCausalLM.from_pretrained("Chanjeans/tfchatbot_2")
|
| 1119 |
+
model.eval()
|
| 1120 |
+
print("Model loaded successfully.")
|
| 1121 |
|
| 1122 |
+
def chat_response(user_input, mode="emotion"):
|
| 1123 |
if mode not in ["emotion", "rational"]:
|
| 1124 |
raise HTTPException(status_code=400, detail="mode๋ 'emotion' ๋๋ 'rational'์ด์ด์ผ ํฉ๋๋ค.")
|
| 1125 |
|
| 1126 |
prompt = f"<{mode}><usr>{user_input}</usr><sys>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1127 |
|
| 1128 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 1129 |
+
|
| 1130 |
+
with torch.no_grad():
|
| 1131 |
+
outputs = model.generate(
|
| 1132 |
+
**inputs,
|
| 1133 |
+
max_new_tokens=128,
|
| 1134 |
+
temperature=0.7,
|
| 1135 |
+
top_p=0.9,
|
| 1136 |
+
top_k=50,
|
| 1137 |
+
repetition_penalty=1.2,
|
| 1138 |
+
do_sample=True
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 1142 |
+
# prompt ๋ถ๋ถ ์ ๊ฑฐ (๋ถํ์ํ ํ๋กฌํํธ๊น์ง ๋ฐํ๋์ง ์๋๋ก)
|
| 1143 |
+
response_text = generated_text.replace(prompt, "").strip()
|
| 1144 |
+
|
| 1145 |
+
return response_text
|
| 1146 |
+
|
|
|
|
|
|
|
|
|
|
| 1147 |
|
| 1148 |
|
| 1149 |
#์ฐ์ธ๋ถ๋ฅ ๋ชจ๋ธ ์ถ๊ฐ
|