Kims12 commited on
Commit
b80a15f
·
verified ·
1 Parent(s): 5d0f132

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -91
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import gradio as gr
3
  import openai
4
- import aiohttp
5
- import asyncio
6
  import logging
7
 
8
  # 로깅 설정
@@ -68,54 +67,35 @@ def get_category_prompt(category):
68
  - 팔 없는 장애인 화가의 꿈을 이뤄준 첨단 그림 도구
69
  """
70
 
71
- async def call_api_async(session, content, system_message, max_tokens, temperature, top_p):
72
- payload = {
73
- "model": "gpt-4o-mini", # 실제 사용 가능한 모델 이름으로 변경하세요.
74
- "messages": [
75
- {"role": "system", "content": system_message},
76
- {"role": "user", "content": content},
77
- ],
78
- "max_tokens": max_tokens,
79
- "temperature": temperature,
80
- "top_p": top_p,
81
- }
82
- headers = {"Authorization": f"Bearer {openai.api_key}"}
83
- async with session.post("https://api.openai.com/v1/chat/completions", json=payload, headers=headers) as response:
84
- if response.status != 200:
85
- text = await response.text()
86
- logger.error(f"API 요청 실패: {response.status}, {text}")
87
- raise Exception(f"API 요청 실패: {response.status}")
88
- result = await response.json()
89
- return result['choices'][0]['message']['content']
90
-
91
- async def generate_copywriting_async(categories, topic):
92
  max_tokens = 1000
93
  temperature = 0.8
94
  top_p = 0.95
95
-
96
- results = {}
97
-
98
- async with aiohttp.ClientSession() as session:
99
- tasks = []
100
- for category in categories:
101
- prompt = get_category_prompt(category)
102
- user_content = f"주제: {topic}"
103
- task = asyncio.create_task(call_api_async(session, user_content, prompt, max_tokens, temperature, top_p))
104
- tasks.append((category, task))
105
-
106
- # Mapping tasks to categories
107
- task_to_category = {task: category for category, task in tasks}
108
-
109
- for task in asyncio.as_completed(task_to_category.keys()):
110
- category = task_to_category[task]
111
- try:
112
- copywriting = await task
113
- results[category] = copywriting
114
- logger.info(f"Generated {category}: {copywriting}")
115
- except Exception as e:
116
- logger.error(f"Error generating copywriting for {category}: {str(e)}")
117
- results[category] = f"오류 발생: {str(e)}"
118
 
 
 
 
 
 
 
119
  return results
120
 
121
  # Gradio 인터페이스 부분
@@ -123,7 +103,6 @@ with gr.Blocks() as iface:
123
  gr.Markdown("# AI 카피라이팅 생성기")
124
 
125
  with gr.Column():
126
- # 선택 기능 제거
127
  topic = gr.Textbox(lines=1, label="주제를 입력하세요")
128
 
129
  generate_btn = gr.Button("카피라이팅 생성하기")
@@ -136,63 +115,144 @@ with gr.Blocks() as iface:
136
  output_box = gr.Textbox(label=category, visible=True)
137
  output_boxes[category] = output_box
138
 
139
- async def validate_and_generate(topic_input):
 
140
  try:
141
- # 초기 상태 설정: "카피라이팅 생성 중..." 및 빈 출력 박스
142
- current_outputs = ["카피라이팅 생성 중..."] + [""] * len(CATEGORIES)
143
- yield [gr.Markdown.update(value=current_outputs[0])] + [gr.Textbox.update(value=o) for o in current_outputs[1:]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- # 비동기 카피라이팅 생성
146
- async with aiohttp.ClientSession() as session:
147
- tasks = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  for category in CATEGORIES:
149
  prompt = get_category_prompt(category)
150
- user_content = f"주제: {topic_input}"
151
- task = asyncio.create_task(call_api_async(session, user_content, prompt, 1000, 0.8, 0.95))
152
- tasks.append((category, task))
 
153
 
154
- # 태스크를 카테고리 매핑
155
- task_to_category = {task: category for category, task in tasks}
156
-
157
- for task in asyncio.as_completed(task_to_category.keys()):
158
- category = task_to_category[task]
159
- try:
160
- copywriting = await task
161
- index = CATEGORIES.index(category) + 1 # +1은 status를 제외하기 위함
162
- current_outputs[index] = copywriting
163
- # 업데이트할 출력 리스트
164
- updates = [gr.Markdown.update(value="카피라이팅 생성 중...")] + [
165
- gr.Textbox.update(value=current_outputs[i + 1]) if i + 1 == index else gr.Textbox.update()
166
- for i in range(len(CATEGORIES))
167
- ]
168
- yield updates
169
- except Exception as e:
170
- logger.error(f"Error generating copywriting for {category}: {str(e)}")
171
- index = CATEGORIES.index(category) + 1
172
- current_outputs[index] = f"오류 : {str(e)}"
173
- updates = [gr.Markdown.update(value="카피라이팅 생성 중...")] + [
174
- gr.Textbox.update(value=current_outputs[i + 1]) if i + 1 == index else gr.Textbox.update()
175
- for i in range(len(CATEGORIES))
176
- ]
177
- yield updates
178
-
179
- # 최종 상태 업데이트: "카피라이팅 생성이 완료되었습니다."
180
- final_updates = [gr.Markdown.update(value="피라이팅 생성이 완료되었습니다.")] + [
181
- gr.Textbox.update(value=current_outputs[i + 1]) for i in range(len(CATEGORIES))
182
- ]
183
- yield final_updates
184
 
185
  except Exception as e:
186
  logger.error(f"Error during copywriting generation: {str(e)}")
187
- error_message = f"오류 발생: {str(e)}"
188
- yield [gr.Markdown.update(value=error_message)] + [gr.Textbox.update(value="") for _ in CATEGORIES]
189
 
 
190
  generate_btn.click(
191
- fn=validate_and_generate,
192
  inputs=[topic],
193
  outputs=[status] + [output_boxes[category] for category in CATEGORIES],
194
- api_name=None,
195
  )
 
 
 
196
 
197
  # 인터페이스 실행
198
  iface.launch()
 
1
  import os
2
  import gradio as gr
3
  import openai
4
+ import requests
 
5
  import logging
6
 
7
  # 로깅 설정
 
67
  - 팔 없는 장애인 화가의 꿈을 이뤄준 첨단 그림 도구
68
  """
69
 
70
+ def call_api_sync(content, system_message, max_tokens, temperature, top_p):
71
+ response = requests.post(
72
+ "https://api.openai.com/v1/chat/completions",
73
+ headers={"Authorization": f"Bearer {openai.api_key}"},
74
+ json={
75
+ "model": "gpt-4o-mini",
76
+ "messages": [
77
+ {"role": "system", "content": system_message},
78
+ {"role": "user", "content": content},
79
+ ],
80
+ "max_tokens": max_tokens,
81
+ "temperature": temperature,
82
+ "top_p": top_p,
83
+ }
84
+ )
85
+ result = response.json()
86
+ return result['choices'][0]['message']['content']
87
+
88
+ def generate_copywriting(categories, topic):
 
 
89
  max_tokens = 1000
90
  temperature = 0.8
91
  top_p = 0.95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ results = {}
94
+ for category in categories:
95
+ prompt = get_category_prompt(category)
96
+ user_content = f"주제: {topic}"
97
+ copywriting = call_api_sync(user_content, prompt, max_tokens, temperature, top_p)
98
+ results[category] = copywriting
99
  return results
100
 
101
  # Gradio 인터페이스 부분
 
103
  gr.Markdown("# AI 카피라이팅 생성기")
104
 
105
  with gr.Column():
 
106
  topic = gr.Textbox(lines=1, label="주제를 입력하세요")
107
 
108
  generate_btn = gr.Button("카피라이팅 생성하기")
 
115
  output_box = gr.Textbox(label=category, visible=True)
116
  output_boxes[category] = output_box
117
 
118
+ # 기존 동기 함수 (수정/삭제 금지)
119
+ def validate_and_generate(topic):
120
  try:
121
+ results = generate_copywriting(CATEGORIES, topic)
122
+ logger.debug(f"Generated results: {results}")
123
+
124
+ outputs = []
125
+ for category in CATEGORIES:
126
+ if category in results:
127
+ outputs.append(gr.update(value=results[category]))
128
+ else:
129
+ outputs.append(gr.update(value=""))
130
+
131
+ return [gr.update(value="카피라이팅 생성이 완료되었습니다.")] + outputs
132
+ except Exception as e:
133
+ logger.error(f"Error during copywriting generation: {str(e)}")
134
+ return [gr.update(value=f"오류 발생: {str(e)}")] + [gr.update(value="") for _ in CATEGORIES]
135
+
136
+ ##########################################
137
+ # 추가��� 비동기/병렬 처리용 코드 시작
138
+ ##########################################
139
+ import asyncio
140
+ import aiohttp
141
+
142
+ async def call_api_async(content, system_message, max_tokens, temperature, top_p):
143
+ """
144
+ 비동기적으로 OpenAI API를 호출하는 함수
145
+ """
146
+ url = "https://api.openai.com/v1/chat/completions"
147
+ headers = {"Authorization": f"Bearer {openai.api_key}"}
148
+ payload = {
149
+ "model": "gpt-4o-mini",
150
+ "messages": [
151
+ {"role": "system", "content": system_message},
152
+ {"role": "user", "content": content},
153
+ ],
154
+ "max_tokens": max_tokens,
155
+ "temperature": temperature,
156
+ "top_p": top_p,
157
+ }
158
+ async with aiohttp.ClientSession() as session:
159
+ async with session.post(url, headers=headers, json=payload) as resp:
160
+ resp_json = await resp.json()
161
+ return resp_json["choices"][0]["message"]["content"]
162
 
163
+ async def generate_copywriting_async(categories, topic):
164
+ """
165
+ 여러 카테고리에 대해 비동기적으로 카피라이팅을 생성하는 함수
166
+ """
167
+ max_tokens = 1000
168
+ temperature = 0.8
169
+ top_p = 0.95
170
+
171
+ tasks = []
172
+ for category in categories:
173
+ prompt = get_category_prompt(category)
174
+ user_content = f"주제: {topic}"
175
+ tasks.append(
176
+ asyncio.create_task(
177
+ call_api_async(user_content, prompt, max_tokens, temperature, top_p)
178
+ )
179
+ )
180
+
181
+ # 병렬 실행
182
+ results = await asyncio.gather(*tasks)
183
+
184
+ # category 순서에 맞춰 딕셔너리화
185
+ result_dict = {}
186
+ for category, copywriting in zip(categories, results):
187
+ result_dict[category] = copywriting
188
+
189
+ return result_dict
190
+
191
+ async def validate_and_generate_async(topic):
192
+ """
193
+ Gradio에서 스트리밍(yield)을 통해
194
+ 각 카테고리의 결과가 나오면 즉시 전달하도록 하는 함수
195
+ """
196
+ try:
197
+ # 우선 상태창 업데이트
198
+ yield [gr.update(value="카피라이팅 생성 중...")] + [gr.update() for _ in CATEGORIES]
199
+
200
+ # 비동기로 카피라이팅 생성
201
+ # 각 카테고리별 결과를 기다리지 않고, 완료되는 순서대로 전달
202
+ async def async_run():
203
+ # as_completed로 각 결과가 나올 때마다 받기
204
+ tasks = {}
205
  for category in CATEGORIES:
206
  prompt = get_category_prompt(category)
207
+ user_content = f"주제: {topic}"
208
+ tasks[category] = asyncio.create_task(
209
+ call_api_async(user_content, prompt, 1000, 0.8, 0.95)
210
+ )
211
 
212
+ # 카테고리 결과가 끝나는 순서대로 반환
213
+ for finished_task in asyncio.as_completed(tasks.values()):
214
+ # finished_task가 어떤 카테고리에 해당되는지 찾는다
215
+ for cat, tsk in tasks.items():
216
+ if tsk == finished_task:
217
+ # 해당 카테고리의 결과
218
+ try:
219
+ result_text = await finished_task
220
+ except Exception as e:
221
+ result_text = f"오류 : {str(e)}"
222
+ yield cat, result_text
223
+ break
224
+
225
+ # 스트리밍 처리
226
+ # 각 카테고리가 끝날 때마다 yield 결과 갱신
227
+ # outputs 순서는 [status] + [카테고리1, 카테고리2, ...]
228
+ # 따라서 index를 찾아서 부분 업데이트
229
+ # status(0)번, 카테고리별 1~N
230
+ current_values = ["카피라이팅 중..."] + ["" for _ in CATEGORIES]
231
+
232
+ async for cat, result_text in async_run():
233
+ # cat의 인덱스를 찾아서 갱신
234
+ idx = CATEGORIES.index(cat) + 1 # status가 0번이므로 +1
235
+ current_values[idx] = result_text
236
+ yield current_values # 부분 업데이트
237
+
238
+ # 모든테고리가 끝난 최종 상태메시지
239
+ current_values[0] = "카피라이팅 생성이 모두 완료되었습니다."
240
+ yield current_values
 
241
 
242
  except Exception as e:
243
  logger.error(f"Error during copywriting generation: {str(e)}")
244
+ yield [f"오류 발생: {str(e)}"] + ["" for _ in CATEGORIES]
 
245
 
246
+ # 비동기 함수를 Gradio 이벤트에 연결
247
  generate_btn.click(
248
+ fn=validate_and_generate_async,
249
  inputs=[topic],
250
  outputs=[status] + [output_boxes[category] for category in CATEGORIES],
251
+ api_name="generate_copy_async" # 임의의 api_name
252
  )
253
+ ##########################################
254
+ # 추가된 비동기/병렬 처리용 코드 끝
255
+ ##########################################
256
 
257
  # 인터페이스 실행
258
  iface.launch()