Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ import torch
|
|
| 8 |
from PIL import Image
|
| 9 |
import os
|
| 10 |
import re
|
| 11 |
-
import openai # OpenAI ์ถ๊ฐ
|
| 12 |
|
| 13 |
# ๊ฒฝ๋ก ๋ฐ ์ค์
|
| 14 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
|
@@ -18,7 +17,6 @@ CHECKPOINT_PATH = Path("wpkklhc6")
|
|
| 18 |
TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
|
| 19 |
|
| 20 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 21 |
-
openai.api_key = os.getenv("OPENAI_API_KEY") # OpenAI API ํค ์ค์
|
| 22 |
|
| 23 |
# ์ด๋ฏธ์ง ์ด๋ํฐ ์ ์
|
| 24 |
class ImageAdapter(nn.Module):
|
|
@@ -83,8 +81,19 @@ def stream_chat(input_image: Image.Image):
|
|
| 83 |
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device='cuda', dtype=torch.int64))
|
| 84 |
|
| 85 |
# ํ๋กฌํํธ ๊ตฌ์ฑ
|
| 86 |
-
inputs_embeds = torch.cat([
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
attention_mask = torch.ones_like(input_ids)
|
| 89 |
|
| 90 |
# ํ
์คํธ ์์ฑ
|
|
@@ -140,10 +149,10 @@ def replace_gender_specific_words(caption, gender_prefix):
|
|
| 140 |
if gender_prefix == "Korean man":
|
| 141 |
caption = re.sub(r'\bwoman\b', "man", caption, flags=re.IGNORECASE)
|
| 142 |
caption = re.sub(r'\bgirl\b', "boy", caption, flags=re.IGNORECASE)
|
| 143 |
-
caption = re.sub(r'\blady\b', "gentleman", flags=re.IGNORECASE)
|
| 144 |
-
caption = re.sub(r'\bshe\b', "he", flags=re.IGNORECASE)
|
| 145 |
-
caption = re.sub(r'\bher\b', "his", flags=re.IGNORECASE)
|
| 146 |
-
caption = re.sub(r'\bherself\b', "himself", flags=re.IGNORECASE)
|
| 147 |
elif gender_prefix == "Korean woman":
|
| 148 |
caption = re.sub(r'\bman\b', "woman", caption, flags=re.IGNORECASE)
|
| 149 |
caption = re.sub(r'\bboy\b', "girl", caption, flags=re.IGNORECASE)
|
|
@@ -167,20 +176,6 @@ def replace_gender_words(caption, gender, age, hair_length, hair_style, hair_col
|
|
| 167 |
caption = replace_gender_specific_words(caption, gender_prefix)
|
| 168 |
return f"{gender_prefix}, age {age}, {hair_description}: {caption}"
|
| 169 |
|
| 170 |
-
# GPT-4o mini API ํธ์ถ ํจ์ ์ถ๊ฐ
|
| 171 |
-
def call_gpt4o_api(content, system_message, max_tokens, temperature, top_p):
|
| 172 |
-
response = openai.ChatCompletion.create(
|
| 173 |
-
model="gpt-4o-mini", # OpenAI GPT-4o Mini ๋ชจ๋ธ ์ฌ์ฉ
|
| 174 |
-
messages=[
|
| 175 |
-
{"role": "system", "content": system_message},
|
| 176 |
-
{"role": "user", "content": content},
|
| 177 |
-
],
|
| 178 |
-
max_tokens=max_tokens,
|
| 179 |
-
temperature=temperature,
|
| 180 |
-
top_p=top_p,
|
| 181 |
-
)
|
| 182 |
-
return response.choices[0].message['content']
|
| 183 |
-
|
| 184 |
# Recaption ํจ์
|
| 185 |
def recaption(input_image: Image.Image, prefix: str, age: int, hair_length: str, hair_style: str, hair_color: str, hair_accessory: str):
|
| 186 |
original_caption = stream_chat(input_image)
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
import os
|
| 10 |
import re
|
|
|
|
| 11 |
|
| 12 |
# ๊ฒฝ๋ก ๋ฐ ์ค์
|
| 13 |
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
|
|
|
| 17 |
TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
|
| 18 |
|
| 19 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
|
| 20 |
|
| 21 |
# ์ด๋ฏธ์ง ์ด๋ํฐ ์ ์
|
| 22 |
class ImageAdapter(nn.Module):
|
|
|
|
| 81 |
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device='cuda', dtype=torch.int64))
|
| 82 |
|
| 83 |
# ํ๋กฌํํธ ๊ตฌ์ฑ
|
| 84 |
+
inputs_embeds = torch.cat([
|
| 85 |
+
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
| 86 |
+
embedded_images,
|
| 87 |
+
prompt_embeds
|
| 88 |
+
], dim=1)
|
| 89 |
+
|
| 90 |
+
# CPU์ ์๋ ํ
์๋ฅผ GPU๋ก ์ด๋
|
| 91 |
+
input_ids = torch.cat([
|
| 92 |
+
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to('cuda'),
|
| 93 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long).to('cuda'),
|
| 94 |
+
prompt.to('cuda')
|
| 95 |
+
], dim=1)
|
| 96 |
+
|
| 97 |
attention_mask = torch.ones_like(input_ids)
|
| 98 |
|
| 99 |
# ํ
์คํธ ์์ฑ
|
|
|
|
| 149 |
if gender_prefix == "Korean man":
|
| 150 |
caption = re.sub(r'\bwoman\b', "man", caption, flags=re.IGNORECASE)
|
| 151 |
caption = re.sub(r'\bgirl\b', "boy", caption, flags=re.IGNORECASE)
|
| 152 |
+
caption = re.sub(r'\blady\b', "gentleman", caption, flags=re.IGNORECASE)
|
| 153 |
+
caption = re.sub(r'\bshe\b', "he", caption, flags=re.IGNORECASE)
|
| 154 |
+
caption = re.sub(r'\bher\b', "his", caption, flags=re.IGNORECASE)
|
| 155 |
+
caption = re.sub(r'\bherself\b', "himself", caption, flags=re.IGNORECASE)
|
| 156 |
elif gender_prefix == "Korean woman":
|
| 157 |
caption = re.sub(r'\bman\b', "woman", caption, flags=re.IGNORECASE)
|
| 158 |
caption = re.sub(r'\bboy\b', "girl", caption, flags=re.IGNORECASE)
|
|
|
|
| 176 |
caption = replace_gender_specific_words(caption, gender_prefix)
|
| 177 |
return f"{gender_prefix}, age {age}, {hair_description}: {caption}"
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# Recaption ํจ์
|
| 180 |
def recaption(input_image: Image.Image, prefix: str, age: int, hair_length: str, hair_style: str, hair_color: str, hair_accessory: str):
|
| 181 |
original_caption = stream_chat(input_image)
|