Kims12 commited on
Commit
6027eaa
·
verified ·
1 Parent(s): a302432

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -38
app.py CHANGED
@@ -10,15 +10,16 @@ import aiohttp
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # OpenAI API 클라이언트 설정
14
  openai.api_key = os.getenv("OPENAI_API_KEY")
15
 
16
- # 요청사항에 따라 카테고리 수정
17
  CATEGORIES = [
18
  "공포 마케팅",
19
  "스토리텔링"
20
  ]
21
 
 
22
  def get_category_prompt(category):
23
  if category == "공포 마케팅":
24
  return """
@@ -68,7 +69,7 @@ def get_category_prompt(category):
68
  - 팔 없는 장애인 화가의 꿈을 이뤄준 첨단 그림 도구
69
  """
70
 
71
- # 비동기 API 호출 함수
72
  async def call_api_async(session, content, system_message, max_tokens, temperature, top_p):
73
  url = "https://api.openai.com/v1/chat/completions"
74
  headers = {
@@ -85,11 +86,14 @@ async def call_api_async(session, content, system_message, max_tokens, temperatu
85
  "temperature": temperature,
86
  "top_p": top_p,
87
  }
88
- async with session.post(url, headers=headers, json=payload) as response:
89
- response_json = await response.json()
90
- return response_json["choices"][0]["message"]["content"]
 
 
 
91
 
92
- # 여러 카테고리를 "동시에" 요청 보내
93
  async def generate_copywriting_async(categories, topic):
94
  max_tokens = 1000
95
  temperature = 0.8
@@ -97,34 +101,37 @@ async def generate_copywriting_async(categories, topic):
97
  results = {}
98
 
99
  async with aiohttp.ClientSession() as session:
100
- # 카테고로 태스크 만들어 "동시에" 실행
101
- tasks = []
102
- cat_map = {}
103
-
104
- for category in categories:
105
- prompt = get_category_prompt(category)
106
  user_content = f"주제: {topic}"
107
- # create_task를 통해 태스크 생성 후 배열에 담음
108
- task = asyncio.create_task(
109
  call_api_async(session, user_content, prompt, max_tokens, temperature, top_p)
110
  )
111
- tasks.append(task)
112
- cat_map[task] = category
113
-
114
- # 모든 태스크가 한꺼번에 시작된 상태에서,
115
- # 먼저태스크부터 결과를 얻어옴 (as_completed)
116
- for done_task in asyncio.as_completed(tasks):
117
- finished_category = cat_map[done_task]
 
 
 
 
 
118
  try:
119
- result = await done_task
120
- results[finished_category] = result
121
  except Exception as e:
122
  results[finished_category] = f"에러 발생: {str(e)}"
123
-
124
- # 현재까지 완료된 카테고리 결과를 yield (부분완료 상태 반환)
125
  yield results
126
 
 
127
  # Gradio 인터페이스
 
128
  with gr.Blocks() as iface:
129
  gr.Markdown("# AI 카피라이팅 생성기")
130
 
@@ -140,37 +147,35 @@ with gr.Blocks() as iface:
140
  output_box = gr.Textbox(label=category, visible=True)
141
  output_boxes[category] = output_box
142
 
143
- # validate_and_generate: 비동기 + 스트리밍
144
  async def validate_and_generate(topic):
145
  if not topic:
146
- # 주제가 비어있을 경우
147
  yield [gr.update(value="주제를 입력하세요")] + [gr.update(value="") for _ in CATEGORIES]
148
  return
149
 
150
  # 초기 상태
151
  yield [gr.update(value="카피라이팅 생성 중...")] + [gr.update(value="대기중...") for _ in CATEGORIES]
152
 
153
- # 'generate_copywriting_async'에서 완료될 때마다 partial_results를 yield
154
  async for partial_results in generate_copywriting_async(CATEGORIES, topic):
155
- # partial_results에는 지금까지 완료된 카테고리들의 결과가 담김
156
- box_updates = []
157
  for cat in CATEGORIES:
158
  if cat in partial_results:
159
- box_updates.append(gr.update(value=partial_results[cat]))
160
  else:
161
- box_updates.append(gr.update(value="대기중..."))
162
 
163
- # 진행상황 안내 (완료된 카테고리 개수 / 전체)
164
- yield [gr.update(value=f"진행상황: {len(partial_results)}/{len(CATEGORIES)} 카테고리 완료")] + box_updates
165
 
166
- # 버튼 클릭 시 validate_and_generate 호출 (비동기 + 스트리밍)
167
  generate_btn.click(
168
  fn=validate_and_generate,
169
  inputs=[topic],
170
  outputs=[status] + [output_boxes[cat] for cat in CATEGORIES],
171
- api_name="generate_copywriting",
172
  queue=True
173
  )
174
 
175
  # 인터페이스 실행
176
- iface.launch()
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # OpenAI API 설정
14
  openai.api_key = os.getenv("OPENAI_API_KEY")
15
 
16
+ # 카테고리
17
  CATEGORIES = [
18
  "공포 마케팅",
19
  "스토리텔링"
20
  ]
21
 
22
+ # 카테고리별 프롬프트 반환 함수
23
  def get_category_prompt(category):
24
  if category == "공포 마케팅":
25
  return """
 
69
  - 팔 없는 장애인 화가의 꿈을 이뤄준 첨단 그림 도구
70
  """
71
 
72
+ # 비동기 OpenAI API 호출
73
  async def call_api_async(session, content, system_message, max_tokens, temperature, top_p):
74
  url = "https://api.openai.com/v1/chat/completions"
75
  headers = {
 
86
  "temperature": temperature,
87
  "top_p": top_p,
88
  }
89
+ async with session.post(url, headers=headers, json=payload) as resp:
90
+ resp_json = await resp.json()
91
+ # 응답 형식 예외처리
92
+ if "choices" not in resp_json or len(resp_json["choices"]) == 0:
93
+ return "에러 발생: API 응답이 올바르지 않습니다."
94
+ return resp_json["choices"][0]["message"]["content"]
95
 
96
+ # 여러 카테고리를 "동시에" 비동로 처리
97
  async def generate_copywriting_async(categories, topic):
98
  max_tokens = 1000
99
  temperature = 0.8
 
101
  results = {}
102
 
103
  async with aiohttp.ClientSession() as session:
104
+ # 딕셔너 형태로 태스크 등록: { 카테고리: Task }
105
+ tasks = {}
106
+ for cat in categories:
107
+ prompt = get_category_prompt(cat)
 
 
108
  user_content = f"주제: {topic}"
109
+ tasks[cat] = asyncio.create_task(
 
110
  call_api_async(session, user_content, prompt, max_tokens, temperature, top_p)
111
  )
112
+
113
+ # 모든 태스크를 "동시에" 시작하고,
114
+ # 가장 먼저 끝난 태스크부터 순서대로 반환
115
+ for done_task in asyncio.as_completed(tasks.values()):
116
+ # 어느 카테고리가 찾아냄
117
+ finished_category = None
118
+ for cat, t in tasks.items():
119
+ if t == done_task:
120
+ finished_category = cat
121
+ break
122
+
123
+ # 결과 저장
124
  try:
125
+ results[finished_category] = await done_task
 
126
  except Exception as e:
127
  results[finished_category] = f"에러 발생: {str(e)}"
128
+
129
+ # 지금까지 부분 결과를 스트리밍
130
  yield results
131
 
132
+ #############################
133
  # Gradio 인터페이스
134
+ #############################
135
  with gr.Blocks() as iface:
136
  gr.Markdown("# AI 카피라이팅 생성기")
137
 
 
147
  output_box = gr.Textbox(label=category, visible=True)
148
  output_boxes[category] = output_box
149
 
150
+ # 비동기+스트리밍 함수
151
  async def validate_and_generate(topic):
152
  if not topic:
 
153
  yield [gr.update(value="주제를 입력하세요")] + [gr.update(value="") for _ in CATEGORIES]
154
  return
155
 
156
  # 초기 상태
157
  yield [gr.update(value="카피라이팅 생성 중...")] + [gr.update(value="대기중...") for _ in CATEGORIES]
158
 
159
+ # generate_copywriting_async에서 부분적으로 완료될 때마다 partial_results를 yield
160
  async for partial_results in generate_copywriting_async(CATEGORIES, topic):
161
+ # partial_results = { "공포 마케팅": "...", ... } 이미 끝난 것들만 저장
162
+ updates = []
163
  for cat in CATEGORIES:
164
  if cat in partial_results:
165
+ updates.append(gr.update(value=partial_results[cat]))
166
  else:
167
+ updates.append(gr.update(value="대기중..."))
168
 
169
+ yield [gr.update(value=f"진행상황: {len(partial_results)}/{len(CATEGORIES)} 카테고리 완료")] + updates
 
170
 
171
+ # 버튼 클릭 시 validate_and_generate 실행
172
  generate_btn.click(
173
  fn=validate_and_generate,
174
  inputs=[topic],
175
  outputs=[status] + [output_boxes[cat] for cat in CATEGORIES],
176
+ # Gradio에서 비동기 스트리밍을 지원하기 위해 queue=True
177
  queue=True
178
  )
179
 
180
  # 인터페이스 실행
181
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=False)