Style_Rec / app.py
kyoyejin's picture
Upload 2 files
18bbed8 verified
import os
import gradio as gr
from openai import OpenAI
from PIL import Image
from io import BytesIO
import traceback
import requests
#os.environ["OPENAI_API_KEY"] = "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"
#os.environ["STABILITY_API_KEY"] = "sk-HvTCuoGI1JjUoJD2GoJ4xi7BDycodYLEIFFU6t3nJDORusLb"
# ===== 1. ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ API ํ‚ค ์ฝ๊ธฐ =====
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
STABILITY_API_KEY = os.getenv("STABILITY_API_KEY")
if OPENAI_API_KEY is None:
raise ValueError("OPENAI_API_KEY ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
if STABILITY_API_KEY is None:
raise ValueError("STABILITY_API_KEY ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
client = OpenAI(api_key=OPENAI_API_KEY)
MODEL = "gpt-4o-mini"
STABILITY_ENDPOINT = "https://api.stability.ai/v2beta/stable-image/generate/core"
# ===== 2. ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฒ•๋ณ„ System Prompt =====
SYS_COT = """๋‹น์‹ ์€ ์‹œ์ƒ์‹ ๊ฒฝํ—˜์ด ๋งŽ์€ ์ „๋ฌธ ์Šคํƒ€์ผ๋ฆฌ์ŠคํŠธ AI์ž…๋‹ˆ๋‹ค.
- ๋‚ด๋ถ€ ์‚ฌ๊ณ ๊ณผ์ •์€ ์ˆจ๊ธฐ๊ณ  ๊ฒฐ๊ณผ ์š”์•ฝ๊ณผ ํ•ต์‹ฌ ๊ทผ๊ฑฐ๋งŒ ์ถœ๋ ฅ.
- ์˜ˆ์‚ฐ/์‹œ๊ฐ„/๋™์„ /์กฐ๋ช…/์ฒดํ˜•/ํ”ผ๋ถ€ํ†ค/๋“œ๋ ˆ์Šค์ฝ”๋“œ ๋ฐ˜์˜, ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€.
"""
SYS_TOT = """๋‹น์‹ ์€ 'ํŒจ์…˜ ์ „๋žต๊ฐ€'์ž…๋‹ˆ๋‹ค.
- ์„œ๋กœ ๋‹ค๋ฅธ 3๊ฐ€์ง€ ์Šคํƒ€์ผ ๊ฒฝ๋กœ ์ œ์‹œ(๊ฐ 5~7์ค„) โ†’ ๊ฐ„๋‹จ ํ‰๊ฐ€ํ‘œ(1~5) โ†’ ์ตœ์ข…์•ˆ.
- ๋‚ด๋ถ€ ๋‚˜๋ญ‡๊ฐ€์ง€ ์ถ”๋ก ์€ ์ˆจ๊ธฐ๊ณ  ๊ฒฐ๊ณผ๋งŒ ์ถœ๋ ฅ, ํ•œ๊ตญ์–ด.
"""
SYS_SELFCONS = """๋‹น์‹ ์€ '์Šคํƒ€์ผ ์ œ์•ˆ ์ปจ์„คํ„ดํŠธ'์ž…๋‹ˆ๋‹ค.
- ๊ด€์ ์ด ๋‹ค๋ฅธ 3์•ˆ ์ œ์‹œ(ํด๋ž˜์‹/๋ชจ๋˜/๋Œ€๋‹ด) โ†’ ๊ธฐ์ค€๋ณ„(๋ฌด๋Œ€ ์ ํ•ฉ์„ฑ/์‚ฌ์ง„๋ฐœ/์ฐฉ์šฉ ๋‚œ์ด๋„/๋ฆฌ์Šคํฌ) ์ ์ˆ˜ํ™” โ†’ ์ตœ์ข…์•ˆ ์„ ํƒ.
- ๋‚ด๋ถ€ ํ•ฉ์˜ ๊ณผ์ •์€ ์ˆจ๊ธฐ๊ณ  ๊ฒฐ๊ณผ๋งŒ ์ถœ๋ ฅ, ํ•œ๊ตญ์–ด.
"""
SYS_REACT = """๋‹น์‹ ์€ '์Šคํƒ€์ผ๋ง ์˜คํผ๋ ˆ์ด์…˜ ๋งค๋‹ˆ์ €'์ž…๋‹ˆ๋‹ค.
- Plan โ†’ Action โ†’ Observation โ†’ Update 2~3ํšŒ ํ›„ Final ์ œ์•ˆ.
- ๋‚ด๋ถ€ ์‚ฌ๊ณ ๋Š” ์ˆจ๊ธฐ๊ณ  ๊ฐ ๋‹จ๊ณ„ 1~3์ค„, ํ•œ๊ตญ์–ด.
"""
METHOD_TO_SYS = {
"Chain-of-Thought(์š”์•ฝํ˜•)": SYS_COT,
"Tree-of-Thought": SYS_TOT,
"Self-Consistency": SYS_SELFCONS,
"ReAct": SYS_REACT,
}
# ===== 3. RCCF ํ”„๋กฌํ”„ํŠธ ๋นŒ๋” & ํ…์ŠคํŠธ ์Šคํƒ€์ผ ์ถ”์ฒœ =====
def build_rccf_prompt(role, context, constraints, fmt):
return f"""[Role]
{role}
[Context]
{context}
[Constraints]
{constraints}
[Format]
{fmt}
"""
def run_chat(method_name, role, context, constraints, fmt, model=MODEL, temperature=0.7):
if method_name not in METHOD_TO_SYS:
raise ValueError("๊ธฐ๋ฒ• ์„ ํƒ์ด ์˜ฌ๋ฐ”๋ฅด์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
system_prompt = METHOD_TO_SYS[method_name]
user_prompt = build_rccf_prompt(role, context, constraints, fmt)
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=float(temperature),
)
return f"### ๊ธฐ๋ฒ•: **{method_name}**\n\n" + resp.choices[0].message.content
# ===== 4. ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ…์ŠคํŠธ โ†’ ์˜์–ด ์ด๋ฏธ์ง€ ํ”„๋กฌํ”„ํŠธ =====
def build_image_prompt_from_text(style_text: str) -> str:
system_msg = (
"You are a fashion illustration prompt engineer for an image model. "
"Given a Korean styling recommendation, output ONE concise English prompt "
"for a full-body fashion illustration for a red carpet or award show. "
"40 words or less. Focus on outfit, colors, silhouette, and overall vibe. "
"No explanations, just the prompt."
)
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": style_text},
],
temperature=0.4,
)
prompt_en = resp.choices[0].message.content.strip()
return prompt_en
# ===== 5. Stability AI๋กœ ์ด๋ฏธ์ง€ ์ƒ์„ฑ =====
def generate_fashion_image_with_stability(style_text: str) -> Image.Image:
img_prompt = build_image_prompt_from_text(style_text)
headers = {
"Authorization": f"Bearer {STABILITY_API_KEY}",
"Accept": "image/*",
}
files = {
"prompt": (None, img_prompt),
"output_format": (None, "png"),
}
response = requests.post(
STABILITY_ENDPOINT,
headers=headers,
files=files,
timeout=120,
)
if response.status_code != 200:
raise RuntimeError(
f"Stability API ์˜ค๋ฅ˜: {response.status_code} - {response.text[:500]}"
)
img = Image.open(BytesIO(response.content))
return img
# ===== 6. Gradio ์ธํ„ฐํŽ˜์ด์Šค ํ•จ์ˆ˜ =====
def interface_fn(role, context, constraints, fmt, method, temperature):
try:
style_text = run_chat(method, role, context, constraints, fmt, MODEL, temperature)
style_image = generate_fashion_image_with_stability(style_text)
return style_text, style_image
except Exception as e:
traceback.print_exc()
err_text = (
f"โš  ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}\n\n"
f"Space Logs์—์„œ traceback์„ ํ™•์ธํ•ด ์ฃผ์„ธ์š”."
)
return err_text, None
# ===== 7. Gradio UI ์ •์˜ =====
demo = gr.Interface(
fn=interface_fn,
inputs=[
gr.Textbox(label="Role", value="์‹œ์ƒ์‹ ์ „๋ฌธ ์Šคํƒ€์ผ๋ฆฌ์ŠคํŠธ", lines=2),
gr.Textbox(
label="Context",
value="๋ฐฐ์šฐ ํ•œ์„œ์—ฐ์ด ์ฒญ๋ฃก์˜ํ™”์ƒ ๋ ˆ๋“œ์นดํŽซ ์ฐธ์„. ๋‰ดํŠธ๋Ÿด-์›œ ํ†ค, ํ‚ค 167cm. ๋ฏธ๋‹ˆ๋ฉ€ & ๋ชจ๋˜ ์„ ํ˜ธ.",
lines=3,
),
gr.Textbox(
label="Constraints",
value="์ด๋™ ๋™์„  ์ด‰๋ฐ•, 2๋ฒŒ ํ”ผํŒ…, ํ—ค์–ด/๋ฉ”์ดํฌ์—… ๊ฐ 50๋ถ„, ์˜ˆ์‚ฐ 1,000๋งŒ์› ํ˜‘์ฐฌ ์šฐ์„ ",
lines=3,
),
gr.Textbox(
label="Format",
value=(
"1) ๋ฃฉ ์ฝ˜์…‰ํŠธ(ํ•œ ์ค„)\n"
"2) ์˜์ƒ ์ œ์•ˆ(๋“œ๋ ˆ์Šค 2์•ˆ + ๋Œ€์•ˆ ์ˆ˜ํŠธ 1์•ˆ)\n"
"3) ์•ก์„ธ์„œ๋ฆฌยท์Šˆ์ฆˆยทํด๋Ÿฌ์น˜\n"
"4) ํ—ค์–ด & ๋ฉ”์ดํฌ์—…(ํ‚คํฌ์ธํŠธ 3๊ฐœ)\n"
"5) ๋ฆฌ์Šคํฌ & ์™„ํ™”(์ตœ๋Œ€ 3๊ฐœ)\n"
"6) ํ˜„์žฅ ํƒ€์ž„๋ผ์ธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ(6๋‹จ๊ณ„)"
),
lines=8,
),
gr.Dropdown(
label="ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฒ•",
choices=list(METHOD_TO_SYS.keys()),
value="Chain-of-Thought(์š”์•ฝํ˜•)",
),
gr.Slider(0.0, 1.2, value=0.7, step=0.1, label="Temperature"),
],
outputs=[
gr.Markdown(label="์Šคํƒ€์ผ ์ถ”์ฒœ ๊ฒฐ๊ณผ"),
gr.Image(label="ํŒจ์…˜ ์ผ๋Ÿฌ์ŠคํŠธ", type="pil"),
],
title="๋‚˜๋งŒ์˜ ์Šคํƒ€์ผ๋ฆฌ์ŠคํŠธ ์ฑ—๋ด‡ โ€” ํ…์ŠคํŠธ(OpenAI) + ์ด๋ฏธ์ง€(Stability AI)",
description="RCCF + ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ฒ•์œผ๋กœ ์Šคํƒ€์ผ์„ ์ถ”์ฒœํ•˜๊ณ , ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ Stable Diffusion์œผ๋กœ ํŒจ์…˜ ์ผ๋Ÿฌ์ŠคํŠธ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.",
)
# HF Spaces์—์„œ๋Š” launch()์— share ํ•„์š” X
if __name__ == "__main__":
demo.launch()