gopalagra commited on
Commit
f30c62d
·
verified ·
1 Parent(s): 96877e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -24
app.py CHANGED
@@ -70,29 +70,42 @@ import gradio as gr
70
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
71
  from PIL import Image
72
  import torch
 
73
 
74
  # ----------------------
75
- # Load BLIP2 for Captioning
76
  # ----------------------
77
- caption_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
78
- caption_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # ----------------------
81
- # Load BLIP2 for VQA
82
  # ----------------------
83
- vqa_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
84
- vqa_model = Blip2ForConditionalGeneration.from_pretrained(
85
- "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16, device_map="auto"
86
- )
87
 
88
- # ----------------------
89
- # Translation pipelines
90
- # ----------------------
91
- translation_models = {
92
- "Hindi": pipeline("translation", model="Helsinki-NLP/opus-mt-en-hi"),
93
- "French": pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr"),
94
- "Spanish": pipeline("translation", model="Helsinki-NLP/opus-mt-en-es"),
95
- }
96
 
97
  # ----------------------
98
  # Caption + Translate Function
@@ -102,7 +115,6 @@ def generate_caption_translate(image, target_lang):
102
  out = caption_model.generate(**inputs, max_new_tokens=50)
103
  english_caption = caption_processor.decode(out[0], skip_special_tokens=True)
104
 
105
- # Translate if chosen
106
  if target_lang in translation_models:
107
  translated = translation_models[target_lang](english_caption)[0]['translation_text']
108
  else:
@@ -142,10 +154,4 @@ with gr.Blocks(title="BLIP2 Vision App") as demo:
142
  btn2 = gr.Button("Ask")
143
  btn2.click(vqa, inputs=[img_vqa, q_in], outputs=ans_out)
144
 
145
- demo.launch(share="true")
146
-
147
-
148
-
149
-
150
-
151
-
 
70
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
71
  from PIL import Image
72
  import torch
73
+ import streamlit as st
74
 
75
  # ----------------------
76
+ # Cached Model Loaders
77
  # ----------------------
78
+ @st.cache_resource
79
+ def load_caption_model():
80
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
81
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
82
+ return processor, model
83
+
84
+ @st.cache_resource
85
+ def load_vqa_model():
86
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
87
+ model = Blip2ForConditionalGeneration.from_pretrained(
88
+ "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16, device_map="auto"
89
+ )
90
+ return processor, model
91
+
92
+ @st.cache_resource
93
+ def load_translation_models():
94
+ return {
95
+ "Hindi": pipeline("translation", model="Helsinki-NLP/opus-mt-en-hi"),
96
+ "French": pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr"),
97
+ "Spanish": pipeline("translation", model="Helsinki-NLP/opus-mt-en-es"),
98
+ }
99
 
100
  # ----------------------
101
+ # Load All Models with Spinner
102
  # ----------------------
103
+ with st.spinner("Loading BLIP2 models... please wait ⏳"):
104
+ caption_processor, caption_model = load_caption_model()
105
+ vqa_processor, vqa_model = load_vqa_model()
106
+ translation_models = load_translation_models()
107
 
108
+ st.success("✅ Models are ready!")
 
 
 
 
 
 
 
109
 
110
  # ----------------------
111
  # Caption + Translate Function
 
115
  out = caption_model.generate(**inputs, max_new_tokens=50)
116
  english_caption = caption_processor.decode(out[0], skip_special_tokens=True)
117
 
 
118
  if target_lang in translation_models:
119
  translated = translation_models[target_lang](english_caption)[0]['translation_text']
120
  else:
 
154
  btn2 = gr.Button("Ask")
155
  btn2.click(vqa, inputs=[img_vqa, q_in], outputs=ans_out)
156
 
157
+ demo.launch()