gbrabbit commited on
Commit
5e29010
ยท
1 Parent(s): ea9b321

Auto commit at 07-2025-08 0:31:54

Browse files
Files changed (1) hide show
  1. app.py +71 -8
app.py CHANGED
@@ -56,8 +56,7 @@ try:
56
  torch_dtype=torch.float16,
57
  trust_remote_code=True,
58
  device_map=None,
59
- low_cpu_mem_usage=True,
60
- # max_memory={0: "4GB"} # GPU ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ
61
  )
62
  print(" โœ… ์ปค์Šคํ…€ ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
63
  else:
@@ -116,7 +115,8 @@ def chat_with_model(message, history, image=None):
116
  pixel_values = transform(pil_image).unsqueeze(0)
117
  image_metas = {"vision_grid_thw": torch.tensor([[1, 14, 14]])} # ๊ธฐ๋ณธ ๊ทธ๋ฆฌ๋“œ ํฌ๊ธฐ
118
 
119
- outputs = model.generate(
 
120
  input_ids=inputs["input_ids"],
121
  attention_mask=inputs["attention_mask"],
122
  pixel_values=[pixel_values],
@@ -128,7 +128,7 @@ def chat_with_model(message, history, image=None):
128
  )
129
  else:
130
  # ์ด๋ฏธ์ง€๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ํ…์ŠคํŠธ๋งŒ ์ƒ์„ฑ
131
- outputs = model.generate(
132
  input_ids=inputs["input_ids"],
133
  attention_mask=inputs["attention_mask"],
134
  max_new_tokens=200,
@@ -137,7 +137,38 @@ def chat_with_model(message, history, image=None):
137
  pad_token_id=tokenizer.eos_token_id
138
  )
139
 
140
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  if message in response:
142
  response = response.replace(message, "").strip()
143
  return response if response else "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์‘๋‹ต์„ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
@@ -172,7 +203,8 @@ def solve_math_problem(problem, image=None):
172
  pixel_values = transform(pil_image).unsqueeze(0)
173
  image_metas = {"vision_grid_thw": torch.tensor([[1, 14, 14]])} # ๊ธฐ๋ณธ ๊ทธ๋ฆฌ๋“œ ํฌ๊ธฐ
174
 
175
- outputs = model.generate(
 
176
  input_ids=inputs["input_ids"],
177
  attention_mask=inputs["attention_mask"],
178
  pixel_values=[pixel_values],
@@ -184,7 +216,7 @@ def solve_math_problem(problem, image=None):
184
  )
185
  else:
186
  # ์ด๋ฏธ์ง€๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ํ…์ŠคํŠธ๋งŒ ์ƒ์„ฑ
187
- outputs = model.generate(
188
  input_ids=inputs["input_ids"],
189
  attention_mask=inputs["attention_mask"],
190
  max_new_tokens=300,
@@ -193,7 +225,38 @@ def solve_math_problem(problem, image=None):
193
  pad_token_id=tokenizer.eos_token_id
194
  )
195
 
196
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  if prompt in response:
198
  response = response.replace(prompt, "").strip()
199
  return response if response else "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ํ’€ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
 
56
  torch_dtype=torch.float16,
57
  trust_remote_code=True,
58
  device_map=None,
59
+ low_cpu_mem_usage=True
 
60
  )
61
  print(" โœ… ์ปค์Šคํ…€ ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
62
  else:
 
115
  pixel_values = transform(pil_image).unsqueeze(0)
116
  image_metas = {"vision_grid_thw": torch.tensor([[1, 14, 14]])} # ๊ธฐ๋ณธ ๊ทธ๋ฆฌ๋“œ ํฌ๊ธฐ
117
 
118
+ # ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋ชจ๋ธ์˜ forward ๋ฉ”์„œ๋“œ ์‚ฌ์šฉ
119
+ outputs = model(
120
  input_ids=inputs["input_ids"],
121
  attention_mask=inputs["attention_mask"],
122
  pixel_values=[pixel_values],
 
128
  )
129
  else:
130
  # ์ด๋ฏธ์ง€๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ํ…์ŠคํŠธ๋งŒ ์ƒ์„ฑ
131
+ outputs = model(
132
  input_ids=inputs["input_ids"],
133
  attention_mask=inputs["attention_mask"],
134
  max_new_tokens=200,
 
137
  pad_token_id=tokenizer.eos_token_id
138
  )
139
 
140
+ # outputs๊ฐ€ ํŠœํ”Œ์ธ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์š”์†Œ ์‚ฌ์šฉ
141
+ if isinstance(outputs, tuple):
142
+ logits = outputs[0]
143
+ else:
144
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
145
+
146
+ # ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์˜ ํ† ํฐ ์„ ํƒ
147
+ next_token = torch.argmax(logits[:, -1, :], dim=-1)
148
+ generated_tokens = [next_token]
149
+
150
+ # ์ถ”๊ฐ€ ํ† ํฐ ์ƒ์„ฑ
151
+ for _ in range(199): # max_new_tokens - 1
152
+ inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token.unsqueeze(-1)], dim=-1)
153
+ inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.ones_like(next_token.unsqueeze(-1))], dim=-1)
154
+
155
+ with torch.no_grad():
156
+ outputs = model(**inputs)
157
+ if isinstance(outputs, tuple):
158
+ logits = outputs[0]
159
+ else:
160
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
161
+
162
+ next_token = torch.argmax(logits[:, -1, :], dim=-1)
163
+ generated_tokens.append(next_token)
164
+
165
+ if next_token.item() == tokenizer.eos_token_id:
166
+ break
167
+
168
+ # ์ƒ์„ฑ๋œ ํ† ํฐ๋“ค์„ ๋””์ฝ”๋”ฉ
169
+ generated_ids = torch.cat(generated_tokens, dim=0)
170
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True)
171
+
172
  if message in response:
173
  response = response.replace(message, "").strip()
174
  return response if response else "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์‘๋‹ต์„ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
 
203
  pixel_values = transform(pil_image).unsqueeze(0)
204
  image_metas = {"vision_grid_thw": torch.tensor([[1, 14, 14]])} # ๊ธฐ๋ณธ ๊ทธ๋ฆฌ๋“œ ํฌ๊ธฐ
205
 
206
+ # ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋ชจ๋ธ์˜ forward ๋ฉ”์„œ๋“œ ์‚ฌ์šฉ
207
+ outputs = model(
208
  input_ids=inputs["input_ids"],
209
  attention_mask=inputs["attention_mask"],
210
  pixel_values=[pixel_values],
 
216
  )
217
  else:
218
  # ์ด๋ฏธ์ง€๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ํ…์ŠคํŠธ๋งŒ ์ƒ์„ฑ
219
+ outputs = model(
220
  input_ids=inputs["input_ids"],
221
  attention_mask=inputs["attention_mask"],
222
  max_new_tokens=300,
 
225
  pad_token_id=tokenizer.eos_token_id
226
  )
227
 
228
+ # outputs๊ฐ€ ํŠœํ”Œ์ธ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์š”์†Œ ์‚ฌ์šฉ
229
+ if isinstance(outputs, tuple):
230
+ logits = outputs[0]
231
+ else:
232
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
233
+
234
+ # ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์˜ ํ† ํฐ ์„ ํƒ
235
+ next_token = torch.argmax(logits[:, -1, :], dim=-1)
236
+ generated_tokens = [next_token]
237
+
238
+ # ์ถ”๊ฐ€ ํ† ํฐ ์ƒ์„ฑ
239
+ for _ in range(299): # max_new_tokens - 1
240
+ inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token.unsqueeze(-1)], dim=-1)
241
+ inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.ones_like(next_token.unsqueeze(-1))], dim=-1)
242
+
243
+ with torch.no_grad():
244
+ outputs = model(**inputs)
245
+ if isinstance(outputs, tuple):
246
+ logits = outputs[0]
247
+ else:
248
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
249
+
250
+ next_token = torch.argmax(logits[:, -1, :], dim=-1)
251
+ generated_tokens.append(next_token)
252
+
253
+ if next_token.item() == tokenizer.eos_token_id:
254
+ break
255
+
256
+ # ์ƒ์„ฑ๋œ ํ† ํฐ๋“ค์„ ๋””์ฝ”๋”ฉ
257
+ generated_ids = torch.cat(generated_tokens, dim=0)
258
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True)
259
+
260
  if prompt in response:
261
  response = response.replace(prompt, "").strip()
262
  return response if response else "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ˆ˜ํ•™ ๋ฌธ์ œ๋ฅผ ํ’€ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."