kijeoung commited on
Commit
bb5b825
ยท
verified ยท
1 Parent(s): b933b5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  from PIL import Image
9
  import os
10
  import re
11
- import openai # OpenAI ์ถ”๊ฐ€
12
 
13
  # ๊ฒฝ๋กœ ๋ฐ ์„ค์ •
14
  CLIP_PATH = "google/siglip-so400m-patch14-384"
@@ -18,7 +17,6 @@ CHECKPOINT_PATH = Path("wpkklhc6")
18
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
19
 
20
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
21
- openai.api_key = os.getenv("OPENAI_API_KEY") # OpenAI API ํ‚ค ์„ค์ •
22
 
23
  # ์ด๋ฏธ์ง€ ์–ด๋Œ‘ํ„ฐ ์ •์˜
24
  class ImageAdapter(nn.Module):
@@ -83,8 +81,19 @@ def stream_chat(input_image: Image.Image):
83
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device='cuda', dtype=torch.int64))
84
 
85
  # ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
86
- inputs_embeds = torch.cat([embedded_bos.expand(embedded_images.shape[0], -1, -1), embedded_images, prompt_embeds], dim=1)
87
- input_ids = torch.cat([torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long), torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), prompt], dim=1).to('cuda')
 
 
 
 
 
 
 
 
 
 
 
88
  attention_mask = torch.ones_like(input_ids)
89
 
90
  # ํ…์ŠคํŠธ ์ƒ์„ฑ
@@ -140,10 +149,10 @@ def replace_gender_specific_words(caption, gender_prefix):
140
  if gender_prefix == "Korean man":
141
  caption = re.sub(r'\bwoman\b', "man", caption, flags=re.IGNORECASE)
142
  caption = re.sub(r'\bgirl\b', "boy", caption, flags=re.IGNORECASE)
143
- caption = re.sub(r'\blady\b', "gentleman", flags=re.IGNORECASE)
144
- caption = re.sub(r'\bshe\b', "he", flags=re.IGNORECASE)
145
- caption = re.sub(r'\bher\b', "his", flags=re.IGNORECASE)
146
- caption = re.sub(r'\bherself\b', "himself", flags=re.IGNORECASE)
147
  elif gender_prefix == "Korean woman":
148
  caption = re.sub(r'\bman\b', "woman", caption, flags=re.IGNORECASE)
149
  caption = re.sub(r'\bboy\b', "girl", caption, flags=re.IGNORECASE)
@@ -167,20 +176,6 @@ def replace_gender_words(caption, gender, age, hair_length, hair_style, hair_col
167
  caption = replace_gender_specific_words(caption, gender_prefix)
168
  return f"{gender_prefix}, age {age}, {hair_description}: {caption}"
169
 
170
- # GPT-4o mini API ํ˜ธ์ถœ ํ•จ์ˆ˜ ์ถ”๊ฐ€
171
- def call_gpt4o_api(content, system_message, max_tokens, temperature, top_p):
172
- response = openai.ChatCompletion.create(
173
- model="gpt-4o-mini", # OpenAI GPT-4o Mini ๋ชจ๋ธ ์‚ฌ์šฉ
174
- messages=[
175
- {"role": "system", "content": system_message},
176
- {"role": "user", "content": content},
177
- ],
178
- max_tokens=max_tokens,
179
- temperature=temperature,
180
- top_p=top_p,
181
- )
182
- return response.choices[0].message['content']
183
-
184
  # Recaption ํ•จ์ˆ˜
185
  def recaption(input_image: Image.Image, prefix: str, age: int, hair_length: str, hair_style: str, hair_color: str, hair_accessory: str):
186
  original_caption = stream_chat(input_image)
 
8
  from PIL import Image
9
  import os
10
  import re
 
11
 
12
  # ๊ฒฝ๋กœ ๋ฐ ์„ค์ •
13
  CLIP_PATH = "google/siglip-so400m-patch14-384"
 
17
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
18
 
19
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
20
 
21
  # ์ด๋ฏธ์ง€ ์–ด๋Œ‘ํ„ฐ ์ •์˜
22
  class ImageAdapter(nn.Module):
 
81
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device='cuda', dtype=torch.int64))
82
 
83
  # ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
84
+ inputs_embeds = torch.cat([
85
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
86
+ embedded_images,
87
+ prompt_embeds
88
+ ], dim=1)
89
+
90
+ # CPU์— ์žˆ๋Š” ํ…์„œ๋ฅผ GPU๋กœ ์ด๋™
91
+ input_ids = torch.cat([
92
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to('cuda'),
93
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long).to('cuda'),
94
+ prompt.to('cuda')
95
+ ], dim=1)
96
+
97
  attention_mask = torch.ones_like(input_ids)
98
 
99
  # ํ…์ŠคํŠธ ์ƒ์„ฑ
 
149
  if gender_prefix == "Korean man":
150
  caption = re.sub(r'\bwoman\b', "man", caption, flags=re.IGNORECASE)
151
  caption = re.sub(r'\bgirl\b', "boy", caption, flags=re.IGNORECASE)
152
+ caption = re.sub(r'\blady\b', "gentleman", caption, flags=re.IGNORECASE)
153
+ caption = re.sub(r'\bshe\b', "he", caption, flags=re.IGNORECASE)
154
+ caption = re.sub(r'\bher\b', "his", caption, flags=re.IGNORECASE)
155
+ caption = re.sub(r'\bherself\b', "himself", caption, flags=re.IGNORECASE)
156
  elif gender_prefix == "Korean woman":
157
  caption = re.sub(r'\bman\b', "woman", caption, flags=re.IGNORECASE)
158
  caption = re.sub(r'\bboy\b', "girl", caption, flags=re.IGNORECASE)
 
176
  caption = replace_gender_specific_words(caption, gender_prefix)
177
  return f"{gender_prefix}, age {age}, {hair_description}: {caption}"
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Recaption ํ•จ์ˆ˜
180
  def recaption(input_image: Image.Image, prefix: str, age: int, hair_length: str, hair_style: str, hair_color: str, hair_accessory: str):
181
  original_caption = stream_chat(input_image)