Hug0endob commited on
Commit
4462cb3
Β·
verified Β·
1 Parent(s): d914e3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -34
app.py CHANGED
@@ -2,25 +2,28 @@ import gradio as gr
2
  from PIL import Image
3
  import requests
4
  from io import BytesIO
5
- import re
6
  import torch
7
- from transformers import BlipForConditionalGeneration, BlipProcessor, T5ForConditionalGeneration, T5Tokenizer
 
 
 
 
 
 
 
8
 
9
  device = torch.device("cpu")
10
 
11
- # Load models (CPU-friendly)
12
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
13
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
 
 
14
 
 
15
  rewriter_tokenizer = T5Tokenizer.from_pretrained("t5-small")
16
  rewriter = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
17
-
18
- # Safety patterns (simple filter)
19
- SENSITIVE_PATTERNS = [
20
- r"\b(nude|naked|porn|sex|sexual|explicit|hardcore)\b",
21
- r"\b(blood|gore|mutilat|disembowel|organs)\b",
22
- ]
23
- SENSITIVE_RE = re.compile("|".join(SENSITIVE_PATTERNS), flags=re.IGNORECASE)
24
 
25
  def load_image_from_url(url: str, timeout=10):
26
  try:
@@ -31,18 +34,14 @@ def load_image_from_url(url: str, timeout=10):
31
  except Exception as e:
32
  return None, f"Error loading image: {e}"
33
 
34
- def is_caption_allowed(text: str):
35
- if not text:
36
- return False
37
- return SENSITIVE_RE.search(text) is None
38
-
39
- def generate_caption(img: Image.Image, max_len:int=30):
40
- inputs = processor(images=img, return_tensors="pt").to(device)
41
- out = model.generate(**inputs, max_length=max_len, num_beams=3, early_stopping=True)
42
- caption = processor.decode(out[0], skip_special_tokens=True).strip()
43
  return caption
44
 
45
- def rewrite_caption(caption: str, max_len:int=64):
46
  input_text = "paraphrase: " + caption
47
  tok = rewriter_tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
48
  out = rewriter.generate(**tok, max_length=max_len, num_beams=2, early_stopping=True)
@@ -54,29 +53,18 @@ def describe_image(url: str, max_caption_len: int = 30, expand: bool = True):
54
  if err:
55
  return None, f"Error: {err}"
56
  caption = generate_caption(img, max_len=max_caption_len)
57
- if not is_caption_allowed(caption):
58
- # Provide a neutral, respectful safety message with next steps
59
- safety_msg = ("A descriptive caption was not provided because the image may contain explicit or graphic content. "
60
- "If this is unexpected, try a different image or upload a cropped/edited version that removes sensitive content.")
61
- return img, safety_msg
62
  if expand:
63
  try:
64
  caption = rewrite_caption(caption, max_len=64)
65
- if not is_caption_allowed(caption):
66
- safety_msg = ("A descriptive caption was not provided because the generated text may describe explicit or graphic content. "
67
- "Try a different image or disable the expansion option.")
68
- return img, safety_msg
69
  except Exception:
70
  pass
71
- # Make caption more descriptive by appending structural cues (objects, colors, setting)
72
- # Quick heuristic: if short, expand with a simple template
73
  if len(caption.split()) < 6:
74
  caption = f"{caption}. The scene appears to contain: {caption.lower()}."
75
  return img, caption
76
 
77
  # Gradio UI: image left, caption right
78
  with gr.Blocks() as demo:
79
- gr.Markdown("## Image captioning β€” image on the left, descriptive caption on the right (CPU-optimized)")
80
  with gr.Row():
81
  with gr.Column(scale=1):
82
  url_in = gr.Textbox(label="Image URL", placeholder="https://example.com/photo.jpg")
 
2
  from PIL import Image
3
  import requests
4
  from io import BytesIO
 
5
  import torch
6
+ from transformers import (
7
+ VisionEncoderDecoderModel,
8
+ ViTImageProcessor,
9
+ AutoTokenizer,
10
+ T5ForConditionalGeneration,
11
+ T5Tokenizer,
12
+ )
13
+ import re
14
 
15
  device = torch.device("cpu")
16
 
17
+ # Image captioning model (nlpconnect/vit-gpt2-image-captioning)
18
+ processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
19
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
20
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
21
+ model.eval()
22
 
23
+ # Rewriter (T5)
24
  rewriter_tokenizer = T5Tokenizer.from_pretrained("t5-small")
25
  rewriter = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
26
+ rewriter.eval()
 
 
 
 
 
 
27
 
28
  def load_image_from_url(url: str, timeout=10):
29
  try:
 
34
  except Exception as e:
35
  return None, f"Error loading image: {e}"
36
 
37
+ def generate_caption(img: Image.Image, max_len: int = 30):
38
+ inputs = processor(images=img, return_tensors="pt")
39
+ pixel_values = inputs.pixel_values.to(device)
40
+ out = model.generate(pixel_values, max_length=max_len, num_beams=2, early_stopping=True)
41
+ caption = tokenizer.decode(out[0], skip_special_tokens=True).strip()
 
 
 
 
42
  return caption
43
 
44
+ def rewrite_caption(caption: str, max_len: int = 64):
45
  input_text = "paraphrase: " + caption
46
  tok = rewriter_tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
47
  out = rewriter.generate(**tok, max_length=max_len, num_beams=2, early_stopping=True)
 
53
  if err:
54
  return None, f"Error: {err}"
55
  caption = generate_caption(img, max_len=max_caption_len)
 
 
 
 
 
56
  if expand:
57
  try:
58
  caption = rewrite_caption(caption, max_len=64)
 
 
 
 
59
  except Exception:
60
  pass
 
 
61
  if len(caption.split()) < 6:
62
  caption = f"{caption}. The scene appears to contain: {caption.lower()}."
63
  return img, caption
64
 
65
  # Gradio UI: image left, caption right
66
  with gr.Blocks() as demo:
67
+ gr.Markdown("## Image captioning β€” image on the left, descriptive caption on the right (CPU-optimized, uncensored)")
68
  with gr.Row():
69
  with gr.Column(scale=1):
70
  url_in = gr.Textbox(label="Image URL", placeholder="https://example.com/photo.jpg")