gopalagra commited on
Commit
3981a40
·
verified ·
1 Parent(s): e1a7959

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -16
app.py CHANGED
@@ -69,10 +69,15 @@
69
  import gradio as gr
70
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
71
  from PIL import Image
 
72
 
73
- # Load BLIP model
74
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
75
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
 
 
 
 
76
 
77
  # Translation pipelines
78
  translation_models = {
@@ -81,13 +86,12 @@ translation_models = {
81
  "Spanish": pipeline("translation", model="Helsinki-NLP/opus-mt-en-es"),
82
  }
83
 
 
84
  def generate_caption_translate(image, target_lang):
85
- # Step 1: Generate English caption
86
- inputs = processor(image, return_tensors="pt")
87
  out = model.generate(**inputs, max_new_tokens=50)
88
  english_caption = processor.decode(out[0], skip_special_tokens=True)
89
 
90
- # Step 2: Translate
91
  if target_lang in translation_models:
92
  translated = translation_models[target_lang](english_caption)[0]['translation_text']
93
  else:
@@ -95,18 +99,34 @@ def generate_caption_translate(image, target_lang):
95
 
96
  return english_caption, translated
97
 
98
- # Gradio Interface
99
- interface = gr.Interface(
100
- fn=generate_caption_translate,
101
- inputs=[gr.Image(type="pil"), gr.Dropdown(["Hindi", "French", "Spanish"], label="Translate To")],
102
- outputs=[
103
- gr.Textbox(label="English Caption"),
104
- gr.Textbox(label="Translated Caption")
105
- ],
106
- title="BLIP Captioning + Translation"
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- interface.launch()
110
 
111
 
112
 
 
69
  import gradio as gr
70
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
71
  from PIL import Image
72
+ import torch
73
 
74
+ # Load BLIP2 model
75
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
76
+ model = Blip2ForConditionalGeneration.from_pretrained(
77
+ "Salesforce/blip2-opt-2.7b",
78
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
79
+ device_map="auto" if torch.cuda.is_available() else None
80
+ )
81
 
82
  # Translation pipelines
83
  translation_models = {
 
86
  "Spanish": pipeline("translation", model="Helsinki-NLP/opus-mt-en-es"),
87
  }
88
 
89
+ # ---- Caption + Translation ----
90
  def generate_caption_translate(image, target_lang):
91
+ inputs = processor(image, return_tensors="pt").to(model.device)
 
92
  out = model.generate(**inputs, max_new_tokens=50)
93
  english_caption = processor.decode(out[0], skip_special_tokens=True)
94
 
 
95
  if target_lang in translation_models:
96
  translated = translation_models[target_lang](english_caption)[0]['translation_text']
97
  else:
 
99
 
100
  return english_caption, translated
101
 
102
+ # ---- Visual Question Answering ----
103
+ def answer_question(image, question):
104
+ inputs = processor(image, text=question, return_tensors="pt").to(model.device)
105
+ out = model.generate(**inputs, max_new_tokens=50)
106
+ answer = processor.decode(out[0], skip_special_tokens=True)
107
+ return answer
108
+
109
+ # ---- Gradio Interface ----
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown("## 🖼️ BLIP2: Image Captioning + Translation + VQA")
112
+
113
+ with gr.Tab("Caption + Translation"):
114
+ img1 = gr.Image(type="pil")
115
+ lang = gr.Dropdown(["Hindi", "French", "Spanish"], label="Translate To")
116
+ eng_cap = gr.Textbox(label="English Caption")
117
+ trans_cap = gr.Textbox(label="Translated Caption")
118
+ btn1 = gr.Button("Generate Caption + Translate")
119
+ btn1.click(generate_caption_translate, inputs=[img1, lang], outputs=[eng_cap, trans_cap])
120
+
121
+ with gr.Tab("Visual Question Answering"):
122
+ img2 = gr.Image(type="pil")
123
+ question = gr.Textbox(label="Ask a Question about the Image")
124
+ answer = gr.Textbox(label="Answer")
125
+ btn2 = gr.Button("Get Answer")
126
+ btn2.click(answer_question, inputs=[img2, question], outputs=answer)
127
+
128
+ demo.launch()
129
 
 
130
 
131
 
132