seawolf2357 commited on
Commit
f6d9128
·
verified ·
1 Parent(s): f024201

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -462
app.py DELETED
@@ -1,462 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- import re
5
- import tempfile
6
- from collections.abc import Iterator
7
- from threading import Thread
8
-
9
- import cv2
10
- import gradio as gr
11
- import spaces
12
- import torch
13
- from loguru import logger
14
- from PIL import Image
15
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
-
17
- # [PDF] PyPDF2 추가
18
- import PyPDF2
19
- # [CSV] Pandas 추가
20
- import pandas as pd
21
-
22
- model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
23
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
24
- model = Gemma3ForConditionalGeneration.from_pretrained(
25
- model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
26
- )
27
-
28
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
29
-
30
- ###################################################################
31
- # CSV를 Markdown으로 변환하는 유틸 함수
32
- ###################################################################
33
- def csv_to_markdown(csv_path: str) -> str:
34
- """
35
- CSV 파일 전체를 문자열로 변환하여 Markdown 형태로 반환.
36
- (매우 큰 CSV라면 전체를 넘기는 것이 위험할 수 있음 -> 필요 시 잘라낼 것)
37
- """
38
- try:
39
- df = pd.read_csv(csv_path)
40
- df_str = df.to_string()
41
- # 필요하다면 길이 제한을 걸어도 됨
42
- # if len(df_str) > 10000:
43
- # df_str = df_str[:10000] + "\n...(truncated)..."
44
-
45
- return f"**[CSV File: {os.path.basename(csv_path)}]**\n\n```\n{df_str}\n```"
46
- except Exception as e:
47
- return f"Failed to read CSV ({os.path.basename(csv_path)}): {str(e)}"
48
-
49
- ###################################################################
50
- # PDF -> Markdown 변환 함수 (기존)
51
- ###################################################################
52
- def pdf_to_markdown(pdf_path: str) -> str:
53
- """
54
- PDF 파일을 텍스트로 추출 후, 간단한 Markdown 형태로 반환.
55
- """
56
- text_chunks = []
57
- with open(pdf_path, "rb") as f:
58
- reader = PyPDF2.PdfReader(f)
59
- for page_num, page in enumerate(reader.pages, start=1):
60
- page_text = page.extract_text()
61
- page_text = page_text.strip() if page_text else ""
62
- if page_text:
63
- # 페이지별로 간단한 헤더와 본문을 Markdown으로 합침
64
- text_chunks.append(f"## Page {page_num}\n\n{page_text}\n")
65
- return "\n".join(text_chunks)
66
-
67
- ###################################################################
68
- # 이미지/비디오 개수 카운트 (기존)
69
- ###################################################################
70
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
71
- image_count = 0
72
- video_count = 0
73
- for path in paths:
74
- if path.endswith(".mp4"):
75
- video_count += 1
76
- else:
77
- image_count += 1
78
- return image_count, video_count
79
-
80
- def count_files_in_history(history: list[dict]) -> tuple[int, int]:
81
- image_count = 0
82
- video_count = 0
83
- for item in history:
84
- if item["role"] != "user" or isinstance(item["content"], str):
85
- continue
86
- if item["content"][0].endswith(".mp4"):
87
- video_count += 1
88
- else:
89
- image_count += 1
90
- return image_count, video_count
91
-
92
- ###################################################################
93
- # 미디어(이미지/비디오) 제한 검사 + PDF/CSV 예외 (기존/수정)
94
- ###################################################################
95
- def validate_media_constraints(message: dict, history: list[dict]) -> bool:
96
- """
97
- 이미지/비디오 개수와 혼합 여부 등을 검사하는 함수.
98
- PDF, CSV 등은 검사 로직에서 제외하여 업로드만 허용.
99
- """
100
- # pdf, csv 파일 제외
101
- pdf_files = [f for f in message["files"] if f.endswith(".pdf")]
102
- csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
103
- non_pdf_csv_files = [f for f in message["files"]
104
- if not f.endswith(".pdf") and not f.lower().endswith(".csv")]
105
-
106
- # 기존 로직은 이미지/비디오에 대해서만 체크
107
- new_image_count, new_video_count = count_files_in_new_message(non_pdf_csv_files)
108
- history_image_count, history_video_count = count_files_in_history(history)
109
- image_count = history_image_count + new_image_count
110
- video_count = history_video_count + new_video_count
111
-
112
- if video_count > 1:
113
- gr.Warning("Only one video is supported.")
114
- return False
115
- if video_count == 1:
116
- if image_count > 0:
117
- gr.Warning("Mixing images and videos is not allowed.")
118
- return False
119
- if "<image>" in message["text"]:
120
- gr.Warning("Using <image> tags with video files is not supported.")
121
- return False
122
-
123
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
124
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
125
- return False
126
-
127
- # <image> 태그가 있을 경우, 이미지 수와 태그 수 일치
128
- if "<image>" in message["text"]:
129
- if message["text"].count("<image>") != new_image_count:
130
- gr.Warning("The number of <image> tags in the text does not match the number of images.")
131
- return False
132
-
133
- return True
134
-
135
- ###################################################################
136
- # 동영상 처리 (기존)
137
- ###################################################################
138
- def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
139
- vidcap = cv2.VideoCapture(video_path)
140
- fps = vidcap.get(cv2.CAP_PROP_FPS)
141
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
142
-
143
- frame_interval = int(fps / 3)
144
- frames = []
145
-
146
- for i in range(0, total_frames, frame_interval):
147
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
148
- success, image = vidcap.read()
149
- if success:
150
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
151
- pil_image = Image.fromarray(image)
152
- timestamp = round(i / fps, 2)
153
- frames.append((pil_image, timestamp))
154
-
155
- vidcap.release()
156
- return frames
157
-
158
- def process_video(video_path: str) -> list[dict]:
159
- content = []
160
- frames = downsample_video(video_path)
161
- for frame in frames:
162
- pil_image, timestamp = frame
163
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
164
- pil_image.save(temp_file.name)
165
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
166
- content.append({"type": "image", "url": temp_file.name})
167
- logger.debug(f"{content=}")
168
- return content
169
-
170
- ###################################################################
171
- # <image> 태그 interleaved 이미지 처리 (기존)
172
- ###################################################################
173
- def process_interleaved_images(message: dict) -> list[dict]:
174
- logger.debug(f"{message['files']=}")
175
- parts = re.split(r"(<image>)", message["text"])
176
- logger.debug(f"{parts=}")
177
-
178
- content = []
179
- image_index = 0
180
- for part in parts:
181
- logger.debug(f"{part=}")
182
- if part == "<image>":
183
- content.append({"type": "image", "url": message["files"][image_index]})
184
- logger.debug(f"file: {message['files'][image_index]}")
185
- image_index += 1
186
- elif part.strip():
187
- content.append({"type": "text", "text": part.strip()})
188
- elif isinstance(part, str) and part != "<image>":
189
- content.append({"type": "text", "text": part})
190
- logger.debug(f"{content=}")
191
- return content
192
-
193
- ###################################################################
194
- # 새 user message 처리 (PDF + CSV + 이미지/비디오)
195
- ###################################################################
196
- def process_new_user_message(message: dict) -> list[dict]:
197
- if not message["files"]:
198
- return [{"type": "text", "text": message["text"]}]
199
-
200
- # PDF 파일 목록
201
- pdf_files = [f for f in message["files"] if f.endswith(".pdf")]
202
- # CSV 파일 목록
203
- csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
204
- # 이미지/비디오 (기존)
205
- other_files = [f for f in message["files"]
206
- if not f.endswith(".pdf") and not f.lower().endswith(".csv")]
207
-
208
- # 일단 사용자의 text를 먼저 넣는다
209
- content_list = [{"type": "text", "text": message["text"]}]
210
-
211
- # [PDF] 변환 후 추가
212
- for pdf_path in pdf_files:
213
- pdf_markdown = pdf_to_markdown(pdf_path)
214
- if pdf_markdown.strip():
215
- content_list.append({"type": "text", "text": pdf_markdown})
216
- else:
217
- content_list.append({"type": "text", "text": "(PDF에서 텍스트 추출 실패)"})
218
-
219
- # [CSV] 변환 후 추가
220
- for cfile in csv_files:
221
- csv_md = csv_to_markdown(cfile)
222
- content_list.append({"type": "text", "text": csv_md})
223
-
224
- # 영상 처리
225
- video_files = [f for f in other_files if f.endswith(".mp4")]
226
- if video_files:
227
- content_list += process_video(video_files[0])
228
- return content_list
229
-
230
- # interleaved 이미지
231
- if "<image>" in message["text"]:
232
- return process_interleaved_images(message)
233
-
234
- # 일반 이미지(여러 장)
235
- image_files = [f for f in other_files if not f.endswith(".mp4")]
236
- if image_files:
237
- content_list += [{"type": "image", "url": path} for path in image_files]
238
-
239
- return content_list
240
-
241
- ###################################################################
242
- # 히스토리 -> LLM용 메시지 변환 (기존)
243
- ###################################################################
244
- def process_history(history: list[dict]) -> list[dict]:
245
- messages = []
246
- current_user_content: list[dict] = []
247
- for item in history:
248
- if item["role"] == "assistant":
249
- if current_user_content:
250
- messages.append({"role": "user", "content": current_user_content})
251
- current_user_content = []
252
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
253
- else:
254
- content = item["content"]
255
- if isinstance(content, str):
256
- current_user_content.append({"type": "text", "text": content})
257
- else:
258
- current_user_content.append({"type": "image", "url": content[0]})
259
- return messages
260
-
261
- ###################################################################
262
- # 메인 추론 함수 (기존)
263
- ###################################################################
264
- @spaces.GPU(duration=120)
265
- def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
266
- if not validate_media_constraints(message, history):
267
- yield ""
268
- return
269
-
270
- messages = []
271
- if system_prompt:
272
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
273
- messages.extend(process_history(history))
274
- messages.append({"role": "user", "content": process_new_user_message(message)})
275
-
276
- inputs = processor.apply_chat_template(
277
- messages,
278
- add_generation_prompt=True,
279
- tokenize=True,
280
- return_dict=True,
281
- return_tensors="pt",
282
- ).to(device=model.device, dtype=torch.bfloat16)
283
-
284
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
285
- generate_kwargs = dict(
286
- inputs,
287
- streamer=streamer,
288
- max_new_tokens=max_new_tokens,
289
- )
290
- t = Thread(target=model.generate, kwargs=generate_kwargs)
291
- t.start()
292
-
293
- output = ""
294
- for delta in streamer:
295
- output += delta
296
- yield output
297
-
298
- ###################################################################
299
- # 예시들 (기존 그대로)
300
- ###################################################################
301
- examples = [
302
- [
303
- {
304
- "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.",
305
- "files": [],
306
- }
307
- ],
308
- [
309
- {
310
- "text": "Write the matplotlib code to generate the same bar chart.",
311
- "files": ["assets/additional-examples/barchart.png"],
312
- }
313
- ],
314
- [
315
- {
316
- "text": "What is odd about this video?",
317
- "files": ["assets/additional-examples/tmp.mp4"],
318
- }
319
- ],
320
- [
321
- {
322
- "text": "I already have this supplement <image> and I want to buy this one <image>. Any warnings I should know about?",
323
- "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
324
- }
325
- ],
326
- [
327
- {
328
- "text": "Write a poem inspired by the visual elements of the images.",
329
- "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
330
- }
331
- ],
332
- [
333
- {
334
- "text": "Compose a short musical piece inspired by the visual elements of the images.",
335
- "files": [
336
- "assets/sample-images/07-1.png",
337
- "assets/sample-images/07-2.png",
338
- "assets/sample-images/07-3.png",
339
- "assets/sample-images/07-4.png",
340
- ],
341
- }
342
- ],
343
- [
344
- {
345
- "text": "Write a short story about what might have happened in this house.",
346
- "files": ["assets/sample-images/08.png"],
347
- }
348
- ],
349
- [
350
- {
351
- "text": "Create a short story based on the sequence of images.",
352
- "files": [
353
- "assets/sample-images/09-1.png",
354
- "assets/sample-images/09-2.png",
355
- "assets/sample-images/09-3.png",
356
- "assets/sample-images/09-4.png",
357
- "assets/sample-images/09-5.png",
358
- ],
359
- }
360
- ],
361
- [
362
- {
363
- "text": "Describe the creatures that would live in this world.",
364
- "files": ["assets/sample-images/10.png"],
365
- }
366
- ],
367
- [
368
- {
369
- "text": "Read text in the image.",
370
- "files": ["assets/additional-examples/1.png"],
371
- }
372
- ],
373
- [
374
- {
375
- "text": "When is this ticket dated and how much did it cost?",
376
- "files": ["assets/additional-examples/2.png"],
377
- }
378
- ],
379
- [
380
- {
381
- "text": "Read the text in the image into markdown.",
382
- "files": ["assets/additional-examples/3.png"],
383
- }
384
- ],
385
- [
386
- {
387
- "text": "Evaluate this integral.",
388
- "files": ["assets/additional-examples/4.png"],
389
- }
390
- ],
391
- [
392
- {
393
- "text": "caption this image",
394
- "files": ["assets/sample-images/01.png"],
395
- }
396
- ],
397
- [
398
- {
399
- "text": "What's the sign says?",
400
- "files": ["assets/sample-images/02.png"],
401
- }
402
- ],
403
- [
404
- {
405
- "text": "Compare and contrast the two images.",
406
- "files": ["assets/sample-images/03.png"],
407
- }
408
- ],
409
- [
410
- {
411
- "text": "List all the objects in the image and their colors.",
412
- "files": ["assets/sample-images/04.png"],
413
- }
414
- ],
415
- [
416
- {
417
- "text": "Describe the atmosphere of the scene.",
418
- "files": ["assets/sample-images/05.png"],
419
- }
420
- ],
421
- ]
422
-
423
- ###################################################################
424
- # PDF + CSV를 허용하는 Gradio ChatInterface
425
- ###################################################################
426
- demo = gr.ChatInterface(
427
- fn=run,
428
- type="messages",
429
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
430
- textbox=gr.MultimodalTextbox(
431
- file_types=["image", ".mp4", ".pdf", ".csv"], # pdf & csv 허용
432
- file_count="multiple",
433
- autofocus=True
434
- ),
435
- multimodal=True,
436
- additional_inputs=[
437
- gr.Textbox(
438
- label="System Prompt",
439
- value=(
440
- "You are a deeply thoughtful AI. Consider problems thoroughly and derive correct "
441
- "solutions through systematic reasoning. Please answer in korean."
442
- )
443
- ),
444
- gr.Slider(
445
- label="Max New Tokens",
446
- minimum=100,
447
- maximum=8000,
448
- step=50,
449
- value=2000
450
- ),
451
- ],
452
- stop_btn=False,
453
- title="Gemma 3 27B IT",
454
- examples=examples,
455
- run_examples_on_click=False,
456
- cache_examples=False,
457
- css_paths="style.css",
458
- delete_cache=(1800, 1800),
459
- )
460
-
461
- if __name__ == "__main__":
462
- demo.launch()