Jay1121 commited on
Commit
3bbacc0
·
verified ·
1 Parent(s): 4985efe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -91
app.py CHANGED
@@ -1,57 +1,63 @@
1
  # -*- coding: utf-8 -*-
2
- # app.py — SOLAR 10.7B 친구 챗봇 (Gradio, 경량 설정)
3
 
4
- import os, re, random, difflib, torch
 
 
 
 
5
  from datetime import datetime
 
6
  try:
7
  from zoneinfo import ZoneInfo
8
  except Exception:
9
  ZoneInfo = None
10
 
11
  import gradio as gr
12
- from transformers import AutoModelForCausalLM, AutoTokenizer
13
  from peft import PeftModel
14
 
15
  # =========================================================
16
  # 기본 모델 / 파인튜닝 모델 경로
17
  # =========================================================
18
 
19
- # 베이스 SOLAR 모델 (토크나이저 fallback 용)
20
- BASE_MODEL_PATH = "Upstage/SOLAR-10.7B-Instruct-v1.0"
21
 
22
- # Hugging Face Hub 에 올려둔 병합 모델 리포 ID
23
- # - Colab 에서 별도 경로 쓰고 싶으면 환경변수 MODEL_DIR override
24
- MODEL_DIR = os.environ.get("MODEL_DIR", "Jay1121/my-solar-chatbot-merged")
25
 
26
  # =========================================================
27
  # 환경 변수 / 기본값 설정
28
  # =========================================================
29
 
30
- # 사전/욕설 경로 (Space에는 ./dictionaries 안에 같이 올리면 됨)
31
- DICT_PATH = os.environ.get("DICT_PATH", "./dictionaries/korean_words.txt")
32
  PROFANITY_PATH = os.environ.get("PROFANITY_PATH", "")
 
 
33
 
34
- # 속도/품질 옵션 (기본은 빠르게 쪽으로)
35
- OOV_THRESHOLD = int(os.environ.get("OOV_THRESHOLD", "0"))
36
- OOV_STRIP = os.environ.get("OOV_STRIP", "1") == "1"
37
- STRICT_MODE = os.environ.get("STRICT_MODE", "0") == "1" # 기본 OFF
38
- SAFETY_ON = os.environ.get("SAFETY_ON", "0") == "1" # 기본 OFF
39
- BAN_JAMO = os.environ.get("BAN_JAMO", "1") == "1"
40
- USE_FA = os.environ.get("USE_FLASH_ATTN", "1") == "1"
41
 
42
- STYLE_MODE = os.environ.get("STYLE_MODE", "auto") # auto | deadpan | neutral
43
- WHITELIST_JAMO = set([s.strip() for s in os.environ.get("WHITELIST_JAMO", "ㅎ,ㅋ").split(",") if s.strip()])
44
- KEEP_REPEATS = os.environ.get("KEEP_REPEATS", "0") == "1"
45
 
46
- ANTI_SMALLTALK = os.environ.get("ANTI_SMALLTALK", "0") == "1" # 기본 OFF
47
- SMALLTALK_TRIES= int(os.environ.get("SMALLTALK_TRIES", "1"))
 
 
 
 
48
 
49
  META_BANS = ["AI", "인공지능", "챗봇", "도와줄게", "역할"]
50
 
51
  DEFAULT_PROFANITY = {
52
- "씨발", "시발", "ㅅㅂ", "좆", "좆같", "개같", "개새끼", "개새", "개소리", "지랄",
53
- "병신", "븅신", "병쉰", "병1신", "염병", "닥쳐", "꺼져", "닥치", "ㅄ", "ㅗ", "씹",
54
- "ㅈ같", "개지랄", "싫다", "빡친", "개빡", "개빡침", "등신", "존나", "미친"
 
55
  }
56
 
57
  # =========================================================
@@ -61,28 +67,32 @@ DEFAULT_PROFANITY = {
61
  def _pick_attn_impl():
62
  return "flash_attention_2" if USE_FA else "sdpa"
63
 
 
64
  def _is_peft_adapter(model_dir: str) -> bool:
65
  return os.path.exists(os.path.join(model_dir, "adapter_config.json"))
66
 
 
67
  def _has_full_model(model_dir: str) -> bool:
68
  names = ["pytorch_model.bin", "model.safetensors", "consolidated.safetensors"]
69
  has_weight = any(os.path.exists(os.path.join(model_dir, n)) for n in names)
70
- has_cfg = os.path.exists(os.path.join(model_dir, "config.json"))
71
  return has_weight and has_cfg
72
 
 
73
  def _has_tokenizer_files(path: str) -> bool:
74
  if not path:
75
  return False
76
- return any(os.path.exists(os.path.join(path, n)) for n in [
77
- "tokenizer.model", "tokenizer.json", "vocab.json", "merges.txt"
78
- ])
 
79
 
80
- def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
81
- tried = []
82
 
 
83
  def _try(path, fast):
84
- tried.append(f"{path} (fast={fast})")
85
- return AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=fast)
 
86
 
87
  # 1) 로컬 tokenizer.model 우선
88
  if local_dir and os.path.exists(os.path.join(local_dir, "tokenizer.model")):
@@ -90,7 +100,7 @@ def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
90
  tok = _try(local_dir, False)
91
  if tok.pad_token is None:
92
  tok.pad_token = tok.eos_token
93
- print(f"🔤 토크나이저 OK: {local_dir} (use_fast=False, tokenizer.model)")
94
  return tok
95
  except Exception as e:
96
  print(f"⚠️ local slow 실패: {e}")
@@ -101,7 +111,7 @@ def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
101
  tok = _try(local_dir, True)
102
  if tok.pad_token is None:
103
  tok.pad_token = tok.eos_token
104
- print(f"🔤 토크나이저 OK: {local_dir} (use_fast=True, tokenizer.json)")
105
  return tok
106
  except Exception as e:
107
  print(f"⚠️ local fast 실패: {e}")
@@ -112,7 +122,7 @@ def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
112
  tok = _try(fallback_dir, fast)
113
  if tok.pad_token is None:
114
  tok.pad_token = tok.eos_token
115
- print(f"🔤 토크나이저 OK: {fallback_dir} (use_fast={fast})")
116
  return tok
117
  except Exception as e:
118
  print(f"⚠️ fallback (fast={fast}) 실패: {e}")
@@ -120,26 +130,39 @@ def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
120
  raise RuntimeError("토크나이저 로드에 모두 실패했습니다.")
121
 
122
  # =========================================================
123
- # 모델 로드
124
  # =========================================================
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def load_model_for_chat(model_dir: str, tokenizer_dir: str | None = None):
127
  """
128
  model_dir:
129
- - 로컬 폴더일 수도 있고
130
- - Hugging Face Hub repo id (예: 'Jay1121/my-solar-chatbot-merged') 일 수도 있음
131
  """
132
-
133
  if os.path.isdir(model_dir):
134
  print(f"▶ 로컬 모델 폴더: {model_dir}")
135
  is_adapter = _is_peft_adapter(model_dir)
136
- is_full = _has_full_model(model_dir)
137
  else:
138
  print(f"▶ 로컬 폴더 없음 → HF Hub에서 '{model_dir}' 로드 시도")
139
  is_adapter = False
140
- is_full = False
141
 
142
  attn_impl = _pick_attn_impl()
 
143
 
144
  # 토크나이저 경로 선택
145
  if tokenizer_dir:
@@ -152,9 +175,9 @@ def load_model_for_chat(model_dir: str, tokenizer_dir: str | None = None):
152
  print(f"🔎 토크나이저 경로 선택: {tk_dir}")
153
  tok = _load_tokenizer_pref_local(tk_dir, BASE_MODEL_PATH)
154
 
155
- # 1) PEFT 어댑터 폴더인 경우 (로컬 디렉토리에서만 의미 있음)
156
  if is_adapter and not is_full:
157
- print("📦 감지: PEFT LoRA 어댑터 → 베이스(SOLAR) 로드 후 어댑터 적용")
158
  try:
159
  base = AutoModelForCausalLM.from_pretrained(
160
  BASE_MODEL_PATH,
@@ -182,30 +205,49 @@ def load_model_for_chat(model_dir: str, tokenizer_dir: str | None = None):
182
  print("✅ 어댑터 병합(merge_and_unload) 완료")
183
  except Exception as e:
184
  print(f"ℹ️ 병합 스킵: {e}")
 
185
  model.eval()
186
  print("✅ 모델 로드 완료!")
187
  return model, tok
188
 
189
- # 2) 병합된 풀 모델 or HF Hub 모델로 로드
190
  print("📦 감지: 병합된 '완전체' 모델 또는 HF Hub 모델 → from_pretrained 로 로드")
191
  try:
192
- model = AutoModelForCausalLM.from_pretrained(
193
- model_dir,
194
- torch_dtype=torch.float16,
195
- device_map="auto",
196
- trust_remote_code=True,
197
- attn_implementation=attn_impl,
198
- )
199
- except Exception as e:
200
- if attn_impl == "flash_attention_2":
201
- print(f"⚠️ flash-attn 실패 → SDPA로 전환: {e}")
202
  model = AutoModelForCausalLM.from_pretrained(
203
  model_dir,
204
  torch_dtype=torch.float16,
205
  device_map="auto",
206
  trust_remote_code=True,
207
- attn_implementation="sdpa",
208
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  else:
210
  raise
211
 
@@ -226,6 +268,7 @@ def load_dictionary(path=DICT_PATH):
226
  print(f"📚 사전 없음: {path} (OOV 검사 약화)")
227
  return set()
228
 
 
229
  def load_profanity(path=PROFANITY_PATH):
230
  prof = set(DEFAULT_PROFANITY)
231
  if path and os.path.exists(path):
@@ -241,12 +284,15 @@ def load_profanity(path=PROFANITY_PATH):
241
  # 전처리 / 검사
242
  # =========================================================
243
 
244
- RE_LAUGH = re.compile(r'(ㅋ|ㅎ|ㅠ|ㅜ)\1{2,}')
245
- RE_EN = re.compile(r'[A-Za-z]+')
246
- RE_WORDS = re.compile(r'[가-힣]{2,}')
247
 
248
  def build_bad_words_ids(tokenizer):
249
- ids = [tokenizer(w, add_special_tokens=False).input_ids for w in META_BANS]
 
 
 
250
  for ch in list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"):
251
  ids.append(tokenizer(ch, add_special_tokens=False).input_ids)
252
  if BAN_JAMO:
@@ -257,18 +303,21 @@ def build_bad_words_ids(tokenizer):
257
  ids.append(tokenizer(ch, add_special_tokens=False).input_ids)
258
  return ids
259
 
 
260
  def clean_text(txt: str):
261
  if not KEEP_REPEATS:
262
  txt = RE_LAUGH.sub(lambda m: m.group(1) * 2, txt)
263
- txt = RE_EN.sub('', txt)
264
  cut = txt.split("### User:")[0]
265
  return cut.strip()
266
 
 
267
  def count_oov(txt: str, dictionary, allowlist):
268
  words = RE_WORDS.findall(txt)
269
  oov = [w for w in words if (w not in dictionary and w not in allowlist)]
270
  return len(oov), oov
271
 
 
272
  def strip_oov(txt: str, dictionary, allowlist):
273
  kept, i = [], 0
274
  while i < len(txt):
@@ -282,28 +331,32 @@ def strip_oov(txt: str, dictionary, allowlist):
282
  kept.append(w)
283
  i = m.end()
284
  out = "".join(kept)
285
- out = re.sub(r'\s{2,}', ' ', out).strip()
286
  return out
287
 
288
  SMALLTALK_PATTERNS = [
289
- r'오늘\s*날씨', r'\b날씨\s*(가|는)?\s*(좋|괜찮|별로|따뜻|쌀쌀|시원|선선)',
290
- r'(하늘|기온|미세먼지)\s*(이|가)?\s*(좋|맑|깨끗|나쁨|흐림)',
291
- r'(더워|추워)\b', r'비(\s*가)?\s*(온|와|왔|올)\b'
 
 
292
  ]
293
  SMALLTALK_REGEXES = [re.compile(p) for p in SMALLTALK_PATTERNS]
294
 
295
  def normalize_for_sim(s: str):
296
- s = re.sub(r'\s+', '', s)
297
- s = re.sub(r'[.!?~…]+', '', s)
298
- s = re.sub(r'(.)\1{2,}', r'\1\1', s)
299
  return s
300
 
 
301
  def looks_smalltalk(text: str):
302
  t = normalize_for_sim(text)
303
  if "오늘날씨좋았어" in t:
304
  return True
305
  return any(rx.search(text) for rx in SMALLTALK_REGEXES)
306
 
 
307
  def too_similar_to_history(text: str, history_texts, thresh=0.86):
308
  t1 = normalize_for_sim(text)
309
  for h in history_texts:
@@ -317,9 +370,9 @@ def too_similar_to_history(text: str, history_texts, thresh=0.86):
317
  # =========================================================
318
 
319
  DEADPAN_TRIGGERS = [
320
- "심심", "귀찮", "짜증", "싫", "하..", "휴", "후", "지루", "그만", "피곤", "죽였어",
321
- "개소리", "뭐래", "에휴", "흥미없", "아...", "음....", ";;;;", "어쩌라고",
322
- "그건 본인 사정이죠", "그건 니사정이지"
323
  ]
324
 
325
  def should_deadpan(user_text: str):
@@ -330,11 +383,12 @@ def should_deadpan(user_text: str):
330
  return False
331
  return any(k in user_text for k in DEADPAN_TRIGGERS)
332
 
 
333
  def postprocess_deadpan(reply: str):
334
  reply = reply.replace("!", ".")
335
- reply = re.sub(r'[~…]+', '...', reply)
336
  if len(reply) > 120:
337
- cut = re.split(r'([.다]\s)', reply, maxsplit=1)
338
  if cut and len("".join(cut[:2])) > 0:
339
  reply = "".join(cut[:2]).strip()
340
  reply = reply[:120].rstrip() + "..."
@@ -345,17 +399,32 @@ def postprocess_deadpan(reply: str):
345
  return reply.strip()
346
 
347
  # =========================================================
348
- # 디코딩 (경량화)
349
  # =========================================================
350
 
351
  def decode_once(model, tok, prompt, bad_words_ids, *, deadpan=False):
352
- # max_new_tokens 줄여서 속도 확보
353
  if deadpan:
354
- cfg = dict(do_sample=True, temperature=0.25, top_p=0.85, max_new_tokens=96)
 
 
 
 
 
355
  elif STRICT_MODE:
356
- cfg = dict(do_sample=True, temperature=0.35, top_p=0.88, max_new_tokens=128)
 
 
 
 
 
357
  else:
358
- cfg = dict(do_sample=True, temperature=0.5, top_p=0.9, max_new_tokens=128)
 
 
 
 
 
359
 
360
  inputs = tok(prompt, return_tensors="pt").to(model.device)
361
  with torch.no_grad():
@@ -366,7 +435,7 @@ def decode_once(model, tok, prompt, bad_words_ids, *, deadpan=False):
366
  eos_token_id=tok.eos_token_id,
367
  pad_token_id=tok.pad_token_id,
368
  bad_words_ids=bad_words_ids,
369
- **cfg
370
  )
371
  gen = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
372
  return clean_text(gen)
@@ -376,17 +445,17 @@ def decode_once(model, tok, prompt, bad_words_ids, *, deadpan=False):
376
  # =========================================================
377
 
378
  SYSTEM_PROMPT = (
379
- "너는 사용자의 가장 친한 친구야. 순수한 한글 구어체로 말해. "
380
  f"영문/불필요한 낱자 자모 금지(허용: {','.join(sorted(WHITELIST_JAMO))}). "
381
  "메타 단어('AI','인공지능','챗봇','도와줄게','역할') 금지. "
382
- "필요하면 짧고 건조하게 답해도 돼.\n\n"
383
  "--- 대화 예시 ---\n"
384
  "User: 넌 누구야?\n"
385
- "Assistant: 나는.. 인간이야..\n"
386
  "User: 무슨 일 해?\n"
387
- "Assistant: 그냥 있어..\n"
388
  "User: 심심하다\n"
389
- "Assistant: 음.. 뭐 할래? 산책?\n"
390
  "--- 여기까지 예시 ---\n\n"
391
  )
392
 
@@ -408,20 +477,23 @@ print("✅ 초기화 완료")
408
  def chat_fn(user_input, history):
409
  # history: 리스트 [(user, bot), ...]
410
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
411
- for u, b in history[-5:]: # 최근 5턴만 사용
 
 
412
  messages.append({"role": "user", "content": u})
413
  messages.append({"role": "assistant", "content": b})
414
  messages.append({"role": "user", "content": user_input})
415
 
416
  prompt = tokenizer.apply_chat_template(
417
- messages, tokenize=False, add_generation_prompt=True
 
 
418
  )
419
 
420
  deadpan = should_deadpan(user_input)
421
-
422
  reply = decode_once(model, tokenizer, prompt, bad_words_ids, deadpan=deadpan)
423
- oov_cnt, _ = count_oov(reply, dictionary, profanity)
424
 
 
425
  if OOV_STRIP and oov_cnt > 0:
426
  reply = strip_oov(reply, dictionary, profanity)
427
 
@@ -434,11 +506,34 @@ def chat_fn(user_input, history):
434
  # Gradio UI
435
  # =========================================================
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  demo = gr.ChatInterface(
438
  fn=chat_fn,
439
- title="SOLAR 친구 챗봇",
440
- description="SOLAR-10.7B 기반 한글 친구 챗봇 (가벼운 설정)",
441
- examples=[" 오늘 개피곤하다", "이직할까 말까 고민중이야", "나 칭찬해줘"],
 
 
 
 
 
 
 
 
442
  )
443
 
444
  if __name__ == "__main__":
 
1
  # -*- coding: utf-8 -*-
2
+ # app.py — 어느 MZ 친구의 느린 DM방 (Blossom 8B, 4bit, Gradio)
3
 
4
+ import os
5
+ import re
6
+ import random
7
+ import difflib
8
+ import torch
9
  from datetime import datetime
10
+
11
  try:
12
  from zoneinfo import ZoneInfo
13
  except Exception:
14
  ZoneInfo = None
15
 
16
  import gradio as gr
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
18
  from peft import PeftModel
19
 
20
  # =========================================================
21
  # 기본 모델 / 파인튜닝 모델 경로
22
  # =========================================================
23
 
24
+ BASE_MODEL_PATH = "MLP-KTLim/llama-3-Korean-Bllossom-8B"
 
25
 
26
+ # Hugging Face Hub에 올려둔 병합 모델 리포 ID
27
+ MODEL_DIR_DEFAULT = "Jay1121/blossom_v2/kakao_merged_v1" # << 여기를 repo바꿔!
28
+ MODEL_DIR = os.environ.get("MODEL_DIR", MODEL_DIR_DEFAULT)
29
 
30
  # =========================================================
31
  # 환경 변수 / 기본값 설정
32
  # =========================================================
33
 
34
+ DICT_PATH = os.environ.get("DICT_PATH", "./dictionaries/korean_words.txt")
 
35
  PROFANITY_PATH = os.environ.get("PROFANITY_PATH", "")
36
+ OOV_THRESHOLD = int(os.environ.get("OOV_THRESHOLD", "0"))
37
+ OOV_STRIP = os.environ.get("OOV_STRIP", "1") == "1"
38
 
39
+ STRICT_MODE = os.environ.get("STRICT_MODE", "0") == "1" # 기본 OFF
40
+ SAFETY_ON = os.environ.get("SAFETY_ON", "0") == "1" # 기본 OFF
41
+ BAN_JAMO = os.environ.get("BAN_JAMO", "1") == "1"
42
+ USE_FA = os.environ.get("USE_FLASH_ATTN", "1") == "1"
43
+ USE_4BIT = os.environ.get("USE_4BIT", "1") == "1" # 기본 4bit 사용
 
 
44
 
45
+ STYLE_MODE = os.environ.get("STYLE_MODE", "auto") # auto | deadpan | neutral
 
 
46
 
47
+ WHITELIST_JAMO = set(
48
+ [s.strip() for s in os.environ.get("WHITELIST_JAMO", "ㅎ,ㅋ").split(",") if s.strip()]
49
+ )
50
+ KEEP_REPEATS = os.environ.get("KEEP_REPEATS", "0") == "1"
51
+ ANTI_SMALLTALK = os.environ.get("ANTI_SMALLTALK", "0") == "1"
52
+ SMALLTALK_TRIES = int(os.environ.get("SMALLTALK_TRIES", "1"))
53
 
54
  META_BANS = ["AI", "인공지능", "챗봇", "도와줄게", "역할"]
55
 
56
  DEFAULT_PROFANITY = {
57
+ "씨발", "시발", "ㅅㅂ", "좆", "좆같", "개같", "개새끼", "개새", "개소리",
58
+ "지랄", "병신", "븅신", "병쉰", "병1신", "염병", "닥쳐", "꺼져", "닥치",
59
+ "", "", "", "ㅈ같", "개지랄", "싫다", "빡친", "개빡", "개빡침",
60
+ "등신", "존나", "미친"
61
  }
62
 
63
  # =========================================================
 
67
  def _pick_attn_impl():
68
  return "flash_attention_2" if USE_FA else "sdpa"
69
 
70
+
71
  def _is_peft_adapter(model_dir: str) -> bool:
72
  return os.path.exists(os.path.join(model_dir, "adapter_config.json"))
73
 
74
+
75
  def _has_full_model(model_dir: str) -> bool:
76
  names = ["pytorch_model.bin", "model.safetensors", "consolidated.safetensors"]
77
  has_weight = any(os.path.exists(os.path.join(model_dir, n)) for n in names)
78
+ has_cfg = os.path.exists(os.path.join(model_dir, "config.json"))
79
  return has_weight and has_cfg
80
 
81
+
82
  def _has_tokenizer_files(path: str) -> bool:
83
  if not path:
84
  return False
85
+ return any(
86
+ os.path.exists(os.path.join(path, n))
87
+ for n in ["tokenizer.model", "tokenizer.json", "vocab.json", "merges.txt"]
88
+ )
89
 
 
 
90
 
91
+ def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
92
  def _try(path, fast):
93
+ return AutoTokenizer.from_pretrained(
94
+ path, trust_remote_code=True, use_fast=fast
95
+ )
96
 
97
  # 1) 로컬 tokenizer.model 우선
98
  if local_dir and os.path.exists(os.path.join(local_dir, "tokenizer.model")):
 
100
  tok = _try(local_dir, False)
101
  if tok.pad_token is None:
102
  tok.pad_token = tok.eos_token
103
+ print(f"🔤 토크나이저 OK: {local_dir} (slow, tokenizer.model)")
104
  return tok
105
  except Exception as e:
106
  print(f"⚠️ local slow 실패: {e}")
 
111
  tok = _try(local_dir, True)
112
  if tok.pad_token is None:
113
  tok.pad_token = tok.eos_token
114
+ print(f"🔤 토크나이저 OK: {local_dir} (fast, tokenizer.json)")
115
  return tok
116
  except Exception as e:
117
  print(f"⚠️ local fast 실패: {e}")
 
122
  tok = _try(fallback_dir, fast)
123
  if tok.pad_token is None:
124
  tok.pad_token = tok.eos_token
125
+ print(f"🔤 토크나이저 OK: {fallback_dir} (fast={fast})")
126
  return tok
127
  except Exception as e:
128
  print(f"⚠️ fallback (fast={fast}) 실패: {e}")
 
130
  raise RuntimeError("토크나이저 로드에 모두 실패했습니다.")
131
 
132
  # =========================================================
133
+ # 모델 로드 (4bit 지원)
134
  # =========================================================
135
 
136
+ def _get_bnb_config():
137
+ if not USE_4BIT:
138
+ return None
139
+ compute_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
140
+ print(f"🧮 4bit 양자화 사용 (compute_dtype={compute_dtype})")
141
+ return BitsAndBytesConfig(
142
+ load_in_4bit=True,
143
+ bnb_4bit_use_double_quant=True,
144
+ bnb_4bit_quant_type="nf4",
145
+ bnb_4bit_compute_dtype=compute_dtype,
146
+ )
147
+
148
+
149
  def load_model_for_chat(model_dir: str, tokenizer_dir: str | None = None):
150
  """
151
  model_dir:
152
+ - 로컬 폴더
153
+ - 또는 Hugging Face Hub repo id (예: 'Jay1121/blossom-kakao-merged')
154
  """
 
155
  if os.path.isdir(model_dir):
156
  print(f"▶ 로컬 모델 폴더: {model_dir}")
157
  is_adapter = _is_peft_adapter(model_dir)
158
+ is_full = _has_full_model(model_dir)
159
  else:
160
  print(f"▶ 로컬 폴더 없음 → HF Hub에서 '{model_dir}' 로드 시도")
161
  is_adapter = False
162
+ is_full = False
163
 
164
  attn_impl = _pick_attn_impl()
165
+ bnb_config = _get_bnb_config()
166
 
167
  # 토크나이저 경로 선택
168
  if tokenizer_dir:
 
175
  print(f"🔎 토크나이저 경로 선택: {tk_dir}")
176
  tok = _load_tokenizer_pref_local(tk_dir, BASE_MODEL_PATH)
177
 
178
+ # 1) PEFT 어댑터만 있는 경우 (로컬에서만 의미)
179
  if is_adapter and not is_full:
180
+ print("📦 감지: PEFT LoRA 어댑터 → 베이스(Bllossom) 로드 후 어댑터 적용")
181
  try:
182
  base = AutoModelForCausalLM.from_pretrained(
183
  BASE_MODEL_PATH,
 
205
  print("✅ 어댑터 병합(merge_and_unload) 완료")
206
  except Exception as e:
207
  print(f"ℹ️ 병합 스킵: {e}")
208
+
209
  model.eval()
210
  print("✅ 모델 로드 완료!")
211
  return model, tok
212
 
213
+ # 2) 병합된 풀 모델 or HF Hub 모델 (4bit 가능)
214
  print("📦 감지: 병합된 '완전체' 모델 또는 HF Hub 모델 → from_pretrained 로 로드")
215
  try:
216
+ if bnb_config is not None:
217
+ model = AutoModelForCausalLM.from_pretrained(
218
+ model_dir,
219
+ device_map="auto",
220
+ trust_remote_code=True,
221
+ attn_implementation=attn_impl,
222
+ quantization_config=bnb_config,
223
+ )
224
+ else:
 
225
  model = AutoModelForCausalLM.from_pretrained(
226
  model_dir,
227
  torch_dtype=torch.float16,
228
  device_map="auto",
229
  trust_remote_code=True,
230
+ attn_implementation=attn_impl,
231
  )
232
+ except Exception as e:
233
+ if attn_impl == "flash_attention_2":
234
+ print(f"⚠️ flash-attn 실패 → SDPA로 전환: {e}")
235
+ if bnb_config is not None:
236
+ model = AutoModelForCausalLM.from_pretrained(
237
+ model_dir,
238
+ device_map="auto",
239
+ trust_remote_code=True,
240
+ attn_implementation="sdpa",
241
+ quantization_config=bnb_config,
242
+ )
243
+ else:
244
+ model = AutoModelForCausalLM.from_pretrained(
245
+ model_dir,
246
+ torch_dtype=torch.float16,
247
+ device_map="auto",
248
+ trust_remote_code=True,
249
+ attn_implementation="sdpa",
250
+ )
251
  else:
252
  raise
253
 
 
268
  print(f"📚 사전 없음: {path} (OOV 검사 약화)")
269
  return set()
270
 
271
+
272
  def load_profanity(path=PROFANITY_PATH):
273
  prof = set(DEFAULT_PROFANITY)
274
  if path and os.path.exists(path):
 
284
  # 전처리 / 검사
285
  # =========================================================
286
 
287
+ RE_LAUGH = re.compile(r"(ㅋ|ㅎ|ㅠ|ㅜ)\1{2,}")
288
+ RE_EN = re.compile(r"[A-Za-z]+")
289
+ RE_WORDS = re.compile(r"[가-힣]{2,}")
290
 
291
  def build_bad_words_ids(tokenizer):
292
+ ids = [
293
+ tokenizer(w, add_special_tokens=False).input_ids
294
+ for w in META_BANS
295
+ ]
296
  for ch in list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"):
297
  ids.append(tokenizer(ch, add_special_tokens=False).input_ids)
298
  if BAN_JAMO:
 
303
  ids.append(tokenizer(ch, add_special_tokens=False).input_ids)
304
  return ids
305
 
306
+
307
  def clean_text(txt: str):
308
  if not KEEP_REPEATS:
309
  txt = RE_LAUGH.sub(lambda m: m.group(1) * 2, txt)
310
+ txt = RE_EN.sub("", txt)
311
  cut = txt.split("### User:")[0]
312
  return cut.strip()
313
 
314
+
315
  def count_oov(txt: str, dictionary, allowlist):
316
  words = RE_WORDS.findall(txt)
317
  oov = [w for w in words if (w not in dictionary and w not in allowlist)]
318
  return len(oov), oov
319
 
320
+
321
  def strip_oov(txt: str, dictionary, allowlist):
322
  kept, i = [], 0
323
  while i < len(txt):
 
331
  kept.append(w)
332
  i = m.end()
333
  out = "".join(kept)
334
+ out = re.sub(r"\s{2,}", " ", out).strip()
335
  return out
336
 
337
  SMALLTALK_PATTERNS = [
338
+ r"오늘\s*날씨",
339
+ r"\b날씨\s*(가|는)?\s*(좋|괜찮|별로|따뜻|쌀쌀|시원|선선)",
340
+ r"(하늘|기온|미세먼지)\s*(이|가)?\s*(좋|맑|깨끗|나쁨|흐림)",
341
+ r"(더워|추워)\b",
342
+ r"비(\s*가)?\s*(온|와|왔|올)\b",
343
  ]
344
  SMALLTALK_REGEXES = [re.compile(p) for p in SMALLTALK_PATTERNS]
345
 
346
  def normalize_for_sim(s: str):
347
+ s = re.sub(r"\s+", "", s)
348
+ s = re.sub(r"[.!?~…]+", "", s)
349
+ s = re.sub(r"(.)\1{2,}", r"\1\1", s)
350
  return s
351
 
352
+
353
  def looks_smalltalk(text: str):
354
  t = normalize_for_sim(text)
355
  if "오늘날씨좋았어" in t:
356
  return True
357
  return any(rx.search(text) for rx in SMALLTALK_REGEXES)
358
 
359
+
360
  def too_similar_to_history(text: str, history_texts, thresh=0.86):
361
  t1 = normalize_for_sim(text)
362
  for h in history_texts:
 
370
  # =========================================================
371
 
372
  DEADPAN_TRIGGERS = [
373
+ "심심", "귀찮", "짜증", "싫", "하..", "휴", "후", "지루",
374
+ "그만", "피곤", "죽였어", "개소리", "뭐래", "에휴", "흥미없",
375
+ "아...", "음....", ";;;;", "어쩌라고", "그건 본인 사정이죠", "그건 니사정이지"
376
  ]
377
 
378
  def should_deadpan(user_text: str):
 
383
  return False
384
  return any(k in user_text for k in DEADPAN_TRIGGERS)
385
 
386
+
387
  def postprocess_deadpan(reply: str):
388
  reply = reply.replace("!", ".")
389
+ reply = re.sub(r"[~…]+", "...", reply)
390
  if len(reply) > 120:
391
+ cut = re.split(r"([.다]\s)", reply, maxsplit=1)
392
  if cut and len("".join(cut[:2])) > 0:
393
  reply = "".join(cut[:2]).strip()
394
  reply = reply[:120].rstrip() + "..."
 
399
  return reply.strip()
400
 
401
  # =========================================================
402
+ # 디코딩
403
  # =========================================================
404
 
405
  def decode_once(model, tok, prompt, bad_words_ids, *, deadpan=False):
406
+ """답변 길이를 줄여서 속도 확보."""
407
  if deadpan:
408
+ cfg = dict(
409
+ do_sample=True,
410
+ temperature=0.25,
411
+ top_p=0.85,
412
+ max_new_tokens=48,
413
+ )
414
  elif STRICT_MODE:
415
+ cfg = dict(
416
+ do_sample=True,
417
+ temperature=0.35,
418
+ top_p=0.88,
419
+ max_new_tokens=56,
420
+ )
421
  else:
422
+ cfg = dict(
423
+ do_sample=True,
424
+ temperature=0.6,
425
+ top_p=0.9,
426
+ max_new_tokens=64,
427
+ )
428
 
429
  inputs = tok(prompt, return_tensors="pt").to(model.device)
430
  with torch.no_grad():
 
435
  eos_token_id=tok.eos_token_id,
436
  pad_token_id=tok.pad_token_id,
437
  bad_words_ids=bad_words_ids,
438
+ **cfg,
439
  )
440
  gen = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
441
  return clean_text(gen)
 
445
  # =========================================================
446
 
447
  SYSTEM_PROMPT = (
448
+ "너는 사용자의 가장 친한 친구야. 20~30대 MZ 말투 섞인 편안한 한국어 구어체로 말해. "
449
  f"영문/불필요한 낱자 자모 금지(허용: {','.join(sorted(WHITELIST_JAMO))}). "
450
  "메타 단어('AI','인공지능','챗봇','도와줄게','역할') 금지. "
451
+ "가끔 시크하게 줄만 대답해도 되고, 너무 설교하지 말고 현실 친구처럼 얘기해.\n\n"
452
  "--- 대화 예시 ---\n"
453
  "User: 넌 누구야?\n"
454
+ "Assistant: 그냥.. 네 얘기 들어주는 친구야.\n"
455
  "User: 무슨 일 해?\n"
456
+ "Assistant: 건 안 하고, 니 하소연 받아주는 정도?\n"
457
  "User: 심심하다\n"
458
+ "Assistant: 음.. 뭐 할래? 넷플? 산책? 아니면 그냥 수다?\n"
459
  "--- 여기까지 예시 ---\n\n"
460
  )
461
 
 
477
  def chat_fn(user_input, history):
478
  # history: 리스트 [(user, bot), ...]
479
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
480
+
481
+ # 속도 위해 최근 2턴만 유지
482
+ for u, b in history[-2:]:
483
  messages.append({"role": "user", "content": u})
484
  messages.append({"role": "assistant", "content": b})
485
  messages.append({"role": "user", "content": user_input})
486
 
487
  prompt = tokenizer.apply_chat_template(
488
+ messages,
489
+ tokenize=False,
490
+ add_generation_prompt=True,
491
  )
492
 
493
  deadpan = should_deadpan(user_input)
 
494
  reply = decode_once(model, tokenizer, prompt, bad_words_ids, deadpan=deadpan)
 
495
 
496
+ oov_cnt, _ = count_oov(reply, dictionary, profanity)
497
  if OOV_STRIP and oov_cnt > 0:
498
  reply = strip_oov(reply, dictionary, profanity)
499
 
 
506
  # Gradio UI
507
  # =========================================================
508
 
509
+ CUSTOM_CSS = """
510
+ .gradio-container {
511
+ font-family: "Noto Sans KR", system-ui, sans-serif;
512
+ }
513
+
514
+ /* 유저 메시지 텍스트를 진한 검정으로 */
515
+ .message.user,
516
+ .user .message,
517
+ .chat-message.user,
518
+ .gr-chatbot .message.user,
519
+ .gr-chatbot .user {
520
+ color: #111111 !important;
521
+ }
522
+ """
523
+
524
  demo = gr.ChatInterface(
525
  fn=chat_fn,
526
+ title="어느 MZ 친구의 느린 DM방",
527
+ description=(
528
+ "Blossom 8B + 카카오톡 말투 LoRA를 얹은, 어떤 MZ의 말투를 따라하는 한국어 친구 챗봇입니다.\n"
529
+ "(⚠️ 개 느림주의: 대답 늦어도 서운해하지 말 것)"
530
+ ),
531
+ examples=[
532
+ "야 나 오늘 개피곤하다",
533
+ "이직할까 말까 고민중이야",
534
+ "나 좀 칭찬해줘",
535
+ ],
536
+ css=CUSTOM_CSS,
537
  )
538
 
539
  if __name__ == "__main__":