Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
from peft import PeftModel
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import librosa
|
| 8 |
+
import nltk
|
| 9 |
+
|
| 10 |
+
from transformers import PreTrainedModel
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 13 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 14 |
+
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
|
| 17 |
+
model_name = "microsoft/Phi-3.5-mini-instruct"
|
| 18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 19 |
+
|
| 20 |
+
# Load the model and processor
|
| 21 |
+
clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 22 |
+
clipprocessor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 23 |
+
|
| 24 |
+
nltk.download('punkt')
|
| 25 |
+
nltk.download('punkt_tab')
|
| 26 |
+
|
| 27 |
+
def remove_punctuation(text):
|
| 28 |
+
newtext = ''.join([char for char in text if char.isalnum() or char.isspace()])
|
| 29 |
+
newtext = ' '.join(newtext.split())
|
| 30 |
+
return newtext
|
| 31 |
+
|
| 32 |
+
def preprocess_text(text):
|
| 33 |
+
text_no_punct = remove_punctuation(text)
|
| 34 |
+
return text_no_punct
|
| 35 |
+
|
| 36 |
+
# Load Whisper model and processor
|
| 37 |
+
whisper_model_name = "openai/whisper-small"
|
| 38 |
+
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
|
| 39 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
|
| 40 |
+
|
| 41 |
+
def transcribe_speech(audiopath):
|
| 42 |
+
# Load and preprocess the audio
|
| 43 |
+
speech, rate = librosa.load(audiopath, sr=16000)
|
| 44 |
+
audio_input = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
|
| 45 |
+
# print("audio_input:", audio_input)
|
| 46 |
+
|
| 47 |
+
# Generate transcription
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
generated_ids = whisper_model.generate(audio_input["input_features"])
|
| 50 |
+
|
| 51 |
+
# Decode the transcription
|
| 52 |
+
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 53 |
+
|
| 54 |
+
return transcription
|
| 55 |
+
|
| 56 |
+
class ProjectionBlock(nn.Module):
|
| 57 |
+
def __init__(self, input_dim_CLIP, input_dim_phi2):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.pre_norm = nn.LayerNorm(input_dim_CLIP)
|
| 60 |
+
self.proj = nn.Sequential(
|
| 61 |
+
nn.Linear(input_dim_CLIP, input_dim_phi2),
|
| 62 |
+
nn.GELU(),
|
| 63 |
+
nn.Linear(input_dim_phi2, input_dim_phi2)
|
| 64 |
+
)
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
x = self.pre_norm(x)
|
| 67 |
+
return self.proj(x)
|
| 68 |
+
|
| 69 |
+
# Modify the MultimodalPhiModel class to work with HuggingFace Trainer
|
| 70 |
+
class MultimodalPhiModel(PreTrainedModel):
|
| 71 |
+
|
| 72 |
+
def gradient_checkpointing_enable(self, **kwargs):
|
| 73 |
+
self.phi_model.gradient_checkpointing_enable(**kwargs)
|
| 74 |
+
|
| 75 |
+
def gradient_checkpointing_disable(self):
|
| 76 |
+
self.phi_model.gradient_checkpointing_disable()
|
| 77 |
+
|
| 78 |
+
def __init__(self, phi_model, tokenizer, projection):
|
| 79 |
+
super().__init__(phi_model.config)
|
| 80 |
+
self.phi_model = phi_model
|
| 81 |
+
self.image_projection = projection
|
| 82 |
+
self.tokenizer = tokenizer
|
| 83 |
+
# self.device = device
|
| 84 |
+
self.base_phi_model = None
|
| 85 |
+
|
| 86 |
+
@classmethod
|
| 87 |
+
def from_pretrained(self, pretrained_model_name_or_path, *model_args, debug=False, **kwargs):
|
| 88 |
+
|
| 89 |
+
model_name = "microsoft/Phi-3.5-mini-instruct"
|
| 90 |
+
base_phi_model = AutoModelForCausalLM.from_pretrained(
|
| 91 |
+
model_name,
|
| 92 |
+
torch_dtype=torch.bfloat16,
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
)
|
| 95 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 96 |
+
|
| 97 |
+
# phi_path = os.path.join(pretrained_model_name_or_path, "phi_model")
|
| 98 |
+
phi_path = pretrained_model_name_or_path
|
| 99 |
+
|
| 100 |
+
# Save the base model
|
| 101 |
+
model = PeftModel.from_pretrained(base_phi_model, phi_path)
|
| 102 |
+
phi_model = model.merge_and_unload()
|
| 103 |
+
|
| 104 |
+
# # Load the base Phi-3 model
|
| 105 |
+
# phi_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 106 |
+
input_dim = 512
|
| 107 |
+
output_dim = 3072
|
| 108 |
+
|
| 109 |
+
# Load the projector weights
|
| 110 |
+
# projector_path = os.path.join(pretrained_model_name_or_path, "projection_layer", "pytorch_model.bin")
|
| 111 |
+
projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
|
| 112 |
+
if os.path.exists(projector_path):
|
| 113 |
+
projector_state_dict = torch.load(projector_path, map_location=phi_model.device)
|
| 114 |
+
|
| 115 |
+
projector = ProjectionBlock(input_dim, output_dim)
|
| 116 |
+
|
| 117 |
+
# Try to load the state dict, ignoring mismatched keys
|
| 118 |
+
projector.load_state_dict(projector_state_dict, strict=False)
|
| 119 |
+
print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
|
| 120 |
+
else:
|
| 121 |
+
print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.")
|
| 122 |
+
input_dim = 512 # Default CLIP embedding size
|
| 123 |
+
output_dim = phi_model.config.hidden_size
|
| 124 |
+
projector = ProjectionBlock(input_dim, output_dim)
|
| 125 |
+
|
| 126 |
+
# Create and return the Phi3WithProjector instance
|
| 127 |
+
model = self(phi_model, tokenizer, projector)
|
| 128 |
+
model.base_phi_model = base_phi_model
|
| 129 |
+
return model
|
| 130 |
+
|
| 131 |
+
def save_pretrained(self, save_directory):
|
| 132 |
+
# Load the Phi-3.5 model
|
| 133 |
+
self.phi_model.save_pretrained(save_directory)
|
| 134 |
+
# model_name = "microsoft/Phi-3.5-mini-instruct"
|
| 135 |
+
# base_phi_model = AutoModelForCausalLM.from_pretrained(
|
| 136 |
+
# model_name,
|
| 137 |
+
# torch_dtype=torch.bfloat16,
|
| 138 |
+
# trust_remote_code=True,
|
| 139 |
+
# )
|
| 140 |
+
# # Save the base model
|
| 141 |
+
# model = PeftModel.from_pretrained(base_phi_model, self.phi_model)
|
| 142 |
+
# model = model.merge_and_unload()
|
| 143 |
+
# model.save_pretrained(save_directory)
|
| 144 |
+
|
| 145 |
+
# Save the projector weights
|
| 146 |
+
projector_path = os.path.join(save_directory, "image_projector.pth")
|
| 147 |
+
torch.save(self.image_projection.state_dict(), projector_path)
|
| 148 |
+
|
| 149 |
+
# Save the config
|
| 150 |
+
self.config.save_pretrained(save_directory)
|
| 151 |
+
|
| 152 |
+
def encode(self, image_features):
|
| 153 |
+
image_projections = self.image_projection(image_features)
|
| 154 |
+
return image_projections
|
| 155 |
+
|
| 156 |
+
def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels):
|
| 157 |
+
# print("tokenizer bos_token_id", self.tokenizer.bos_token_id, "tokenizer eos_token", self.tokenizer.eos_token,
|
| 158 |
+
# "tokenizer pad_token_id", self.tokenizer.pad_token_id, "tokenizer sep_token_id", self.tokenizer.sep_token_id,
|
| 159 |
+
# "tokenizer cls_token_id", self.tokenizer.cls_token_id, "tokenizer mask_token_id", self.tokenizer.mask_token_id,
|
| 160 |
+
# "tokenizer unk_token_id", self.tokenizer.unk_token_id)
|
| 161 |
+
device = next(self.parameters()).device
|
| 162 |
+
|
| 163 |
+
start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device))
|
| 164 |
+
end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device))
|
| 165 |
+
# print("start_embeds shape:", start_embeds.shape, "image_embeddings shape:", image_embeddings.shape, "end_embeds shape:", end_embeds.shape)
|
| 166 |
+
# print("start_embeds dtype:", start_embeds.dtype, "image_embeddings dtype:", image_embeddings.dtype, "end_embeds dtype:", end_embeds.dtype)
|
| 167 |
+
if image_features is not None:
|
| 168 |
+
# Encode image features
|
| 169 |
+
image_embeddings = self.encode(image_features.to(device)).bfloat16()
|
| 170 |
+
input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1)
|
| 171 |
+
else:
|
| 172 |
+
input_embeds = torch.cat([start_embeds, end_embeds], dim=1)
|
| 173 |
+
# print("Input Embeds shape:", input_embeds.shape, "attention_mask shape:", attention_mask.shape, "labels shape:", labels.shape)
|
| 174 |
+
|
| 175 |
+
# print("input_embeds dtype:", input_embeds.dtype, "attention_mask dtype:", attention_mask.dtype)
|
| 176 |
+
# Forward pass through the language model
|
| 177 |
+
outputs = self.phi_model(inputs_embeds=input_embeds.to(device),
|
| 178 |
+
attention_mask=attention_mask.to(device),
|
| 179 |
+
labels=labels,
|
| 180 |
+
return_dict=True)
|
| 181 |
+
|
| 182 |
+
return outputs
|
| 183 |
+
|
| 184 |
+
def getImageArray(image_path):
|
| 185 |
+
image = Image.open(image_path)
|
| 186 |
+
return image
|
| 187 |
+
|
| 188 |
+
def getAudioArray(audio_path):
|
| 189 |
+
speech, rate = librosa.load(audio_path, sr=16000)
|
| 190 |
+
return speech
|
| 191 |
+
|
| 192 |
+
def getInputs(image_path, question, answer=""):
|
| 193 |
+
|
| 194 |
+
image_features = None
|
| 195 |
+
speech_text = ""
|
| 196 |
+
num_image_tokens = 0
|
| 197 |
+
|
| 198 |
+
if image_path is not None:
|
| 199 |
+
# print("type of image:", type(image_path))
|
| 200 |
+
# print("image path:", image_path)
|
| 201 |
+
image = clipprocessor(images=Image.open(image_path), return_tensors="pt")
|
| 202 |
+
|
| 203 |
+
# Generate the embedding
|
| 204 |
+
image_features = clipmodel.get_image_features(**image)
|
| 205 |
+
|
| 206 |
+
# Generate the embedding
|
| 207 |
+
# image_features = get_clip_embeddings(image)
|
| 208 |
+
image_features = torch.stack([image_features])
|
| 209 |
+
num_image_tokens = image_features.shape[1]
|
| 210 |
+
|
| 211 |
+
# Start text before putting image embedding
|
| 212 |
+
start_text = f"<|system|>\nYou are an assistant good at understanding the objects and their relationship from the context.<|end|>\n<|user|>\n"
|
| 213 |
+
|
| 214 |
+
# Prepare text input for causal language modeling
|
| 215 |
+
end_text = f"\nPlease describe the objects and their relationship from the context.<|end|>\n<|assistant|>\n{answer}"
|
| 216 |
+
|
| 217 |
+
# Tokenize the full texts
|
| 218 |
+
start_tokens = tokenizer(start_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
| 219 |
+
end_tokens = tokenizer(end_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
| 220 |
+
# print(f"start_encodings shape: {start_encodings['input_ids'].shape}, end_encodings shape: {end_encodings['input_ids'].shape}")
|
| 221 |
+
|
| 222 |
+
start_input_ids = start_tokens['input_ids']
|
| 223 |
+
start_attention_mask = start_tokens['attention_mask']
|
| 224 |
+
end_input_ids = end_tokens['input_ids']
|
| 225 |
+
end_attention_mask = end_tokens['attention_mask']
|
| 226 |
+
|
| 227 |
+
# print("start_input_ids type:", type(start_input_ids), "image_tokens type:", type(image_tokens))
|
| 228 |
+
# print(f"start_input_ids shape: {start_input_ids.shape}, image_tokens shape: {image_tokens.shape}, end_input_ids shape: {end_input_ids.shape}")
|
| 229 |
+
# input_ids = torch.cat([start_input_ids,image_tokens,end_input_ids], dim=1)
|
| 230 |
+
if image_path is not None:
|
| 231 |
+
attention_mask = torch.cat([start_attention_mask, torch.ones((1, num_image_tokens), dtype=torch.long), end_attention_mask], dim=1)
|
| 232 |
+
else:
|
| 233 |
+
attention_mask = torch.cat([start_attention_mask, end_attention_mask], dim=1)
|
| 234 |
+
|
| 235 |
+
return start_input_ids, end_input_ids, image_features, attention_mask
|
| 236 |
+
|
| 237 |
+
model_location = "./MM_FT_C1_V2"
|
| 238 |
+
print("Model location:", model_location)
|
| 239 |
+
|
| 240 |
+
model = MultimodalPhiModel.from_pretrained(model_location).to(device)
|
| 241 |
+
|
| 242 |
+
import re
|
| 243 |
+
|
| 244 |
+
def getStringAfter(output, start_str):
|
| 245 |
+
if start_str in output:
|
| 246 |
+
answer = output.split(start_str)[1]
|
| 247 |
+
else:
|
| 248 |
+
answer = output
|
| 249 |
+
|
| 250 |
+
answer = preprocess_text(answer)
|
| 251 |
+
return answer
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def getStringAfterAnswer(output):
|
| 255 |
+
if "<|assistant|>" in output:
|
| 256 |
+
answer = output.split("<|assistant|>")[1]
|
| 257 |
+
else:
|
| 258 |
+
answer = output
|
| 259 |
+
|
| 260 |
+
answer = preprocess_text(answer)
|
| 261 |
+
return answer
|
| 262 |
+
|
| 263 |
+
def generateOutput(image_path, audio_path, context_text, question, max_length=5):
|
| 264 |
+
answerPart = ""
|
| 265 |
+
speech_text = ""
|
| 266 |
+
if image_path is not None:
|
| 267 |
+
for i in range(max_length):
|
| 268 |
+
start_tokens, end_tokens, image_features, attention_mask = getInputs(image_path, question, answer=answerPart)
|
| 269 |
+
# print("image_features dtype:", image_features.dtype)
|
| 270 |
+
output = model(start_tokens, end_tokens, image_features, attention_mask, labels=None)
|
| 271 |
+
tokens = output.logits.argmax(dim=-1)
|
| 272 |
+
output = tokenizer.decode(
|
| 273 |
+
tokens[0],
|
| 274 |
+
skip_special_tokens=True
|
| 275 |
+
)
|
| 276 |
+
answerPart = getStringAfter(output, "<|assistant|>")
|
| 277 |
+
print("Answerpart:", answerPart)
|
| 278 |
+
|
| 279 |
+
if audio_path is not None:
|
| 280 |
+
speech_text = transcribe_speech(audio_path)
|
| 281 |
+
print("Speech Text:", speech_text)
|
| 282 |
+
|
| 283 |
+
if (question is None) or (question == ""):
|
| 284 |
+
question = "Provide only in 1 sentence to describe the objects and their relationships in it."
|
| 285 |
+
|
| 286 |
+
input_text = (
|
| 287 |
+
"<|system|>\nPlease understand the context "
|
| 288 |
+
"and answer the question based on the context in 1 or 2 summarized sentences.\n"
|
| 289 |
+
f"<|end|>\n<|user|>\n<|context|>{answerPart}\n{speech_text}\n{context_text}"
|
| 290 |
+
f"\n<|question|>: {question}\n<|end|>\n<|assistant|>\n"
|
| 291 |
+
)
|
| 292 |
+
print("input_text:", input_text)
|
| 293 |
+
start_tokens = tokenizer(input_text, padding=True, truncation=True, max_length=1024, return_tensors="pt")['input_ids'].to(device)
|
| 294 |
+
# base_phi_model.generate(start_tokens, max_length=2, do_sample=False, pad_token_id=tokenizer.pad_token_id)
|
| 295 |
+
|
| 296 |
+
output_text = tokenizer.decode(
|
| 297 |
+
model.base_phi_model.generate(start_tokens, max_length=1024, do_sample=False, pad_token_id=tokenizer.pad_token_id)[0],
|
| 298 |
+
skip_special_tokens=True
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
output_text = getStringAfter(output_text, question).strip()
|
| 302 |
+
return output_text
|
| 303 |
+
|
| 304 |
+
title = "Created Fine Tuned MultiModal model"
|
| 305 |
+
description = "Test the fine tuned multimodal model created using clip, phi3.5 mini instruct, whisper models"
|
| 306 |
+
examples = [
|
| 307 |
+
["./images/COCO_train2014_000000581181.jpg", None, None, None, None, "Describe what is happening in this image."],
|
| 308 |
+
[None, "Audio File", "./audio/03-01-01-01-01-01-01.wav", None, None, "Describe what is the person trying to tell in this audio."],
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
# [None, "Microphone", None, "example_audio_mic.wav", "Context without image.", "What is the result?"],
|
| 312 |
+
|
| 313 |
+
demo = gr.Blocks()
|
| 314 |
+
|
| 315 |
+
def process_inputs(image, audio_source, audio_file, audio_mic, context_text, question):
|
| 316 |
+
if audio_source == "Microphone":
|
| 317 |
+
speech = audio_mic
|
| 318 |
+
elif audio_source == "Audio File":
|
| 319 |
+
speech = audio_file
|
| 320 |
+
else:
|
| 321 |
+
speech = None
|
| 322 |
+
|
| 323 |
+
# image_features = get_clip_embeddings(image) if image else None
|
| 324 |
+
answer = generateOutput(image, speech, context_text, question)
|
| 325 |
+
|
| 326 |
+
return answer
|
| 327 |
+
|
| 328 |
+
with demo:
|
| 329 |
+
with gr.Row():
|
| 330 |
+
audio_source = gr.Radio(choices=["Microphone", "Audio File"], label="Select Audio Source")
|
| 331 |
+
audio_file = gr.Audio(sources="upload", type="filepath", visible=False)
|
| 332 |
+
audio_mic = gr.Audio(sources="microphone", type="filepath", visible=False)
|
| 333 |
+
image_input = gr.Image(type="filepath", label="Upload Image")
|
| 334 |
+
context_text = gr.Textbox(label="Context Text")
|
| 335 |
+
question = gr.Textbox(label="Question")
|
| 336 |
+
output_text = gr.Textbox(label="Output")
|
| 337 |
+
|
| 338 |
+
def update_audio_input(source):
|
| 339 |
+
if source == "Microphone":
|
| 340 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 341 |
+
elif source == "Audio File":
|
| 342 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 343 |
+
else:
|
| 344 |
+
return gr.update(visible=False), gr.update(visible=False)
|
| 345 |
+
|
| 346 |
+
audio_source.change(fn=update_audio_input, inputs=audio_source, outputs=[audio_mic, audio_file])
|
| 347 |
+
submit_button = gr.Button("Submit")
|
| 348 |
+
submit_button.click(fn=process_inputs, inputs=[image_input, audio_source, audio_file, audio_mic, context_text, question], outputs=output_text)
|
| 349 |
+
|
| 350 |
+
demo.launch(debug=True)
|