HF-Pawan commited on
Commit
f24f5b3
·
1 Parent(s): d9e2463

Style Changes

Browse files
Files changed (2) hide show
  1. core/model_loader.py +34 -25
  2. ui/layout.py +5 -7
core/model_loader.py CHANGED
@@ -4,14 +4,10 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
4
  MODEL_ID = "Salesforce/blip-image-captioning-large"
5
  DEVICE = torch.device("cpu")
6
 
7
- # Prompt templates
8
  PROMPTS = {
9
  "Short Caption": "a photo of",
10
- "Detailed Caption": "this image shows",
11
- "Creative Caption": "this artistic scene depicts",
12
- "Image Explanation": (
13
- "this image shows a complete and detailed scene depicting"
14
- )
15
  }
16
 
17
  def load_model():
@@ -21,6 +17,26 @@ def load_model():
21
  model.eval()
22
  return model, processor
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def generate_caption(
25
  model,
26
  processor,
@@ -36,34 +52,27 @@ def generate_caption(
36
  ).to(DEVICE)
37
 
38
  # Style-specific decoding configuration
39
- if style == "Image Explanation":
40
  generation_kwargs = dict(
41
- min_length=90,
42
- max_length=160,
43
- num_beams=5,
44
  do_sample=False,
45
  repetition_penalty=1.25,
46
  length_penalty=1.1,
 
47
  early_stopping=True
48
  )
49
 
50
- elif style == "Detailed Caption":
51
  generation_kwargs = dict(
52
- min_length=60,
53
- max_length=120,
54
  num_beams=3,
55
  do_sample=False,
56
- repetition_penalty=1.2
57
- )
58
-
59
- else:
60
- generation_kwargs = dict(
61
- min_length=20,
62
- max_length=50,
63
- do_sample=True,
64
- top_p=0.9,
65
- temperature=0.8,
66
- repetition_penalty=1.1
67
  )
68
 
69
  with torch.inference_mode():
@@ -77,4 +86,4 @@ def generate_caption(
77
  skip_special_tokens=True
78
  )
79
 
80
- return caption.strip()
 
4
  MODEL_ID = "Salesforce/blip-image-captioning-large"
5
  DEVICE = torch.device("cpu")
6
 
7
+ # Prompt templates (kept short & stable for BLIP)
8
  PROMPTS = {
9
  "Short Caption": "a photo of",
10
+ "Detailed Caption": "this image shows"
 
 
 
 
11
  }
12
 
13
  def load_model():
 
17
  model.eval()
18
  return model, processor
19
 
20
+ def _finalize_sentence(text: str) -> str:
21
+ """
22
+ Ensures:
23
+ - no trailing commas / conjunctions
24
+ - sentence ends with a dot
25
+ """
26
+ text = text.strip()
27
+
28
+ # Remove dangling conjunctions
29
+ for suffix in [",", "and", "and a", "and the"]:
30
+ if text.lower().endswith(suffix):
31
+ text = text[: -len(suffix)].strip()
32
+
33
+ # Ensure final punctuation
34
+ if not text.endswith((".", "!", "?")):
35
+ text += "."
36
+
37
+ return text
38
+
39
+
40
  def generate_caption(
41
  model,
42
  processor,
 
52
  ).to(DEVICE)
53
 
54
  # Style-specific decoding configuration
55
+ if style == "Detailed Caption":
56
  generation_kwargs = dict(
57
+ min_length=55,
58
+ max_length=110,
59
+ num_beams=4,
60
  do_sample=False,
61
  repetition_penalty=1.25,
62
  length_penalty=1.1,
63
+ no_repeat_ngram_size=3,
64
  early_stopping=True
65
  )
66
 
67
+ else: # Short Caption
68
  generation_kwargs = dict(
69
+ min_length=18,
70
+ max_length=40,
71
  num_beams=3,
72
  do_sample=False,
73
+ repetition_penalty=1.15,
74
+ no_repeat_ngram_size=3,
75
+ early_stopping=True
 
 
 
 
 
 
 
 
76
  )
77
 
78
  with torch.inference_mode():
 
86
  skip_special_tokens=True
87
  )
88
 
89
+ return _finalize_sentence(caption)
ui/layout.py CHANGED
@@ -47,9 +47,7 @@ def build_ui(model, processor):
47
  style_select = gr.Dropdown(
48
  choices=[
49
  "Short Caption",
50
- "Detailed Caption",
51
- "Creative Caption",
52
- "Image Explanation"
53
  ],
54
  value="Detailed Caption",
55
  label="Caption Style"
@@ -69,11 +67,11 @@ def build_ui(model, processor):
69
 
70
  gr.Examples(
71
  examples=[
72
- ["./assets/zebra.jpg", "Detailed Caption"],
73
- ["./assets/cat.jpg", "Image Explanation"],
74
  ["./assets/fridge.jpg", "Detailed Caption"],
75
- ["./assets/marriage.jpg", "Creative Caption"],
76
- ["./assets/giraffe.jpg", "Short Caption"]
77
  ],
78
  inputs=[image_input, style_select]
79
  )
 
47
  style_select = gr.Dropdown(
48
  choices=[
49
  "Short Caption",
50
+ "Detailed Caption"
 
 
51
  ],
52
  value="Detailed Caption",
53
  label="Caption Style"
 
67
 
68
  gr.Examples(
69
  examples=[
70
+ ["./assets/zebra.jpg", "Short Caption"],
71
+ ["./assets/cat.jpg", "Short Caption"],
72
  ["./assets/fridge.jpg", "Detailed Caption"],
73
+ ["./assets/marriage.jpg", "Detailed Caption"],
74
+ ["./assets/giraffe.jpg", "Detailed Caption"]
75
  ],
76
  inputs=[image_input, style_select]
77
  )