multimodal_phi2 / app.py
Dhairyashil Ghatage
upload 12k iterations data
2b44451
import gradio as gr
import torch
from PIL import Image
import librosa
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel
from peft import PeftConfig, PeftModel
from safetensors.torch import load_file
import sys
import os
import torch.nn as nn
from safetensors import safe_open
import json
torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----
def load_clip():
# -- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load CLIP processor and model, saving locally if not already present
if os.path.exists("data/processor_clip_embeddings_vit_base_patch32.pt"):
processor = CLIPProcessor.from_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt")
else:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# save processor for later use
processor.save_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt")
model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
model_clip = model_clip.to(device).to(torch.float32)
model_clip.eval() # Set CLIP model to evaluation mode
return processor, model_clip
class ProjectionLayer(nn.Module):
def __init__(self, clip_embedding_dim, phi_hidden_dim):
super().__init__()
self.linear = nn.Linear(clip_embedding_dim, phi_hidden_dim)
# Comment out the layer normalization for now
# self.ln = nn.LayerNorm(phi_hidden_dim)
def forward(self, image_embeddings):
# Return only the linear transformation
return self.linear(image_embeddings)
# return self.ln(self.linear(image_embeddings))
class MultimodalPhiWithAdapter(nn.Module):
def __init__(self, language_model, projection_layer, freeze_language_model=True, freeze_projection_layer=False):
super().__init__()
self.language_model = language_model
self.projection_layer = projection_layer
self.config = language_model.config
self.set_trainable_params(freeze_language_model, freeze_projection_layer)
# Convert all parameters to float32
self.to(torch.float32)
def set_trainable_params(self, freeze_language_model, freeze_projection_layer):
for param in self.language_model.parameters():
param.requires_grad = not freeze_language_model
for param in self.projection_layer.parameters():
param.requires_grad = not freeze_projection_layer
def forward(self, input_ids=None, attention_mask=None, image_embeddings=None, labels=None, inputs_embeds=None, **kwargs):
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
# Ensure all inputs are in float32
if inputs_embeds is not None:
inputs_embeds = inputs_embeds.to(torch.float32)
if image_embeddings is not None:
image_embeddings = image_embeddings.to(torch.float32)
if inputs_embeds is None:
if image_embeddings is not None:
projected_embeddings = self.projection_layer(image_embeddings)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1)
else:
combined_embeds = self.language_model.get_input_embeddings()(input_ids)
else:
combined_embeds = inputs_embeds
if attention_mask is not None:
if image_embeddings is not None:
image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=combined_embeds.device)
combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1)
else:
combined_attention_mask = attention_mask
else:
combined_attention_mask = None
# Ensure combined_embeds is in float32
combined_embeds = combined_embeds.to(torch.float32)
if labels is not None and image_embeddings is not None:
pad_labels = torch.full((batch_size, 1), -100, dtype=labels.dtype, device=labels.device)
labels = torch.cat([pad_labels, labels], dim=1)
outputs = self.language_model(
inputs_embeds=combined_embeds,
attention_mask=combined_attention_mask,
labels=labels,
**kwargs
)
return outputs
def generate(self, input_ids, attention_mask, image_embeddings=None, **kwargs):
if image_embeddings is not None:
batch_size = input_ids.shape[0]
projected_embeddings = self.projection_layer(image_embeddings)
# Prepend projected image embeddings to the input sequence
input_embeds = self.language_model.get_input_embeddings()(input_ids)
combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1)
# Adjust attention mask for the added image token
image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=input_ids.device)
combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1)
# Call the base model's generate method with the combined embeddings
return self.language_model.generate(
inputs_embeds=combined_embeds,
attention_mask=combined_attention_mask,
**kwargs
)
else:
# Handle text-only input
return self.language_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.language_model.prepare_inputs_for_generation(*args, **kwargs)
def count_trainable_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def count_total_parameters(self):
return sum(p.numel() for p in self.parameters())
clip_embedding_dim = 512 #clip_embeddings.shape[1]
# Then, after the import:
processor, model_clip = load_clip()
# Load your Phi model
# if os.path.exists("local_phi2_model"):
# base_model = AutoModelForCausalLM.from_pretrained("local_phi2_model").to(device)
# else:
# base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device)
# base_model.save_pretrained("local_phi2_model")
from transformers import PhiForCausalLM, PhiConfig
if os.path.exists("local_phi2_model"):
base_model = PhiForCausalLM.from_pretrained("local_phi2_model").to(device)
else:
config = PhiConfig.from_pretrained("microsoft/phi-2")
base_model = PhiForCausalLM.from_pretrained("microsoft/phi-2", config=config).to(device)
#base_model.save_pretrained("local_phi2_model")
if os.path.exists("local_phi2_tokenizer"):
tokenizer = AutoTokenizer.from_pretrained("local_phi2_tokenizer")
print("tokenizer loaded from local")
else:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
tokenizer.save_pretrained("local_phi2_tokenizer")
print("tokenizer loaded from HF")
peft_config = PeftConfig.from_pretrained("118k_12k_itr_fine_tuned_phi_lora")
projection_layer = ProjectionLayer(clip_embedding_dim, base_model.config.hidden_size).to(device).to(torch.float32)
projection_layer.load_state_dict(torch.load("118k_12k_itr_projection_layer.pt", map_location=device))
fine_tuned_model = PeftModel.from_pretrained(base_model, "118k_12k_itr_fine_tuned_phi_lora").to(device)
whole_model = MultimodalPhiWithAdapter(fine_tuned_model, projection_layer).to(device).to(torch.float32)
# Replace the load_file call with this:
def load_safetensors(path):
with safe_open(path, framework="pt") as f:
metadata = f.metadata() # This is already a dict, no need for json.loads()
tensors = {k: f.get_tensor(k) for k in f.keys()}
return tensors, metadata
state_dict, metadata = load_safetensors("118k_4k_itr_fine_tuned_phi_lora/adapter_model.safetensors")
state_dict = {k: v.to(device).to(torch.float32) for k, v in state_dict.items()}
whole_model.load_state_dict(state_dict, strict=False)
# Set the model to evaluation mode
whole_model.eval()
whisper_model_path = "whisper-small-local"
if os.path.exists(whisper_model_path):
#print("Loading Whisper model from local storage...")
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_path)
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_path)
else:
print("Downloading Whisper model from Hugging Face...")
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
def transcribe_audio(audio):
if isinstance(audio, str): # If it's a file path
audio, sr = librosa.load(audio, sr=16000)
else: # If it's already loaded audio data
sr, audio = audio
audio = audio.astype(np.float32)
audio = librosa.util.fix_length(audio, size=480000)
input_features = whisper_processor(audio, sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
generated_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription
def process_input(text_input, file_input):
input_text = text_input or ""
image_features = None
file_description = ""
if file_input is not None:
file_path = file_input.name
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
image = Image.open(file_path)
print(f"image loaded, {file_path} {file_extension}")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model_clip.get_image_features(**inputs)
file_description = "An image has been uploaded. "
# Add a specific prompt for image analysis
if not input_text:
input_text = "Please describe and analyze the content of the uploaded image."
else:
input_text = f"Regarding the uploaded image: {input_text}"
elif file_extension in ['.mp3', '.wav', '.ogg']:
transcribed_text = transcribe_audio(file_path)
file_description = f"Transcribed audio: '{transcribed_text}'. "
else:
file_description = f"Unsupported file type: {os.path.basename(file_path)}. "
if not input_text and file_description:
input_text = f"{file_description}Please describe or analyze the content."
elif file_description:
input_text = f"{file_description} {input_text}"
encoded_text = tokenizer(input_text, return_tensors="pt").to(device)
# Handle case when there's no image
if image_features is not None:
image_features = image_features.to(device).to(torch.float32)
else:
image_features = torch.zeros((1, clip_embedding_dim), device=device, dtype=torch.float32)
# In the generate function
with torch.no_grad():
# Choose parameters based on whether image_features are present
if image_features is not None and abs(image_features.sum().item()) > 1e-6:
generation_params = {
"max_new_tokens": 400,
"num_beams": 4,
"do_sample": True,
"temperature": 0.8,
"top_k": 40,
"top_p": 0.92,
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
}
else:
generation_params = {
"max_new_tokens": 150,
"num_beams": 5,
"do_sample": True,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
"no_repeat_ngram_size": 3,
"repetition_penalty": 1.1,
}
outputs = whole_model.generate(
input_ids=encoded_text.input_ids,
attention_mask=encoded_text.attention_mask,
image_embeddings=image_features,
bad_words_ids=[[tokenizer.unk_token_id]],
renormalize_logits=True,
**generation_params
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input text from the generated output
response = generated_text #[len(input_text):].strip()
if not response:
response = "I apologize, but I couldn't generate a meaningful response. Could you please provide more context or try rephrasing your input?"
return response
def chat(message, history, file):
try:
response = process_input(message, file)
if file:
file_name = file.name
history.append((f"[File uploaded: {file_name}]", response))
elif message:
history.append((message, response))
except Exception as e:
history.append((message, f"An error occurred: {str(e)}"))
return history, "" # Clear the input box after sending
def new_chat():
return [], "" # Return an empty list for chatbot and an empty string for file_indicator
with gr.Blocks(css="""
body { margin: 0; padding: 0; } /* Removes default body margins */
#chatbot {
position: fixed; /* Fixed position relative to the viewport */
top: 50px; /* Space for the New Chat button */
left: 0;
right: 0;
bottom: 60px; /* Space for the input area */
overflow-y: auto; /* Adds vertical scrollbar when content overflows */
padding: 20px; /* Inner spacing */
}
#chatbot > div > div {
margin-bottom: 15px; /* Space between chat messages */
max-width: 80%; /* Maximum width of chat bubbles */
}
#chatbot > div > div:nth-child(odd) {
margin-left: auto; /* Aligns user messages to the right */
text-align: right;
}
#chatbot > div > div:nth-child(even) {
margin-right: auto; /* Aligns AI messages to the left */
text-align: left;
}
#input-area {
position: fixed; /* Keeps input area at the bottom of the screen */
bottom: 0;
left: 0;
right: 0;
background-color: #2c2c2c; /* Background color of input area */
padding: 10px; /* Inner spacing of input area */
height: 60px; /* Height of input area */
box-sizing: border-box; /* Includes padding in the element's total width and height */
}
#input-box {
max-width: 90%; /* Maximum width of the input box, using percentage for responsiveness */
width: 100%;
margin: 0 auto; /* Centers the input box */
display: flex; /* Allows flexible alignment of child elements */
align-items: center; /* Vertically centers items in the input box */
height: 100%; /* Takes full height of parent */
}
#file-indicator {
font-size: 12px; /* Keeping original font size */
color: #aaa; /* Color of file indicator text */
position: absolute;
top: -20px; /* Positions indicator above the input area */
left: 0;
right: 0;
text-align: center;
}
#new-chat-btn {
position: fixed; /* Keeps button in a fixed position */
top: 10px; /* Distance from top of viewport */
left: 10px; /* Distance from left of viewport */
width: auto; /* Allow button to size based on content */
height: auto;
padding: 5px 10px;
}
#file-upload {
width: 40px; /* Width of file upload button */
height: 40px; /* Height of file upload button */
padding: 0;
flex-shrink: 0; /* Prevents button from shrinking */
}
#file-upload > button {
border-radius: 50%; /* Makes the button circular */
padding: 8px;
width: 100%;
height: 100%;
}
#msg {
flex-grow: 1; /* Allows the text input to grow and fill available space */
height: 40px; /* Height of text input */
margin-right: 10px; /* Space between text input and file upload button */
}
/* Media query for smaller screens */
@media (max-width: 600px) {
#chatbot {
top: 40px; /* Slightly less space for the button on small screens */
bottom: 50px; /* Slightly less space for the input area on small screens */
padding: 10px; /* Reduced padding on small screens */
}
#input-area {
height: 50px; /* Slightly reduced height on small screens */
}
#input-box {
max-width: 95%; /* Wider input box on small screens */
}
#file-upload, #msg {
height: 30px; /* Slightly smaller height for inputs on small screens */
}
#new-chat-btn {
font-size: 14px; /* Keeping font size consistent, adjust if needed */
}
}
#app-description {
position: fixed;
top: 10px;
left: 50%;
transform: translateX(-50%);
text-align: center;
width: 90%;
max-width: 800px;
z-index: 1000;
background-color: #ffffff; /* Darker background color */
padding: 10px;
border-radius: 5px;
color: #1a1a1a; /* Pure white text color */
font-weight: bold; /* Make the text bold */
}
#new-chat-btn {
top: 60px; /* Adjusted to make room for the description */
}
#chatbot {
top: 100px; /* Adjusted to make room for the description and button */
}
@media (max-width: 600px) {
#app-description {
font-size: 14px;
padding: 5px;
}
#new-chat-btn {
top: 50px;
}
#chatbot {
top: 90px;
}
}
""") as demo:
gr.Markdown(
"""
MultiModal phi-2: This app takes text from bottom input bar.
Audio and image files can be attached with attachment button.
(After attaching audio or image file, click in text box and enter to upload file)
""",
elem_id="app-description"
)
new_chat_btn = gr.Button("New Chat", elem_id="new-chat-btn")
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, show_label=False)
with gr.Column(elem_id="input-area"):
file_indicator = gr.Markdown(elem_id="file-indicator")
with gr.Row(elem_id="input-box"):
msg = gr.Textbox(
show_label=False,
placeholder="Send a message...",
container=False,
elem_id="msg"
)
file_input = gr.UploadButton("๐Ÿ“Ž", file_types=["image", "audio"], elem_id="file-upload")
def handle_file_upload(file):
if file:
return gr.update(value=""), f"๐Ÿ“Ž Uploaded: {os.path.basename(file.name)}" # just actual file name, not full path here
return gr.update(value=""), ""
def process_message(message, history, file):
try:
if file:
file_name = file.name
history.append((f"๐Ÿ“Ž [File uploaded: {os.path.basename(file_name)}]", None))
if message:
history.append((message, None))
response = process_input(message, file)
if message or file:
history[-1] = (history[-1][0], response)
except Exception as e:
history.append((message, f"An error occurred: {str(e)}"))
return history, "", None # Clear the input box and file input after sending
msg.submit(process_message, inputs=[msg, chatbot, file_input], outputs=[chatbot, msg, file_indicator])
file_input.upload(handle_file_upload, inputs=[file_input], outputs=[msg, file_indicator])
new_chat_btn.click(new_chat, outputs=[chatbot, file_indicator])
demo.launch(share=True)
# Test model loading
# test_input = "Human: Hello, how are you? Explain share market to me: " # Add a prompt format
# test_encoded = tokenizer(test_input, return_tensors="pt").to(device)
# test_image_features = torch.zeros((1, clip_embedding_dim), device=device, dtype=torch.float32)
# with torch.no_grad():
# test_output = whole_model.generate(
# input_ids=test_encoded.input_ids,
# attention_mask=test_encoded.attention_mask,
# image_embeddings=test_image_features,
# max_new_tokens=200,
# num_beams=5,
# do_sample=True,
# temperature=0.8,
# top_k=50,
# top_p=0.92,
# no_repeat_ngram_size=3,
# repetition_penalty=1.2,
# )
# test_response = tokenizer.decode(test_output[0], skip_special_tokens=True)
# # Handle the case where 'AI:' is not in the response
# if "AI:" in test_response:
# ai_response = test_response.split("AI:")[1].strip()
# else:
# ai_response = test_response.replace(test_input, "").strip()
# print("Test model output:", ai_response)