Jeong-hun Kim commited on
Commit
d503312
ยท
1 Parent(s): 9e5793d

model test

Browse files
Files changed (2) hide show
  1. app/main.py +67 -153
  2. requirements.txt +0 -0
app/main.py CHANGED
@@ -1,181 +1,95 @@
1
- from fastapi import FastAPI, Request
2
- from pydantic import BaseModel
3
- from typing import List
4
- from transformers import pipeline
5
- from PIL import Image
6
- import re, os
7
  import gradio as gr
8
  import torch
9
 
10
  app = FastAPI()
11
 
12
- # 1. LLM ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (SmolLM3 ๋ชจ๋ธ)
13
  print("[torch] is available:", torch.cuda.is_available())
14
  print("[device] default:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
15
- llm = pipeline("text-generation", model="HuggingFaceTB/SmolLM3-3B", device=0 if torch.cuda.is_available() else -1)
16
 
17
- # 2. ๊ฐ์ • ๋ฐ ์ƒํ™ฉ๋ณ„ ์ด๋ฏธ์ง€ ๋งคํ•‘
18
- '''
19
- ์ด๋ฏธ์ง€ ๋งคํ•‘ ์˜ˆ์‹œ
20
- -----------------------------
21
- ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ง‘์–ด๋„ฃ์œผ๋ฉด (./asset/face/)
22
- happy.png
23
- sad.png
24
- angry.png
25
- =>
26
- ์ด๋Ÿฐ ๋”•์…”๋„ˆ๋ฆฌ ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜๋จ
27
- {
28
- "happy": "happy.png",
29
- "sad": "sad.png",
30
- "angry": "angry.png"
31
- }
32
- =>
33
- ๋ชจ๋ธ ์ถœ๋ ฅ์˜ ๊ฐ์ • ๋ถ€๋ถ„์— ๋Œ€์‘๋˜๋Š” ์ด๋ฏธ์ง€ ์ถœ๋ ฅ
34
- '''
35
- def load_faces(face_dir="assets/face"):
36
- if not os.path.exists(face_dir):
37
- os.makedirs(face_dir)
38
- emotion_to_face = {}
39
- for filename in os.listdir(face_dir):
40
- if filename.endswith(".png"):
41
- emotion = os.path.splitext(filename)[0] # 'happy.png' โ†’ 'happy'
42
- emotion_to_face[emotion] = filename # "happy": "happy.png"
43
- return emotion_to_face
44
-
45
- def load_bgs(bg_dir="assets/bg"):
46
- if not os.path.exists(bg_dir):
47
- os.makedirs(bg_dir)
48
- situation_to_bg = {}
49
- for filename in os.listdir(bg_dir):
50
- if filename.endswith(".png"):
51
- emotion = os.path.splitext(filename)[0] # 'happy.png' โ†’ 'happy'
52
- situation_to_bg[emotion] = filename # "happy": "happy.png"
53
- return situation_to_bg
54
-
55
- emotion_to_face = load_faces()
56
- situation_to_bg = load_bgs()
57
-
58
- # 3. ์ถœ๋ ฅ ๋ผ์ธ ํŒŒ์‹ฑ ํ•จ์ˆ˜
59
- def parse_output(text: str):
60
- pattern = r'"(.*?)"\s*\(emotion:\s*(\w+),\s*situation:\s*(\w+)\)'
61
- results = []
62
- for line in text.strip().split('\n'):
63
- match = re.match(pattern, line.strip())
64
- if match:
65
- results.append({
66
- "text": match.group(1),
67
- "emotion": match.group(2),
68
- "situation": match.group(3)
69
- })
70
- return results
71
-
72
- # 4. ์ด๋ฏธ์ง€ ํ•ฉ์„ฑ ํ•จ์ˆ˜
73
- def combine_images(bg_path, face_path):
74
- try:
75
- bg = Image.open(bg_path).convert("RGBA")
76
- except FileNotFoundError:
77
- print(f"[warning] ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์—†์Œ: {bg_path}")
78
- return None
79
- try:
80
- face = Image.open(face_path).convert("RGBA")
81
- except FileNotFoundError:
82
- print(f"[warning] ์บ๋ฆญํ„ฐ ์ด๋ฏธ์ง€ ์—†์Œ: {face_path}")
83
- return None
84
- # ์ด๋ฏธ์ง€ ํ•ฉ์„ฑ
85
- bg.paste(face, (0, 0), face)
86
- return bg
87
 
88
- # 5. ์ฑ—๋ด‡ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜ (Gradio์šฉ)
89
- '''
90
- ์ง€๊ธˆ๊นŒ์ง€ ๋Œ€ํ™” ๋‚ด์šฉ์„ ๋ชจ๋‘ ํ”„๋กฌํ”„ํŠธ๋กœ ๋„ฃ์–ด์„œ ๋Œ€ํ™”๋‚ด์šฉ์„ ๊ธฐ์–ตํ•˜๋„๋ก ํ•จ
91
- '''
92
- def build_prompt(chat_history, user_msg):
93
- system_prompt = (
94
- "You are Aria, a cheerful and expressive fantasy mage."
95
- " Respond in multiple steps if needed."
96
- " Format: \"text\" (emotion: tag, situation: tag)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
 
 
98
 
99
- dialogue = ""
100
- for item in chat_history:
101
- if item["role"] == "user":
102
- dialogue += f"User: {item['text']}\n"
103
- elif item["role"] == "bot":
104
- dialogue += f"Aria: {item['text']}\n"
105
-
106
- dialogue += f"User: {user_msg}\nAria:"
107
- return system_prompt + "\n" + dialogue
108
-
109
- def character_chat(prompt):
110
- full_prompt = build_prompt(chat_history, prompt)
111
-
112
- #raw_output = llm(full_prompt, max_new_tokens=300)[0]['generated_text']
113
- raw_output = '"์šฐ์˜ค์•„" (emotion: tag, situation: tag)'
114
- parsed = parse_output(raw_output)
115
-
116
- result_outputs = []
117
- for i, item in enumerate(parsed):
118
- face = emotion_to_face.get(item['emotion'], "neutral.png")
119
- bg = situation_to_bg.get(item['situation'], "default.jpg")
120
- composite = combine_images(os.path.join("assets/bg", bg), os.path.join("assets/face", face))
121
- img_path = None #์ด๋ฏธ์ง€๊ฐ€ ์—†์œผ๋ฉด ์ถœ๋ ฅ ์•ˆํ•จ
122
- if composite:
123
- img_path = f"static/output_{i}.png"
124
- composite.save(img_path)
125
- result_outputs.append((item['text'], img_path))
126
-
127
- return result_outputs
128
-
129
- # 6. Gradio UI with chat history
130
- chat_history = []
131
-
132
  with gr.Blocks(css="""
133
  .chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; }
134
  .bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; }
135
  .bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
136
- .image-preview { margin: 5px 0; max-width: 100%; border-radius: 10px; }
137
  """) as demo:
138
- gr.Markdown("์ฑ—๋ด‡")
139
  with gr.Column():
140
- chat_output = gr.HTML(value="<div class='chat-box' id='chat-box'></div>")
141
- user_input = gr.Textbox(label="Say something to Aria", placeholder="Type here and press Enter")
142
 
143
  def render_chat():
144
  html = ""
145
  for item in chat_history:
146
- if item['role'] == 'user':
147
  html += f"<div class='bubble-right'>{item['text']}</div>"
148
- elif item['role'] == 'bot':
149
- bubble = f"<div class='bubble-left'>{item['text']}"
150
- if 'image' in item and item['image']:
151
- bubble += f"<br><img class='image-preview' src='{item['image']}'>"
152
- bubble += "</div>"
153
- html += bubble
154
- return html
155
 
156
  def on_submit(user_msg):
157
  chat_history.append({"role": "user", "text": user_msg})
 
 
 
 
158
 
159
- bot_results = character_chat(user_msg)
160
-
161
- for item in bot_results:
162
- try:
163
- text, image_path = item # unpack ์‹œ๋„
164
- except (ValueError, TypeError):
165
- # unpack ์•ˆ๋˜๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์ฒ˜๋ฆฌ (์ด๋ฏธ์ง€ ์—†์ด)
166
- text = str(item)
167
- image_path = None
168
-
169
- chat_entry = {"role": "bot", "text": text}
170
- if image_path:
171
- chat_entry["image"] = image_path
172
-
173
- chat_history.append(chat_entry)
174
-
175
- new_chat_html = render_chat()
176
- return f"<div class='chat-box' id='chat-box'>{new_chat_html}</div>", ""
177
-
178
- user_input.submit(on_submit, inputs=user_input, outputs=[chat_output, user_input])
179
 
180
  if __name__ == "__main__":
181
- demo.launch()
 
1
+ from fastapi import FastAPI
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
3
  import gradio as gr
4
  import torch
5
 
6
  app = FastAPI()
7
 
 
8
  print("[torch] is available:", torch.cuda.is_available())
9
  print("[device] default:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
10
 
11
+ # ๋ชจ๋ธ ๋กœ๋“œ
12
+ # https://huggingface.co/EleutherAI/polyglot-ko-1.3b
13
+ model_id = "EleutherAI/polyglot-ko-1.3b"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ model = AutoModelForCausalLM.from_pretrained(model_id)
16
+ llm = pipeline(
17
+ "text-generation",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ device=0
21
+ )
22
+
23
+ # ์ฑ—๋ด‡ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
24
+ chat_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def build_prompt(history, user_msg):
27
+ prompt = (
28
+ "[์‹œ์ž‘]\n"
29
+ "๋‹น์‹ ์€ ๋งˆ๋ฒ•์‚ฌ ์•„๋ฆฌ์•„(Aria)์ž…๋‹ˆ๋‹ค.\n"
30
+ "๊ทœ์น™:\n"
31
+ "- ํ•ญ์ƒ ํ•œ ๋ฌธ์žฅ๋งŒ ๋งํ•ฉ๋‹ˆ๋‹ค.\n"
32
+ "- ์‚ฌ์šฉ์ž ๋ฐœํ™”๋ฅผ ๋ฐ˜๋ณตํ•˜๊ฑฐ๋‚˜ ๋”ฐ๋ผํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n"
33
+ "- ์˜์–ด, ์ธ์šฉ๋ฌธ, ์ค‘๊ด„ํ˜ธ, ํŠน์ˆ˜๊ธฐํ˜ธ๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n"
34
+ "- ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์—๋งŒ ๋ฐ˜์‘ํ•˜๊ณ  ํ˜ผ์žฃ๋ง์„ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.\n"
35
+ "- ํ•ญ์ƒ ํ•œ๊ตญ์–ด๋งŒ ์‚ฌ์šฉํ•ด์„œ ๋Œ€๋‹ตํ•ฉ๋‹ˆ๋‹ค.\n"
36
+ "๋Œ€ํ™” ์˜ˆ์‹œ:\n"
37
+ "User: ์•ˆ๋…•!\n"
38
+ "Aria: ์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฌด์—‡์„ ๋„์™€๋“œ๋ฆด๊นŒ์š”?\n"
39
+ "User: ์ด๋ฆ„์ด ๋ญ์•ผ?\n"
40
+ "Aria: ์ €๋Š” ์•„๋ฆฌ์•„๋ผ๊ณ  ํ•ด์š”."
41
+ )
42
+ for turn in history[-2:]: # ์ตœ๊ทผ 2ํ„ด๋งŒ ์‚ฌ์šฉ
43
+ if turn["role"] == "user":
44
+ prompt += turn['text']
45
+ else:
46
+ prompt += turn['text']
47
+ prompt += user_msg
48
+ return prompt
49
+
50
+ def character_chat(user_msg):
51
+ prompt = build_prompt(chat_history, user_msg)
52
+ outputs = llm(
53
+ prompt,
54
+ do_sample=True,
55
+ max_new_tokens=20,
56
+ temperature=0.7,
57
+ top_p=0.8,
58
+ repetition_penalty=1.5,
59
+ eos_token_id=tokenizer.eos_token_id,
60
+ return_full_text=False
61
  )
62
+ response = outputs[0]['generated_text'].strip()
63
+ return response
64
 
65
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Blocks(css="""
67
  .chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; }
68
  .bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; }
69
  .bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
 
70
  """) as demo:
71
+ gr.Markdown("### ์•„๋ฆฌ์•„์™€ ๋Œ€ํ™”ํ•˜๊ธฐ")
72
  with gr.Column():
73
+ chat_output = gr.HTML(elem_id="chat-box")
74
+ user_input = gr.Textbox(label="๋ฉ”์‹œ์ง€ ์ž…๋ ฅ", placeholder="Aria์—๊ฒŒ ๋ง์„ ๊ฑธ์–ด๋ณด์„ธ์š”")
75
 
76
  def render_chat():
77
  html = ""
78
  for item in chat_history:
79
+ if item["role"] == "user":
80
  html += f"<div class='bubble-right'>{item['text']}</div>"
81
+ elif item["role"] == "bot":
82
+ html += f"<div class='bubble-left'>{item['text']}</div>"
83
+ return gr.update(value=html)
 
 
 
 
84
 
85
  def on_submit(user_msg):
86
  chat_history.append({"role": "user", "text": user_msg})
87
+ yield render_chat(), ""
88
+ response = character_chat(user_msg)
89
+ chat_history.append({"role": "bot", "text": response})
90
+ yield render_chat(), ""
91
 
92
+ user_input.submit(on_submit, inputs=user_input, outputs=[chat_output, user_input], queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  if __name__ == "__main__":
95
+ demo.launch()
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ