Kims12 commited on
Commit
5d0f132
·
verified ·
1 Parent(s): 2c42789

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -15
app.py CHANGED
@@ -70,7 +70,7 @@ def get_category_prompt(category):
70
 
71
  async def call_api_async(session, content, system_message, max_tokens, temperature, top_p):
72
  payload = {
73
- "model": "gpt-4o-mini", # 모델 이름을 실제 사용 가능한 으로 변경하세요.
74
  "messages": [
75
  {"role": "system", "content": system_message},
76
  {"role": "user", "content": content},
@@ -103,7 +103,11 @@ async def generate_copywriting_async(categories, topic):
103
  task = asyncio.create_task(call_api_async(session, user_content, prompt, max_tokens, temperature, top_p))
104
  tasks.append((category, task))
105
 
106
- for category, task in tasks:
 
 
 
 
107
  try:
108
  copywriting = await task
109
  results[category] = copywriting
@@ -134,22 +138,54 @@ with gr.Blocks() as iface:
134
 
135
  async def validate_and_generate(topic_input):
136
  try:
137
- # 상태 업데이트: 생성 중
138
- status_update = "카피라이팅 생성 중..."
139
- outputs_initial = ["" for _ in CATEGORIES]
140
- yield [gr.update(value=status_update)] + [gr.update(value="") for _ in CATEGORIES]
141
-
142
- # 카피라이팅 생성
143
- results = await generate_copywriting_async(CATEGORIES, topic_input)
144
-
145
- # 상태 업데이트: 완료 및 결과 표시
146
- status_complete = "카피라이팅 생성이 완료되었습니다."
147
- outputs_final = [results.get(category, "") for category in CATEGORIES]
148
- yield [gr.update(value=status_complete)] + [gr.update(value=copy) for copy in outputs_final]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  except Exception as e:
150
  logger.error(f"Error during copywriting generation: {str(e)}")
151
  error_message = f"오류 발생: {str(e)}"
152
- yield [gr.update(value=error_message)] + [gr.update(value="") for _ in CATEGORIES]
153
 
154
  generate_btn.click(
155
  fn=validate_and_generate,
 
70
 
71
  async def call_api_async(session, content, system_message, max_tokens, temperature, top_p):
72
  payload = {
73
+ "model": "gpt-4o-mini", # 실제 사용 가능한 모델 이름으로 변경하세요.
74
  "messages": [
75
  {"role": "system", "content": system_message},
76
  {"role": "user", "content": content},
 
103
  task = asyncio.create_task(call_api_async(session, user_content, prompt, max_tokens, temperature, top_p))
104
  tasks.append((category, task))
105
 
106
+ # Mapping tasks to categories
107
+ task_to_category = {task: category for category, task in tasks}
108
+
109
+ for task in asyncio.as_completed(task_to_category.keys()):
110
+ category = task_to_category[task]
111
  try:
112
  copywriting = await task
113
  results[category] = copywriting
 
138
 
139
  async def validate_and_generate(topic_input):
140
  try:
141
+ # 초기 상태 설정: "카피라이팅 생성 중..." 및 빈 출력 박스
142
+ current_outputs = ["카피라이팅 생성 중..."] + [""] * len(CATEGORIES)
143
+ yield [gr.Markdown.update(value=current_outputs[0])] + [gr.Textbox.update(value=o) for o in current_outputs[1:]]
144
+
145
+ # 비동기 카피라이팅 생성
146
+ async with aiohttp.ClientSession() as session:
147
+ tasks = []
148
+ for category in CATEGORIES:
149
+ prompt = get_category_prompt(category)
150
+ user_content = f"주제: {topic_input}"
151
+ task = asyncio.create_task(call_api_async(session, user_content, prompt, 1000, 0.8, 0.95))
152
+ tasks.append((category, task))
153
+
154
+ # 태스크를 카테고리에 매핑
155
+ task_to_category = {task: category for category, task in tasks}
156
+
157
+ for task in asyncio.as_completed(task_to_category.keys()):
158
+ category = task_to_category[task]
159
+ try:
160
+ copywriting = await task
161
+ index = CATEGORIES.index(category) + 1 # +1은 status를 제외하기 위함
162
+ current_outputs[index] = copywriting
163
+ # 업데이트할 출력 리스트 생성
164
+ updates = [gr.Markdown.update(value="카피라이팅 생성 중...")] + [
165
+ gr.Textbox.update(value=current_outputs[i + 1]) if i + 1 == index else gr.Textbox.update()
166
+ for i in range(len(CATEGORIES))
167
+ ]
168
+ yield updates
169
+ except Exception as e:
170
+ logger.error(f"Error generating copywriting for {category}: {str(e)}")
171
+ index = CATEGORIES.index(category) + 1
172
+ current_outputs[index] = f"오류 발생: {str(e)}"
173
+ updates = [gr.Markdown.update(value="카피라이팅 생성 중...")] + [
174
+ gr.Textbox.update(value=current_outputs[i + 1]) if i + 1 == index else gr.Textbox.update()
175
+ for i in range(len(CATEGORIES))
176
+ ]
177
+ yield updates
178
+
179
+ # 최종 상태 업데이트: "카피라이팅 생성이 완료되었습니다."
180
+ final_updates = [gr.Markdown.update(value="카피라이팅 생성이 완료되었습니다.")] + [
181
+ gr.Textbox.update(value=current_outputs[i + 1]) for i in range(len(CATEGORIES))
182
+ ]
183
+ yield final_updates
184
+
185
  except Exception as e:
186
  logger.error(f"Error during copywriting generation: {str(e)}")
187
  error_message = f"오류 발생: {str(e)}"
188
+ yield [gr.Markdown.update(value=error_message)] + [gr.Textbox.update(value="") for _ in CATEGORIES]
189
 
190
  generate_btn.click(
191
  fn=validate_and_generate,