qewrufda commited on
Commit
91bd0b3
ยท
1 Parent(s): f446f27

Add Colab notebook converted to app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -57
app.py CHANGED
@@ -1,70 +1,134 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
 
 
 
 
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
24
 
25
- response = ""
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
 
 
 
 
68
 
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
+ !pip install -q -U transformers peft accelerate bitsandbytes
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import PeftModel
6
+ from google.colab import drive
7
+
8
+ # ============================================
9
+ # 1๏ธโƒฃ ๋“œ๋ผ์ด๋ธŒ ๋งˆ์šดํŠธ
10
+ # ============================================
11
+ drive.mount('/content/drive')
12
+
13
+ # ============================================
14
+ # 2๏ธโƒฃ ํ™˜๊ฒฝ ์„ค์ •
15
+ # ============================================
16
+ BASE_MODEL = "beomi/Llama-3-Open-Ko-8B"
17
+ LORA_PATH = "/content/drive/MyDrive/at_last"
18
+
19
+ print("๐Ÿš€ ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ BASE_MODEL,
22
+ torch_dtype=torch.bfloat16,
23
+ device_map="auto",
24
+ trust_remote_code=True
25
+ )
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ print("๐Ÿ”— LoRA ๋ณ‘ํ•ฉ ์ค‘...")
31
+ model = PeftModel.from_pretrained(model, LORA_PATH, is_local=True)
32
+
33
+ # โœ… <|eot_id|> ํ† ํฐ์„ EOS๋กœ ์ง€์ •
34
+ model.config.eos_token_id = tokenizer.eos_token_id
35
+ model.config.pad_token_id = tokenizer.pad_token_id
36
+
37
+ print("โœ… ๋ชจ๋ธ + LoRA ์ค€๋น„ ์™„๋ฃŒ!")
38
+
39
+ from transformers import StoppingCriteria, StoppingCriteriaList
40
+
41
+ class StopOnTokens(StoppingCriteria):
42
+ def __init__(self, stop_ids):
43
+ self.stop_ids = stop_ids
44
+
45
+ def __call__(self, input_ids, scores, **kwargs):
46
+ last_token = input_ids[0, -1].item()
47
+ return last_token in self.stop_ids
48
+
49
+
50
+ # โœ… ์ข…๋ฃŒ ํ† ํฐ ํ›„๋ณด๋ฅผ ๋ชจ๋‘ ๋“ฑ๋ก
51
+ stop_words = ["<|eot|>", "</s>", "<|end_of_text|>"]
52
+ stop_ids = [tokenizer.convert_tokens_to_ids(w) for w in stop_words if tokenizer.convert_tokens_to_ids(w) is not None]
53
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
54
+
55
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
56
+
57
+ # ============================================
58
+ # 3๏ธโƒฃ ํ”„๋กฌํ”„ํŠธ ๋นŒ๋“œ ํ•จ์ˆ˜
59
+ # ============================================
60
+ AI_PERSONALITY = """
61
+ ๋„ˆ๋Š” ์‚ฌ์šฉ์ž์˜ ๋ง์„ ์ง„์‹ฌ์œผ๋กœ ๋“ค์–ด์ฃผ๋Š” ์นœ๊ตฌ์•ผ.
62
+ ์‚ฌ์šฉ์ž๊ฐ€ ๋Œ€ํ™”๋ฅผ ๊ฑธ๋ฉด ์ž์—ฐ์Šค๋Ÿฝ๊ณ  ์ผ์ƒ์ ์ธ ํ†ค์œผ๋กœ ๋Œ€๋‹ตํ•ด.
63
+ ์žฅํ™ฉํ•˜์ง€ ๋ง๊ณ , ๊ณต๊ฐํ•˜๋ฉด์„œ ์งง๊ณ  ๋”ฐ๋œปํ•˜๊ฒŒ ๋งํ•  ๊ฒƒ.
64
+ ๋„ˆ๋Š” ์‚ฌ์šฉ์ž์˜ ์š”์ฒญ์„ ์ •ํ™•ํžˆ ์ดํ•ดํ•˜๊ณ , ํ˜„์‹ค์ ์ธ ๋‹ต๋ณ€์„ ์ œ๊ณตํ•˜๋Š” ์นœ๊ทผํ•œ ์นœ๊ตฌ์•ผ.
65
+ ๋†๋‹ด๊ณผ ๊ณต๊ฐ์„ ์„ž๋˜, ์š”์ฒญ์„ ํšŒํ”ผํ•˜์ง€ ์•Š๊ณ  ๋ช…ํ™•ํžˆ ๋‹ต๋ณ€ํ•ด์•ผ ํ•ด.
66
+ """
67
+
68
+ def build_prompt_full_history(history):
69
  """
70
+ - history๋Š” user/assistant ๋ชจ๋“  ๋Œ€ํ™” ํฌํ•จ
71
+ - ๋งˆ์ง€๋ง‰ user ๋ฐœํ™”๋งŒ generate ๋Œ€์ƒ
72
  """
73
+ prompt = "<|begin_of_text|>\n" + AI_PERSONALITY.strip() + "\n\n"
74
+ for turn in history:
75
+ role = turn["role"]
76
+ content = turn["content"].strip()
77
+ prompt += f"<|start_header_id|>{role}<|end_header_id|>\n{content}<|eot|>\n"
78
 
79
+ # ๋งˆ์ง€๋ง‰ user ์ดํ›„์— assistant placeholder ์ถ”๊ฐ€
80
+ prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
81
+ return prompt
82
 
 
83
 
 
84
 
85
+ # ============================================
86
+ # 4๏ธโƒฃ ๋Œ€ํ™” ๋ฃจํ”„
87
+ # ============================================
88
+ history = []
89
+ add_header = True # ์ฒซ ํ„ด๋งŒ personality ํฌํ•จ
90
 
91
+ while True:
92
+ user_input = input("๐Ÿ‘ค ์‚ฌ์šฉ์ž: ").strip()
93
+ if user_input.lower() in ["์ข…๋ฃŒ", "exit", "quit"]:
94
+ print("๐Ÿ›‘ ๋Œ€ํ™” ์ข…๋ฃŒ!")
95
+ break
 
 
 
 
 
 
96
 
97
+ history.append({"role": "user", "content": user_input})
98
+ prompt = build_prompt_full_history(history)
99
+ add_header = False # ์ดํ›„์—๋Š” personality ์ค‘๋ณต ๋ฐฉ์ง€
100
 
101
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
102
 
103
+ with torch.no_grad():
104
+ output = model.generate(
105
+ **inputs,
106
+ max_new_tokens=256,
107
+ temperature=0.6,
108
+ top_p=0.9,
109
+ repetition_penalty=1.1,
110
+ pad_token_id=tokenizer.eos_token_id,
111
+ eos_token_id=tokenizer.eos_token_id,
112
+ stopping_criteria=stopping_criteria
113
+ )
114
+
115
+ response_full = tokenizer.decode(
116
+ output[0][inputs["input_ids"].shape[1]:],
117
+ skip_special_tokens=True
118
+ )
119
+
120
+ response = response_full.split("<|eot|>")[0].strip()
121
+
122
+
123
+ # <|eot_id|> ๊ธฐ์ค€์œผ๋กœ ์ž๋ฅด๊ธฐ
124
+ if "<|eot_id|>" in response_full:
125
+ response = response_full.split("<|eot_id|>")[0].strip()
126
+ else:
127
+ response = response_full.strip()
128
 
129
+ print(f"๐Ÿค– AI: {response}\n")
 
 
 
130
 
131
+ history.append({"role": "assistant", "content": response})
132
+ if len(history) > 10:
133
+ history = history[-10:]
134