nickdigger commited on
Commit
6a131b7
Β·
verified Β·
1 Parent(s): ce6ad7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -154
app.py CHANGED
@@ -15,7 +15,6 @@ except Exception:
15
 
16
  @spaces.GPU()
17
  def _joycaption_register_gpu():
18
- """Dummy GPU registration for HF Spaces."""
19
  return None
20
 
21
  import gradio as gr
@@ -23,15 +22,12 @@ import torch
23
  from transformers import LlavaForConditionalGeneration, AutoProcessor
24
  from PIL import Image
25
  import tempfile, gc, os, shutil, json
26
- from pathlib import Path
27
  from hf_space_utils import fix_image_url, postprocess_caption
28
 
29
  # ---------- Cache paths ----------
30
  _tmpdir = tempfile.gettempdir()
31
- os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
32
- os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
33
- os.environ["HF_DATASETS_CACHE"] = os.path.join(_tmpdir, "datasets_cache")
34
- os.environ["TORCH_HOME"] = os.path.join(_tmpdir, "torch_cache")
35
 
36
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
37
  SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
@@ -52,38 +48,32 @@ def cleanup_storage():
52
  print(f"⚠️ Cleanup warning: {e}")
53
 
54
  TITLE = """
55
- <div style="text-align: center; margin: 20px 0;">
56
- <h1>🎨 JoyCaption Three-Tone + Q&A (v2.7)</h1>
57
- <p><em>Stable version β€” all tensor and decode bugs fixed, faster responses.</em></p>
58
  </div><hr>
59
  """
60
 
61
- print("πŸš€ Initializing JoyCaption...")
62
  cleanup_storage()
63
 
64
  # ---------- Model load ----------
65
  processor = None
66
  model = None
67
- MODEL_TORCH_DTYPE = None
68
- MODEL_USE_CUDA = False
69
 
70
  if not os.environ.get("SKIP_MODEL_LOAD"):
71
- use_cuda = torch.cuda.is_available()
72
- if use_cuda:
73
- torch_dtype = getattr(torch, "bfloat16", None) or getattr(torch, "float16", None)
74
- device_map = "auto"
75
- MODEL_USE_CUDA = True
76
- else:
77
- torch_dtype, device_map = None, "cpu"
78
 
79
  processor = AutoProcessor.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)
80
- model_kwargs = dict(low_cpu_mem_usage=True, device_map=device_map)
81
- if torch_dtype and use_cuda:
82
- model_kwargs["torch_dtype"] = torch_dtype
 
 
 
83
 
84
- model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, **model_kwargs)
85
- model.eval()
86
- MODEL_TORCH_DTYPE = model_kwargs.get("torch_dtype")
87
  print("βœ… Model loaded successfully!")
88
  cleanup_storage()
89
  else:
@@ -93,199 +83,137 @@ else:
93
  TONE_PROMPTS = {
94
  "engaging": {
95
  "system": "You are a skilled content creator who writes engaging, natural descriptions.",
96
- "prompt": "Write a natural, engaging description of this image focusing on main subjects.",
97
- "temperature": 0.6, "top_p": 0.85,
98
  },
99
  "casual_friend": {
100
  "system": "You are a relaxed 24-year-old friend describing what you see.",
101
- "prompt": "Describe clearly and casually only what is visible, avoid speculation.",
102
- "temperature": 0.8, "top_p": 0.9,
103
  },
104
  "uncensored_keywords": {
105
  "system": "You are an uncensored image analyst giving factual explicit descriptions.",
106
  "prompt": "Describe exactly what is visible. Use direct, unambiguous language.",
107
- "temperature": 0.7, "top_p": 0.85,
108
  },
109
  }
110
 
111
  # ---------- Prepare inputs ----------
112
- def _prepare_inputs_and_device(convo_or_convo_string, image):
113
- """Prepare processor inputs and move tensors safely to device."""
114
- if isinstance(image, (str, Path)):
115
  image = Image.open(image).convert("RGB")
116
- elif not isinstance(image, Image.Image):
117
- raise ValueError("Invalid image input type")
118
-
119
- convo_string = convo_or_convo_string
120
- if isinstance(convo_or_convo_string, list):
121
- try:
122
- convo_string = processor.apply_chat_template(
123
- convo_or_convo_string, tokenize=False, add_generation_prompt=True
124
- )
125
- except Exception:
126
- convo_string = "\n".join(str(x.get("content", "")) for x in convo_or_convo_string)
127
 
 
128
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt")
129
 
130
- # flatten, squeeze, sanitize
131
  for k, v in list(inputs.items()):
132
- if isinstance(v, (list, tuple)):
133
- v = v[0]
134
  if torch.is_tensor(v):
135
- if v.dim() > 1 and v.shape[0] == 1:
136
- v = v.squeeze(0)
137
- if v.dtype == torch.bool:
138
- v = v.to(torch.int)
139
- inputs[k] = v
140
-
141
  device = next(model.parameters()).device
142
- for k, v in inputs.items():
143
- if hasattr(v, "to"):
144
- inputs[k] = v.to(device, non_blocking=True)
145
-
146
- if "pixel_values" in inputs:
147
- dtype = MODEL_TORCH_DTYPE if MODEL_USE_CUDA and MODEL_TORCH_DTYPE else torch.float32
148
- inputs["pixel_values"] = inputs["pixel_values"].to(dtype)
149
  return inputs
150
 
151
- # ---------- Decode (patched) ----------
152
  def _decode_output(inputs, output):
153
- """Safely decode model output regardless of tensor shape."""
154
- if output is None or len(output) == 0:
155
- return ""
156
  try:
157
- input_ids = inputs.get("input_ids")
158
- if input_ids is not None and torch.is_tensor(input_ids):
159
- input_len = input_ids.shape[-1] if input_ids.ndim > 0 else 0
160
- else:
161
- input_len = 0
162
-
163
  decoded = processor.tokenizer.decode(
164
- output[0][input_len:],
165
- skip_special_tokens=True,
166
- clean_up_tokenization_spaces=False,
167
  )
168
  return decoded.strip()
169
  except Exception as e:
170
- print(f"⚠️ Decode fallback due to: {e}")
171
  try:
172
- return processor.tokenizer.decode(
173
- output[0],
174
- skip_special_tokens=True,
175
- clean_up_tokenization_spaces=False,
176
- ).strip()
177
  except Exception:
178
  return ""
179
 
180
  def cleanup_after_inference():
181
- if torch.cuda.is_available():
182
- torch.cuda.empty_cache(); torch.cuda.synchronize()
183
  gc.collect()
 
 
 
184
 
185
  # ---------- Generation ----------
186
- def run_image_chat_generation(convo, image, max_new_tokens=150, temperature=0.7, top_p=0.9):
187
- if processor is None or model is None:
188
  return None, "❌ Model not initialized."
189
  try:
190
  inputs = _prepare_inputs_and_device(convo, image)
191
- clean_inputs = {}
192
- for k, v in inputs.items():
193
- if torch.is_tensor(v):
194
- if v.dtype == torch.bool:
195
- v = v.to(torch.int)
196
- if v.dim() == 0:
197
- v = v.unsqueeze(0)
198
- clean_inputs[k] = v
199
 
200
  with torch.no_grad():
201
- output = model.generate(
202
- **clean_inputs,
203
- max_new_tokens=max_new_tokens,
204
- do_sample=False,
205
- temperature=temperature,
206
- top_p=top_p,
207
- repetition_penalty=1.05,
208
- use_cache=True,
209
- pad_token_id=processor.tokenizer.eos_token_id,
210
- eos_token_id=processor.tokenizer.eos_token_id,
211
- )
212
- decoded = _decode_output(clean_inputs, output)
213
  cleanup_after_inference()
214
  return decoded, None
215
  except Exception as e:
216
  cleanup_after_inference()
217
- return None, f"❌ Generation error: {str(e)[:300]}"
218
 
219
  # ---------- Caption helpers ----------
220
- def safe_generate_caption_direct(image, tone, max_chars=600, keywords_text="", custom_instruction=""):
221
- try:
222
- tone_conf = TONE_PROMPTS.get(tone, TONE_PROMPTS["engaging"])
223
- base_prompt = tone_conf["prompt"]
224
- if tone == "uncensored_keywords" and keywords_text.strip():
225
- base_prompt += f"\n\nKeywords (if visible): {keywords_text.strip()}"
226
- if custom_instruction.strip():
227
- base_prompt += f"\n\nInclude this detail: {custom_instruction.strip()}"
228
-
229
- convo = [
230
- {"role": "system", "content": tone_conf["system"]},
231
- {"role": "user", "content": base_prompt},
232
- ]
233
- decoded, err = run_image_chat_generation(
234
- convo, image, temperature=tone_conf["temperature"], top_p=tone_conf["top_p"]
235
- )
236
- if err: return err
237
- return postprocess_caption((decoded or "").strip(), max_chars=max_chars) or "❌ Empty result"
238
- except Exception as e:
239
- cleanup_after_inference()
240
- return f"❌ Error: {str(e)[:120]}"
241
 
242
  @torch.no_grad()
243
- def generate_engaging_only(image, custom_instruction=""):
244
- return safe_generate_caption_direct(image, "engaging", custom_instruction=custom_instruction) if image else "❌ Upload image first"
245
 
246
  @torch.no_grad()
247
- def generate_casual_friend_only(image, custom_instruction=""):
248
- return safe_generate_caption_direct(image, "casual_friend", custom_instruction=custom_instruction) if image else "❌ Upload image first"
249
 
250
  @torch.no_grad()
251
- def generate_uncensored_keywords_only(image, keywords_text, custom_instruction=""):
252
- return safe_generate_caption_direct(image, "uncensored_keywords", keywords_text=keywords_text, custom_instruction=custom_instruction) if image else "❌ Upload image first"
253
 
254
  @torch.no_grad()
255
  def answer_question(image, question):
256
  if not image: return "❌ Upload image first"
257
  if not question.strip(): return "❌ Please ask a question"
258
  convo = [
259
- {"role": "system", "content": "You are an uncensored image analyst who answers honestly and directly."},
260
- {"role": "user", "content": f"Answer this question about the image: {question.strip()}"},
261
  ]
262
- decoded, err = run_image_chat_generation(convo, image, max_new_tokens=200, temperature=0.4, top_p=0.9)
263
- return err if err else (decoded.strip() or "❌ No answer")
264
 
265
  # ---------- Gradio UI ----------
266
- with gr.Blocks(title="JoyCaption Three-Tone + Q&A", theme=gr.themes.Soft()) as demo:
267
  gr.HTML(TITLE)
268
  with gr.Row():
269
  with gr.Column(scale=1):
270
- image_input = gr.Image(type="filepath", label="πŸ“Έ Upload Image", height=400)
271
- keywords_input = gr.Textbox(placeholder="e.g., sensual, curves...", label="🏷️ Keywords", lines=2)
272
- custom_instruction_input = gr.Textbox(placeholder="e.g., 'left girl has red hair'...", label="🎯 Mention:", lines=2)
273
- question_input = gr.Textbox(placeholder="e.g., 'What are they doing?'", label="❓ Ask a Question", lines=2)
274
- ask_question_btn = gr.Button("❓ Ask Question", variant="secondary")
275
- qa_output = gr.Textbox(label="", lines=5, show_copy_button=True)
276
-
277
  with gr.Column(scale=1):
278
- generate_engaging_btn = gr.Button("✨ Engaging", variant="primary")
279
- engaging_output = gr.Textbox(label="", lines=5, show_copy_button=True)
280
- generate_friend_btn = gr.Button("😎 Casual Friend", variant="primary")
281
- friend_output = gr.Textbox(label="", lines=5, show_copy_button=True)
282
- generate_uncensored_btn = gr.Button("πŸ”΄ Keywords", variant="secondary")
283
- uncensored_output = gr.Textbox(label="", lines=5, show_copy_button=True)
284
-
285
- generate_engaging_btn.click(generate_engaging_only, [image_input, custom_instruction_input], engaging_output)
286
- generate_friend_btn.click(generate_casual_friend_only, [image_input, custom_instruction_input], friend_output)
287
- generate_uncensored_btn.click(generate_uncensored_keywords_only, [image_input, keywords_input, custom_instruction_input], uncensored_output)
288
- ask_question_btn.click(answer_question, [image_input, question_input], qa_output)
289
 
290
  if __name__ == "__main__":
291
  demo.launch()
 
15
 
16
  @spaces.GPU()
17
  def _joycaption_register_gpu():
 
18
  return None
19
 
20
  import gradio as gr
 
22
  from transformers import LlavaForConditionalGeneration, AutoProcessor
23
  from PIL import Image
24
  import tempfile, gc, os, shutil, json
 
25
  from hf_space_utils import fix_image_url, postprocess_caption
26
 
27
  # ---------- Cache paths ----------
28
  _tmpdir = tempfile.gettempdir()
29
+ for k in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_DATASETS_CACHE", "TORCH_HOME"]:
30
+ os.environ[k] = os.path.join(_tmpdir, k.lower())
 
 
31
 
32
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
33
  SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
 
48
  print(f"⚠️ Cleanup warning: {e}")
49
 
50
  TITLE = """
51
+ <div style='text-align:center;margin:20px 0;'>
52
+ <h1>🎨 JoyCaption Three-Tone + Q&A (ZeroGPU Stable v3.0)</h1>
53
+ <p><em>Optimized for ZeroGPU – no tensor shape errors, no invalid flags.</em></p>
54
  </div><hr>
55
  """
56
 
57
+ print("πŸš€ Initializing JoyCaption ZeroGPU v3.0...")
58
  cleanup_storage()
59
 
60
  # ---------- Model load ----------
61
  processor = None
62
  model = None
63
+ MODEL_USE_CUDA = torch.cuda.is_available()
 
64
 
65
  if not os.environ.get("SKIP_MODEL_LOAD"):
66
+ dtype = getattr(torch, "bfloat16", None) or getattr(torch, "float16", None) if MODEL_USE_CUDA else None
67
+ device_map = "auto" if MODEL_USE_CUDA else "cpu"
 
 
 
 
 
68
 
69
  processor = AutoProcessor.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)
70
+ model = LlavaForConditionalGeneration.from_pretrained(
71
+ MODEL_PATH,
72
+ low_cpu_mem_usage=True,
73
+ device_map=device_map,
74
+ torch_dtype=dtype if MODEL_USE_CUDA else None,
75
+ ).eval()
76
 
 
 
 
77
  print("βœ… Model loaded successfully!")
78
  cleanup_storage()
79
  else:
 
83
  TONE_PROMPTS = {
84
  "engaging": {
85
  "system": "You are a skilled content creator who writes engaging, natural descriptions.",
86
+ "prompt": "Write a natural, engaging description focusing on the main subjects.",
 
87
  },
88
  "casual_friend": {
89
  "system": "You are a relaxed 24-year-old friend describing what you see.",
90
+ "prompt": "Describe clearly and casually only what is visible.",
 
91
  },
92
  "uncensored_keywords": {
93
  "system": "You are an uncensored image analyst giving factual explicit descriptions.",
94
  "prompt": "Describe exactly what is visible. Use direct, unambiguous language.",
 
95
  },
96
  }
97
 
98
  # ---------- Prepare inputs ----------
99
+ def _prepare_inputs_and_device(convo, image):
100
+ if isinstance(image, (str, os.PathLike)):
 
101
  image = Image.open(image).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
104
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt")
105
 
 
106
  for k, v in list(inputs.items()):
 
 
107
  if torch.is_tensor(v):
108
+ # ensure [1, seq_len]
109
+ if v.ndim == 1:
110
+ v = v.unsqueeze(0)
111
+ inputs[k] = v
 
 
112
  device = next(model.parameters()).device
113
+ inputs = {k: v.to(device) for k, v in inputs.items() if torch.is_tensor(v)}
 
 
 
 
 
 
114
  return inputs
115
 
116
+ # ---------- Decode ----------
117
  def _decode_output(inputs, output):
 
 
 
118
  try:
119
+ input_len = inputs["input_ids"].shape[-1] if "input_ids" in inputs else 0
 
 
 
 
 
120
  decoded = processor.tokenizer.decode(
121
+ output[0][input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
122
  )
123
  return decoded.strip()
124
  except Exception as e:
125
+ print(f"⚠️ Decode fallback: {e}")
126
  try:
127
+ return processor.tokenizer.decode(output[0], skip_special_tokens=True).strip()
 
 
 
 
128
  except Exception:
129
  return ""
130
 
131
  def cleanup_after_inference():
 
 
132
  gc.collect()
133
+ if torch.cuda.is_available():
134
+ torch.cuda.empty_cache()
135
+ torch.cuda.synchronize()
136
 
137
  # ---------- Generation ----------
138
+ def run_image_chat_generation(convo, image, max_new_tokens=150):
139
+ if not processor or not model:
140
  return None, "❌ Model not initialized."
141
  try:
142
  inputs = _prepare_inputs_and_device(convo, image)
143
+
144
+ # ZeroGPU fix: remove unsupported args
145
+ gen_kwargs = dict(
146
+ **inputs,
147
+ max_new_tokens=max_new_tokens,
148
+ pad_token_id=processor.tokenizer.eos_token_id,
149
+ eos_token_id=processor.tokenizer.eos_token_id,
150
+ )
151
 
152
  with torch.no_grad():
153
+ output = model.generate(**gen_kwargs)
154
+
155
+ decoded = _decode_output(inputs, output)
 
 
 
 
 
 
 
 
 
156
  cleanup_after_inference()
157
  return decoded, None
158
  except Exception as e:
159
  cleanup_after_inference()
160
+ return None, f"❌ Generation error: {str(e)}"
161
 
162
  # ---------- Caption helpers ----------
163
+ def safe_generate_caption_direct(image, tone):
164
+ tone_conf = TONE_PROMPTS.get(tone, TONE_PROMPTS["engaging"])
165
+ convo = [
166
+ {"role": "system", "content": tone_conf["system"]},
167
+ {"role": "user", "content": tone_conf["prompt"]},
168
+ ]
169
+ decoded, err = run_image_chat_generation(convo, image)
170
+ if err: return err
171
+ return postprocess_caption(decoded.strip()) if decoded else "❌ Empty result"
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  @torch.no_grad()
174
+ def generate_engaging_only(image):
175
+ return safe_generate_caption_direct(image, "engaging") if image else "❌ Upload image first"
176
 
177
  @torch.no_grad()
178
+ def generate_casual_friend_only(image):
179
+ return safe_generate_caption_direct(image, "casual_friend") if image else "❌ Upload image first"
180
 
181
  @torch.no_grad()
182
+ def generate_uncensored_keywords_only(image):
183
+ return safe_generate_caption_direct(image, "uncensored_keywords") if image else "❌ Upload image first"
184
 
185
  @torch.no_grad()
186
  def answer_question(image, question):
187
  if not image: return "❌ Upload image first"
188
  if not question.strip(): return "❌ Please ask a question"
189
  convo = [
190
+ {"role": "system", "content": "You are an honest image analyst who answers directly."},
191
+ {"role": "user", "content": f"Question about this image: {question.strip()}"},
192
  ]
193
+ decoded, err = run_image_chat_generation(convo, image, max_new_tokens=200)
194
+ return err if err else decoded.strip()
195
 
196
  # ---------- Gradio UI ----------
197
+ with gr.Blocks(title="JoyCaption ZeroGPU Stable", theme=gr.themes.Soft()) as demo:
198
  gr.HTML(TITLE)
199
  with gr.Row():
200
  with gr.Column(scale=1):
201
+ img = gr.Image(type="filepath", label="πŸ“Έ Upload Image", height=400)
202
+ q = gr.Textbox(label="❓ Ask a Question", lines=2)
203
+ ask = gr.Button("Ask")
204
+ qa = gr.Textbox(label="Answer", lines=4)
 
 
 
205
  with gr.Column(scale=1):
206
+ b1 = gr.Button("✨ Engaging")
207
+ o1 = gr.Textbox(lines=4)
208
+ b2 = gr.Button("😎 Casual Friend")
209
+ o2 = gr.Textbox(lines=4)
210
+ b3 = gr.Button("πŸ”΄ Keywords")
211
+ o3 = gr.Textbox(lines=4)
212
+
213
+ b1.click(generate_engaging_only, inputs=img, outputs=o1)
214
+ b2.click(generate_casual_friend_only, inputs=img, outputs=o2)
215
+ b3.click(generate_uncensored_keywords_only, inputs=img, outputs=o3)
216
+ ask.click(answer_question, inputs=[img, q], outputs=qa)
217
 
218
  if __name__ == "__main__":
219
  demo.launch()