Jay1121 commited on
Commit
da9f771
·
verified ·
1 Parent(s): 93cdf4b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -63
app.py CHANGED
@@ -1,70 +1,397 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ # app.py — SOLAR 10.7B 친구 챗봇 (Gradio, 경량 설정)
3
+
4
+ import os, re, random, difflib, torch
5
+ from datetime import datetime
6
+ try:
7
+ from zoneinfo import ZoneInfo
8
+ except Exception:
9
+ ZoneInfo = None
10
+
11
  import gradio as gr
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from peft import PeftModel
14
+
15
+ BASE_MODEL_PATH = "Upstage/SOLAR-10.7B-Instruct-v1.0"
16
+
17
+ # =========================
18
+ # 환경 변수 / 기본값 설정
19
+ # =========================
20
+
21
+ # Hugging Face / Colab 공통: 모델 폴더 경로
22
+ # - 코랩: /content/my-solar-chatbot-merged
23
+ # - Space: ./my-solar-chatbot-merged (repo 안에 모델 폴더 넣었을 때)
24
+ MODEL_DIR = os.environ.get("MODEL_DIR", "/content/my-solar-chatbot-merged")
25
+
26
+ # 사전/욕설 경로 (Space에는 ./dictionaries 안에 같이 올리면 됨)
27
+ DICT_PATH = os.environ.get("DICT_PATH", "./dictionaries/korean_words.txt")
28
+ PROFANITY_PATH = os.environ.get("PROFANITY_PATH", "")
29
+
30
+ # 속도/품질 옵션 (기본은 빠르게 쪽으로)
31
+ OOV_THRESHOLD = int(os.environ.get("OOV_THRESHOLD", "0"))
32
+ OOV_STRIP = os.environ.get("OOV_STRIP","1") == "1"
33
+ STRICT_MODE = os.environ.get("STRICT_MODE","0") == "1" # 기본 OFF
34
+ SAFETY_ON = os.environ.get("SAFETY_ON","0") == "1" # 기본 OFF
35
+ BAN_JAMO = os.environ.get("BAN_JAMO","1") == "1"
36
+ USE_FA = os.environ.get("USE_FLASH_ATTN","1") == "1"
37
+
38
+ STYLE_MODE = os.environ.get("STYLE_MODE","auto") # auto | deadpan | neutral
39
+ WHITELIST_JAMO = set([s.strip() for s in os.environ.get("WHITELIST_JAMO","ㅎ,ㅋ").split(",") if s.strip()])
40
+ KEEP_REPEATS = os.environ.get("KEEP_REPEATS","0") == "1"
41
+
42
+ ANTI_SMALLTALK = os.environ.get("ANTI_SMALLTALK","0") == "1" # 기본 OFF
43
+ SMALLTALK_TRIES= int(os.environ.get("SMALLTALK_TRIES","1"))
44
+
45
+ META_BANS = ["AI","인공지능","챗봇","도와줄게","역할"]
46
+
47
+ DEFAULT_PROFANITY = {
48
+ "씨발","시발","ㅅㅂ","좆","좆같","개같","개새끼","개새","개소리","지랄",
49
+ "병신","븅신","병쉰","병1신","염병","닥쳐","꺼져","닥치","ㅄ","ㅗ","씹",
50
+ "ㅈ같","개지랄","싫다","빡친","개빡","개빡침","등신","존나","미친"
51
+ }
52
+
53
+ # =========================
54
+ # 로더 보조
55
+ # =========================
56
+
57
+ def _pick_attn_impl():
58
+ return "flash_attention_2" if USE_FA else "sdpa"
59
+
60
+ def _is_peft_adapter(model_dir: str) -> bool:
61
+ return os.path.exists(os.path.join(model_dir, "adapter_config.json"))
62
+
63
+ def _has_full_model(model_dir: str) -> bool:
64
+ names = ["pytorch_model.bin", "model.safetensors", "consolidated.safetensors"]
65
+ has_weight = any(os.path.exists(os.path.join(model_dir, n)) for n in names)
66
+ has_cfg = os.path.exists(os.path.join(model_dir, "config.json"))
67
+ return has_weight and has_cfg
68
+
69
+ def _has_tokenizer_files(path: str) -> bool:
70
+ if not path: return False
71
+ return any(os.path.exists(os.path.join(path, n)) for n in [
72
+ "tokenizer.model","tokenizer.json","vocab.json","merges.txt"
73
+ ])
74
+
75
+ def _load_tokenizer_pref_local(local_dir: str, fallback_dir: str):
76
+ tried = []
77
+ def _try(path, fast):
78
+ tried.append(f"{path} (fast={fast})")
79
+ return AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=fast)
80
+
81
+ if local_dir and os.path.exists(os.path.join(local_dir, "tokenizer.model")):
82
+ try:
83
+ tok = _try(local_dir, False)
84
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
85
+ print(f"🔤 토크나이저 OK: {local_dir} (use_fast=False, tokenizer.model)")
86
+ return tok
87
+ except Exception as e:
88
+ print(f"⚠️ local slow 실패: {e}")
89
+
90
+ if local_dir and os.path.exists(os.path.join(local_dir, "tokenizer.json")):
91
+ try:
92
+ tok = _try(local_dir, True)
93
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
94
+ print(f"🔤 토크나이저 OK: {local_dir} (use_fast=True, tokenizer.json)")
95
+ return tok
96
+ except Exception as e:
97
+ print(f"⚠️ local fast 실패: {e}")
98
+
99
+ for fast in (True, False):
100
+ try:
101
+ tok = _try(fallback_dir, fast)
102
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
103
+ print(f"🔤 토크나이저 OK: {fallback_dir} (use_fast={fast})")
104
+ return tok
105
+ except Exception as e:
106
+ print(f"⚠️ fallback (fast={fast}) 실패: {e}")
107
+
108
+ raise RuntimeError("토크나이저 로드에 모두 실패했습니다.")
109
+
110
+ def load_model_for_chat(model_dir: str, tokenizer_dir: str | None = None):
111
+ # Space에서는 모델 폴더를 repo 안에 그대로 넣는다고 가정 → 로컬 디렉토리
112
+ if not os.path.isdir(model_dir):
113
+ raise FileNotFoundError(f"모델 폴더를 찾을 수 없습니다: {model_dir}")
114
+ print(f"▶ 모델 폴더: {model_dir}")
115
+
116
+ attn_impl = _pick_attn_impl()
117
+ is_adapter = _is_peft_adapter(model_dir)
118
+ is_full = _has_full_model(model_dir)
119
+
120
+ tk_dir = tokenizer_dir if tokenizer_dir else (model_dir if _has_tokenizer_files(model_dir) else BASE_MODEL_PATH)
121
+ print(f"🔎 토크나이저 경로 선택: {tk_dir}")
122
+ tok = _load_tokenizer_pref_local(tk_dir, BASE_MODEL_PATH)
123
+
124
+ if is_adapter and not is_full:
125
+ print("📦 감지: PEFT LoRA 어댑터 → 베이스(SOLAR) 로드 후 어댑터 적용")
126
+ try:
127
+ base = AutoModelForCausalLM.from_pretrained(
128
+ BASE_MODEL_PATH, torch_dtype=torch.float16,
129
+ device_map="auto", trust_remote_code=True, attn_implementation=attn_impl
130
+ )
131
+ except Exception as e:
132
+ if attn_impl == "flash_attention_2":
133
+ print(f"⚠️ flash-attn 실패 → SDPA로 전환: {e}")
134
+ base = AutoModelForCausalLM.from_pretrained(
135
+ BASE_MODEL_PATH, torch_dtype=torch.float16,
136
+ device_map="auto", trust_remote_code=True, attn_implementation="sdpa"
137
+ )
138
+ else:
139
+ raise
140
+ model = PeftModel.from_pretrained(base, model_dir, offload_folder="offload")
141
+ try:
142
+ model = model.merge_and_unload()
143
+ print("✅ 어댑터 병합(merge_and_unload) 완료")
144
+ except Exception as e:
145
+ print(f"ℹ️ 병합 스킵: {e}")
146
+ model.eval()
147
+ print("✅ 모델 로드 완료!")
148
+ return model, tok
149
+
150
+ print("📦 감지: 병합된 '완전체' 모델 또는 일반 폴더 → 해당 폴더에서 직접 로드")
151
+ try:
152
+ model = AutoModelForCausalLM.from_pretrained(
153
+ model_dir, torch_dtype=torch.float16,
154
+ device_map="auto", trust_remote_code=True, attn_implementation=attn_impl
155
+ )
156
+ except Exception as e:
157
+ if attn_impl == "flash_attention_2":
158
+ print(f"⚠️ flash-attn 실패 → SDPA로 전환: {e}")
159
+ model = AutoModelForCausalLM.from_pretrained(
160
+ model_dir, torch_dtype=torch.float16,
161
+ device_map="auto", trust_remote_code=True, attn_implementation="sdpa"
162
+ )
163
+ else:
164
+ raise
165
+ model.eval()
166
+ print("✅ 모델 로드 완료!")
167
+ return model, tok
168
+
169
+ # =========================
170
+ # 사전 / 욕설
171
+ # =========================
172
+
173
+ def load_dictionary(path=DICT_PATH):
174
+ if os.path.exists(path):
175
+ with open(path, "r", encoding="utf-8") as f:
176
+ words = set(w.strip() for w in f if w.strip())
177
+ print(f"📚 사전 로드: {path} (단어 {len(words)}개)")
178
+ return words
179
+ print(f"📚 사전 없음: {path} (OOV 검사 약화)")
180
+ return set()
181
+
182
+ def load_profanity(path=PROFANITY_PATH):
183
+ prof = set(DEFAULT_PROFANITY)
184
+ if path and os.path.exists(path):
185
+ with open(path, "r", encoding="utf-8") as f:
186
+ for line in f:
187
+ w = line.strip()
188
+ if w: prof.add(w)
189
+ print(f"📝 욕설 화이트리스트 추가 로드: {path}")
190
+ return prof
191
+
192
+ # =========================
193
+ # 전처리 / 검사
194
+ # =========================
195
+
196
+ RE_LAUGH = re.compile(r'(ㅋ|ㅎ|ㅠ|ㅜ)\1{2,}')
197
+ RE_EN = re.compile(r'[A-Za-z]+')
198
+ RE_WORDS = re.compile(r'[가-힣]{2,}')
199
+
200
+ def build_bad_words_ids(tokenizer):
201
+ ids = [tokenizer(w, add_special_tokens=False).input_ids for w in META_BANS]
202
+ for ch in list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"):
203
+ ids.append(tokenizer(ch, add_special_tokens=False).input_ids)
204
+ if BAN_JAMO:
205
+ for code in list(range(0x1100, 0x11FF+1)) + list(range(0x3130, 0x318F+1)):
206
+ ch = chr(code)
207
+ if ch in WHITELIST_JAMO:
208
+ continue
209
+ ids.append(tokenizer(ch, add_special_tokens=False).input_ids)
210
+ return ids
211
+
212
+ def clean_text(txt: str):
213
+ if not KEEP_REPEATS:
214
+ txt = RE_LAUGH.sub(lambda m: m.group(1)*2, txt)
215
+ txt = RE_EN.sub('', txt)
216
+ cut = txt.split("### User:")[0]
217
+ return cut.strip()
218
+
219
+ def count_oov(txt: str, dictionary, allowlist):
220
+ words = RE_WORDS.findall(txt)
221
+ oov = [w for w in words if (w not in dictionary and w not in allowlist)]
222
+ return len(oov), oov
223
+
224
+ def strip_oov(txt: str, dictionary, allowlist):
225
+ kept, i = [], 0
226
+ while i < len(txt):
227
+ m = RE_WORDS.search(txt, i)
228
+ if not m:
229
+ kept.append(txt[i:]); break
230
+ kept.append(txt[i:m.start()])
231
+ w = m.group(0)
232
+ if (w in dictionary) or (w in allowlist):
233
+ kept.append(w)
234
+ i = m.end()
235
+ out = "".join(kept)
236
+ out = re.sub(r'\s{2,}', ' ', out).strip()
237
+ return out
238
+
239
+ SMALLTALK_PATTERNS = [
240
+ r'오늘\s*날씨', r'\b날씨\s*(가|는)?\s*(좋|괜찮|별로|따뜻|쌀쌀|시원|선선)',
241
+ r'(하늘|기온|미세먼지)\s*(이|가)?\s*(좋|맑|깨끗|나쁨|흐림)',
242
+ r'(더워|추워)\b', r'비(\s*가)?\s*(온|와|왔|올)\b'
243
+ ]
244
+ SMALLTALK_REGEXES = [re.compile(p) for p in SMALLTALK_PATTERNS]
245
+
246
+ def normalize_for_sim(s: str):
247
+ s = re.sub(r'\s+','',s)
248
+ s = re.sub(r'[.!?~…]+','',s)
249
+ s = re.sub(r'(.)\1{2,}', r'\1\1', s)
250
+ return s
251
+
252
+ def looks_smalltalk(text: str):
253
+ t = normalize_for_sim(text)
254
+ if "오늘날씨좋았어" in t:
255
+ return True
256
+ return any(rx.search(text) for rx in SMALLTALK_REGEXES)
257
+
258
+ def too_similar_to_history(text: str, history_texts, thresh=0.86):
259
+ t1 = normalize_for_sim(text)
260
+ for h in history_texts:
261
+ t2 = normalize_for_sim(h)
262
+ if difflib.SequenceMatcher(None, t1, t2).ratio() >= thresh:
263
+ return True
264
+ return False
265
+
266
+ # =========================
267
+ # 데드팬
268
+ # =========================
269
+
270
+ DEADPAN_TRIGGERS = [
271
+ "심심","귀찮","짜증","싫","하..","휴","후","지루","그만","피곤","죽였어","개소리","뭐래","에휴","흥미없",
272
+ "아...", "음....", ";;;;", "어쩌라고", "그건 본인 사정이죠", "그건 니사정이지"
273
+ ]
274
+
275
+ def should_deadpan(user_text: str):
276
+ mode = STYLE_MODE
277
+ if mode == "deadpan":
278
+ return True
279
+ if mode == "neutral":
280
+ return False
281
+ return any(k in user_text for k in DEADPAN_TRIGGERS)
282
+
283
+ def postprocess_deadpan(reply: str):
284
+ reply = reply.replace("!", ".")
285
+ reply = re.sub(r'[~…]+', '...', reply)
286
+ if len(reply) > 120:
287
+ cut = re.split(r'([.다]\s)', reply, maxsplit=1)
288
+ if cut and len("".join(cut[:2])) > 0:
289
+ reply = "".join(cut[:2]).strip()
290
+ reply = reply[:120].rstrip() + "..."
291
+ if not reply.startswith(("음", "아니", "흠", "글쎄")):
292
+ reply = random.choice(["음.. ","아니.. ","흠.. ","글쎄.. "]) + reply
293
+ if random.random() < 0.3 and not reply.endswith(("..","...",".")):
294
+ reply = reply + "..."
295
+ return reply.strip()
296
+
297
+ # =========================
298
+ # 디코딩 (경량화)
299
+ # =========================
300
+
301
+ def decode_once(model, tok, prompt, bad_words_ids, *, deadpan=False):
302
+ # max_new_tokens 줄여서 속도 확보
303
+ if deadpan:
304
+ cfg = dict(do_sample=True, temperature=0.25, top_p=0.85, max_new_tokens=96)
305
+ elif STRICT_MODE:
306
+ cfg = dict(do_sample=True, temperature=0.35, top_p=0.88, max_new_tokens=128)
307
+ else:
308
+ cfg = dict(do_sample=True, temperature=0.5, top_p=0.9, max_new_tokens=128)
309
+
310
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
311
+ with torch.no_grad():
312
+ out = model.generate(
313
+ **inputs,
314
+ repetition_penalty=1.12,
315
+ no_repeat_ngram_size=3,
316
+ eos_token_id=tok.eos_token_id,
317
+ pad_token_id=tok.pad_token_id,
318
+ bad_words_ids=bad_words_ids,
319
+ **cfg
320
+ )
321
+ gen = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
322
+ return clean_text(gen)
323
+
324
+ # =========================
325
+ # 시스템 프롬프트
326
+ # =========================
327
+
328
+ SYSTEM_PROMPT = (
329
+ "너는 사용자의 가장 친한 친구야. 순수한 한글 구어체로 말해. "
330
+ f"영문/불필요한 낱자 자모 금지(허용: {','.join(sorted(WHITELIST_JAMO))}). "
331
+ "메타 단어('AI','인공지능','챗봇','도와줄게','역할') 금지. "
332
+ "필요하면 짧고 건조하게 답해도 돼.\n\n"
333
+ "--- 대화 예시 ---\n"
334
+ "User: 넌 누구야?\n"
335
+ "Assistant: 나는.. 인간이야..\n"
336
+ "User: 무슨 일 해?\n"
337
+ "Assistant: 그냥 있어..\n"
338
+ "User: 심심하다\n"
339
+ "Assistant: 음.. 뭐 할래? 산책?\n"
340
+ "--- 여기까지 예시 ---\n\n"
341
  )
342
 
343
+ # =========================
344
+ # 전역 초기화
345
+ # =========================
346
+
347
+ print("🚀 모델/토크나이저 로드 중...")
348
+ model, tokenizer = load_model_for_chat(MODEL_DIR, tokenizer_dir=None)
349
+ dictionary = load_dictionary()
350
+ profanity = load_profanity()
351
+ bad_words_ids = build_bad_words_ids(tokenizer)
352
+ print("✅ 초기화 완료")
353
+
354
+ # =========================
355
+ # Gradio 챗 함수
356
+ # =========================
357
+
358
+ def chat_fn(user_input, history):
359
+ # history: 리스트 [(user, bot), ...]
360
+ messages = [{"role":"system","content":SYSTEM_PROMPT}]
361
+ for u, b in history[-5:]: # 최근 5턴만 사용
362
+ messages.append({"role":"user","content":u})
363
+ messages.append({"role":"assistant","content":b})
364
+ messages.append({"role":"user","content":user_input})
365
+
366
+ prompt = tokenizer.apply_chat_template(
367
+ messages, tokenize=False, add_generation_prompt=True
368
+ )
369
 
370
+ deadpan = should_deadpan(user_input)
371
+
372
+ # 1회 생성 (재시도 없음, 기본은 SAFETY_OFF)
373
+ reply = decode_once(model, tokenizer, prompt, bad_words_ids, deadpan=deadpan)
374
+ oov_cnt, _ = count_oov(reply, dictionary, profanity)
375
+
376
+ # 필요시 OOV 제거
377
+ if OOV_STRIP and oov_cnt > 0:
378
+ reply = strip_oov(reply, dictionary, profanity)
379
+
380
+ if deadpan:
381
+ reply = postprocess_deadpan(reply)
382
+
383
+ return reply
384
+
385
+ # =========================
386
+ # Gradio UI
387
+ # =========================
388
+
389
+ demo = gr.ChatInterface(
390
+ fn=chat_fn,
391
+ title="SOLAR 친구 챗봇",
392
+ description="SOLAR-10.7B 기반 한글 친구 챗봇 (가벼운 설정)",
393
+ examples=["야 나 오늘 개피곤하다", "이직할까 말까 고민중이야", "나 좀 칭찬해줘"],
394
+ )
395
 
396
  if __name__ == "__main__":
397
  demo.launch()