Hug0endob commited on
Commit
5f87827
Β·
verified Β·
1 Parent(s): 02e399a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -3,16 +3,23 @@ from PIL import Image
3
  import requests
4
  from io import BytesIO
5
  import re
6
-
7
  import torch
8
- from transformers import BlipForConditionalGeneration, BlipProcessor
 
 
 
9
 
10
- # Load model (free, publicly available) once at startup
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
13
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
 
 
 
 
 
14
 
15
- # Simple content-safety checker (avoid explicit sexual/graphic violent text)
16
  SENSITIVE_PATTERNS = [
17
  r"\b(nude|naked|porn|sex|sexual|explicit|hardcore)\b",
18
  r"\b(blood|gore|mutilat|disembowel|organs)\b",
@@ -33,29 +40,48 @@ def is_caption_allowed(text: str):
33
  return False
34
  return SENSITIVE_RE.search(text) is None
35
 
36
- def describe_image(url: str, max_length: int = 32):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  img, err = load_image_from_url(url)
38
  if err:
39
  return None, err
40
- # prepare and generate
41
- inputs = processor(images=img, return_tensors="pt").to(device)
42
- out = model.generate(**inputs, max_length=max_length, num_beams=5, early_stopping=True)
43
- caption = processor.decode(out[0], skip_special_tokens=True).strip()
44
- # apply simple safety filter
45
  if not is_caption_allowed(caption):
46
- # return a sanitized, non-graphic alternative
47
- return img, "Caption omitted for safety. The image may contain explicit content."
 
 
 
 
 
 
48
  return img, caption
49
 
50
  with gr.Blocks() as demo:
51
- gr.Markdown("## Image URL β†’ Display + Caption (open-source, free model)")
52
  with gr.Row():
53
  url_in = gr.Textbox(label="Image URL", placeholder="https://example.com/photo.jpg")
54
- max_len = gr.Slider(minimum=10, maximum=80, value=32, label="Max caption length")
 
55
  img_out = gr.Image(type="pil", label="Image")
56
  caption_out = gr.Textbox(label="Caption")
57
  go = gr.Button("Load & Describe")
58
- go.click(fn=describe_image, inputs=[url_in, max_len], outputs=[img_out, caption_out])
59
 
60
  if __name__ == "__main__":
61
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import requests
4
  from io import BytesIO
5
  import re
 
6
  import torch
7
+ from transformers import BlipForConditionalGeneration, BlipProcessor, T5ForConditionalGeneration, T5Tokenizer
8
+
9
+ # Device: CPU-only
10
+ device = torch.device("cpu")
11
 
12
+ # Load caption model (small, CPU-friendly)
13
+ # HF model: Salesforce/blip-image-captioning-base
14
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
15
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
+ model.to(device)
17
+
18
+ # Lightweight rewriter model (T5-small) to improve fluency/detail without big cost
19
+ rewriter_tokenizer = T5Tokenizer.from_pretrained("t5-small")
20
+ rewriter = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
21
 
22
+ # Simple safety patterns (non-bypass)
23
  SENSITIVE_PATTERNS = [
24
  r"\b(nude|naked|porn|sex|sexual|explicit|hardcore)\b",
25
  r"\b(blood|gore|mutilat|disembowel|organs)\b",
 
40
  return False
41
  return SENSITIVE_RE.search(text) is None
42
 
43
+ def generate_caption(img: Image.Image, max_len:int=30):
44
+ # prepare inputs
45
+ inputs = processor(images=img, return_tensors="pt").to(device)
46
+ # generate with small beams to balance speed/quality
47
+ out = model.generate(**inputs, max_length=max_len, num_beams=3, early_stopping=True)
48
+ caption = processor.decode(out[0], skip_special_tokens=True).strip()
49
+ return caption
50
+
51
+ def rewrite_caption(caption: str, max_len:int=64):
52
+ # Use T5-small to expand/clean the caption. Prefix "paraphrase:" helps instruct T5.
53
+ input_text = "paraphrase: " + caption
54
+ tok = rewriter_tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
55
+ out = rewriter.generate(**tok, max_length=max_len, num_beams=2, early_stopping=True)
56
+ rewritten = rewriter_tokenizer.decode(out[0], skip_special_tokens=True).strip()
57
+ return rewritten
58
+
59
+ def describe_image(url: str, max_caption_len: int = 30, expand: bool = True):
60
  img, err = load_image_from_url(url)
61
  if err:
62
  return None, err
63
+ caption = generate_caption(img, max_len=max_caption_len)
 
 
 
 
64
  if not is_caption_allowed(caption):
65
+ return img, "Caption omitted for safety."
66
+ if expand:
67
+ try:
68
+ caption = rewrite_caption(caption, max_len=64)
69
+ if not is_caption_allowed(caption):
70
+ return img, "Caption omitted for safety."
71
+ except Exception:
72
+ pass
73
  return img, caption
74
 
75
  with gr.Blocks() as demo:
76
+ gr.Markdown("## CPU-optimized Image Captioning (BLIP-base + T5-small rewriter)")
77
  with gr.Row():
78
  url_in = gr.Textbox(label="Image URL", placeholder="https://example.com/photo.jpg")
79
+ max_len = gr.Slider(minimum=10, maximum=60, value=30, label="Max caption token length")
80
+ expand_chk = gr.Checkbox(label="Expand/rewite caption (slower, more natural)", value=True)
81
  img_out = gr.Image(type="pil", label="Image")
82
  caption_out = gr.Textbox(label="Caption")
83
  go = gr.Button("Load & Describe")
84
+ go.click(fn=describe_image, inputs=[url_in, max_len, expand_chk], outputs=[img_out, caption_out])
85
 
86
  if __name__ == "__main__":
87
  demo.launch(server_name="0.0.0.0", server_port=7860)