gopalagra commited on
Commit
6ec7c3f
·
verified ·
1 Parent(s): a88f95d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -70,42 +70,29 @@ import gradio as gr
70
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
71
  from PIL import Image
72
  import torch
73
- import streamlit as st
74
 
75
  # ----------------------
76
- # Cached Model Loaders
77
  # ----------------------
78
- @st.cache_resource
79
- def load_caption_model():
80
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
81
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
82
- return processor, model
83
-
84
- @st.cache_resource
85
- def load_vqa_model():
86
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
87
- model = Blip2ForConditionalGeneration.from_pretrained(
88
- "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16, device_map="auto"
89
- )
90
- return processor, model
91
-
92
- @st.cache_resource
93
- def load_translation_models():
94
- return {
95
- "Hindi": pipeline("translation", model="Helsinki-NLP/opus-mt-en-hi"),
96
- "French": pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr"),
97
- "Spanish": pipeline("translation", model="Helsinki-NLP/opus-mt-en-es"),
98
- }
99
 
100
  # ----------------------
101
- # Load All Models with Spinner
102
  # ----------------------
103
- with st.spinner("Loading BLIP2 models... please wait ⏳"):
104
- caption_processor, caption_model = load_caption_model()
105
- vqa_processor, vqa_model = load_vqa_model()
106
- translation_models = load_translation_models()
107
 
108
- st.success("✅ Models are ready!")
 
 
 
 
 
 
 
109
 
110
  # ----------------------
111
  # Caption + Translate Function
@@ -115,6 +102,7 @@ def generate_caption_translate(image, target_lang):
115
  out = caption_model.generate(**inputs, max_new_tokens=50)
116
  english_caption = caption_processor.decode(out[0], skip_special_tokens=True)
117
 
 
118
  if target_lang in translation_models:
119
  translated = translation_models[target_lang](english_caption)[0]['translation_text']
120
  else:
 
70
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
71
  from PIL import Image
72
  import torch
 
73
 
74
  # ----------------------
75
+ # Load BLIP2 for Captioning
76
  # ----------------------
77
+ caption_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
78
+ caption_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # ----------------------
81
+ # Load BLIP2 for VQA
82
  # ----------------------
83
+ vqa_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
84
+ vqa_model = Blip2ForConditionalGeneration.from_pretrained(
85
+ "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16, device_map="auto"
86
+ )
87
 
88
+ # ----------------------
89
+ # Translation pipelines
90
+ # ----------------------
91
+ translation_models = {
92
+ "Hindi": pipeline("translation", model="Helsinki-NLP/opus-mt-en-hi"),
93
+ "French": pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr"),
94
+ "Spanish": pipeline("translation", model="Helsinki-NLP/opus-mt-en-es"),
95
+ }
96
 
97
  # ----------------------
98
  # Caption + Translate Function
 
102
  out = caption_model.generate(**inputs, max_new_tokens=50)
103
  english_caption = caption_processor.decode(out[0], skip_special_tokens=True)
104
 
105
+ # Translate if chosen
106
  if target_lang in translation_models:
107
  translated = translation_models[target_lang](english_caption)[0]['translation_text']
108
  else: