import gradio as gr import torch from PIL import Image import librosa import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel from peft import PeftConfig, PeftModel from safetensors.torch import load_file import sys import os import torch.nn as nn from safetensors import safe_open import json torch.autograd.set_detect_anomaly(True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ---- def load_clip(): # -- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Load CLIP processor and model, saving locally if not already present if os.path.exists("data/processor_clip_embeddings_vit_base_patch32.pt"): processor = CLIPProcessor.from_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt") else: processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # save processor for later use processor.save_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt") model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) model_clip = model_clip.to(device).to(torch.float32) model_clip.eval() # Set CLIP model to evaluation mode return processor, model_clip class ProjectionLayer(nn.Module): def __init__(self, clip_embedding_dim, phi_hidden_dim): super().__init__() self.linear = nn.Linear(clip_embedding_dim, phi_hidden_dim) # Comment out the layer normalization for now # self.ln = nn.LayerNorm(phi_hidden_dim) def forward(self, image_embeddings): # Return only the linear transformation return self.linear(image_embeddings) # return self.ln(self.linear(image_embeddings)) class MultimodalPhiWithAdapter(nn.Module): def __init__(self, language_model, projection_layer, freeze_language_model=True, freeze_projection_layer=False): super().__init__() self.language_model = language_model self.projection_layer = projection_layer self.config = language_model.config self.set_trainable_params(freeze_language_model, freeze_projection_layer) # Convert all parameters to float32 self.to(torch.float32) def set_trainable_params(self, freeze_language_model, freeze_projection_layer): for param in self.language_model.parameters(): param.requires_grad = not freeze_language_model for param in self.projection_layer.parameters(): param.requires_grad = not freeze_projection_layer def forward(self, input_ids=None, attention_mask=None, image_embeddings=None, labels=None, inputs_embeds=None, **kwargs): batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] # Ensure all inputs are in float32 if inputs_embeds is not None: inputs_embeds = inputs_embeds.to(torch.float32) if image_embeddings is not None: image_embeddings = image_embeddings.to(torch.float32) if inputs_embeds is None: if image_embeddings is not None: projected_embeddings = self.projection_layer(image_embeddings) input_embeds = self.language_model.get_input_embeddings()(input_ids) combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1) else: combined_embeds = self.language_model.get_input_embeddings()(input_ids) else: combined_embeds = inputs_embeds if attention_mask is not None: if image_embeddings is not None: image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=combined_embeds.device) combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1) else: combined_attention_mask = attention_mask else: combined_attention_mask = None # Ensure combined_embeds is in float32 combined_embeds = combined_embeds.to(torch.float32) if labels is not None and image_embeddings is not None: pad_labels = torch.full((batch_size, 1), -100, dtype=labels.dtype, device=labels.device) labels = torch.cat([pad_labels, labels], dim=1) outputs = self.language_model( inputs_embeds=combined_embeds, attention_mask=combined_attention_mask, labels=labels, **kwargs ) return outputs def generate(self, input_ids, attention_mask, image_embeddings=None, **kwargs): if image_embeddings is not None: batch_size = input_ids.shape[0] projected_embeddings = self.projection_layer(image_embeddings) # Prepend projected image embeddings to the input sequence input_embeds = self.language_model.get_input_embeddings()(input_ids) combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1) # Adjust attention mask for the added image token image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=input_ids.device) combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1) # Call the base model's generate method with the combined embeddings return self.language_model.generate( inputs_embeds=combined_embeds, attention_mask=combined_attention_mask, **kwargs ) else: # Handle text-only input return self.language_model.generate( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) def prepare_inputs_for_generation(self, *args, **kwargs): return self.language_model.prepare_inputs_for_generation(*args, **kwargs) def count_trainable_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) def count_total_parameters(self): return sum(p.numel() for p in self.parameters()) clip_embedding_dim = 512 #clip_embeddings.shape[1] # Then, after the import: processor, model_clip = load_clip() # Load your Phi model # if os.path.exists("local_phi2_model"): # base_model = AutoModelForCausalLM.from_pretrained("local_phi2_model").to(device) # else: # base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device) # base_model.save_pretrained("local_phi2_model") from transformers import PhiForCausalLM, PhiConfig if os.path.exists("local_phi2_model"): base_model = PhiForCausalLM.from_pretrained("local_phi2_model").to(device) else: config = PhiConfig.from_pretrained("microsoft/phi-2") base_model = PhiForCausalLM.from_pretrained("microsoft/phi-2", config=config).to(device) #base_model.save_pretrained("local_phi2_model") if os.path.exists("local_phi2_tokenizer"): tokenizer = AutoTokenizer.from_pretrained("local_phi2_tokenizer") print("tokenizer loaded from local") else: tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") tokenizer.save_pretrained("local_phi2_tokenizer") print("tokenizer loaded from HF") peft_config = PeftConfig.from_pretrained("118k_12k_itr_fine_tuned_phi_lora") projection_layer = ProjectionLayer(clip_embedding_dim, base_model.config.hidden_size).to(device).to(torch.float32) projection_layer.load_state_dict(torch.load("118k_12k_itr_projection_layer.pt", map_location=device)) fine_tuned_model = PeftModel.from_pretrained(base_model, "118k_12k_itr_fine_tuned_phi_lora").to(device) whole_model = MultimodalPhiWithAdapter(fine_tuned_model, projection_layer).to(device).to(torch.float32) # Replace the load_file call with this: def load_safetensors(path): with safe_open(path, framework="pt") as f: metadata = f.metadata() # This is already a dict, no need for json.loads() tensors = {k: f.get_tensor(k) for k in f.keys()} return tensors, metadata state_dict, metadata = load_safetensors("118k_4k_itr_fine_tuned_phi_lora/adapter_model.safetensors") state_dict = {k: v.to(device).to(torch.float32) for k, v in state_dict.items()} whole_model.load_state_dict(state_dict, strict=False) # Set the model to evaluation mode whole_model.eval() whisper_model_path = "whisper-small-local" if os.path.exists(whisper_model_path): #print("Loading Whisper model from local storage...") whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_path) whisper_processor = WhisperProcessor.from_pretrained(whisper_model_path) else: print("Downloading Whisper model from Hugging Face...") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") def transcribe_audio(audio): if isinstance(audio, str): # If it's a file path audio, sr = librosa.load(audio, sr=16000) else: # If it's already loaded audio data sr, audio = audio audio = audio.astype(np.float32) audio = librosa.util.fix_length(audio, size=480000) input_features = whisper_processor(audio, sampling_rate=16000, return_tensors="pt").input_features with torch.no_grad(): generated_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return transcription def process_input(text_input, file_input): input_text = text_input or "" image_features = None file_description = "" if file_input is not None: file_path = file_input.name file_extension = os.path.splitext(file_path)[1].lower() if file_extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']: image = Image.open(file_path) print(f"image loaded, {file_path} {file_extension}") inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): image_features = model_clip.get_image_features(**inputs) file_description = "An image has been uploaded. " # Add a specific prompt for image analysis if not input_text: input_text = "Please describe and analyze the content of the uploaded image." else: input_text = f"Regarding the uploaded image: {input_text}" elif file_extension in ['.mp3', '.wav', '.ogg']: transcribed_text = transcribe_audio(file_path) file_description = f"Transcribed audio: '{transcribed_text}'. " else: file_description = f"Unsupported file type: {os.path.basename(file_path)}. " if not input_text and file_description: input_text = f"{file_description}Please describe or analyze the content." elif file_description: input_text = f"{file_description} {input_text}" encoded_text = tokenizer(input_text, return_tensors="pt").to(device) # Handle case when there's no image if image_features is not None: image_features = image_features.to(device).to(torch.float32) else: image_features = torch.zeros((1, clip_embedding_dim), device=device, dtype=torch.float32) # In the generate function with torch.no_grad(): # Choose parameters based on whether image_features are present if image_features is not None and abs(image_features.sum().item()) > 1e-6: generation_params = { "max_new_tokens": 400, "num_beams": 4, "do_sample": True, "temperature": 0.8, "top_k": 40, "top_p": 0.92, "no_repeat_ngram_size": 2, "repetition_penalty": 1.2, } else: generation_params = { "max_new_tokens": 150, "num_beams": 5, "do_sample": True, "temperature": 0.7, "top_k": 50, "top_p": 0.95, "no_repeat_ngram_size": 3, "repetition_penalty": 1.1, } outputs = whole_model.generate( input_ids=encoded_text.input_ids, attention_mask=encoded_text.attention_mask, image_embeddings=image_features, bad_words_ids=[[tokenizer.unk_token_id]], renormalize_logits=True, **generation_params ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the input text from the generated output response = generated_text #[len(input_text):].strip() if not response: response = "I apologize, but I couldn't generate a meaningful response. Could you please provide more context or try rephrasing your input?" return response def chat(message, history, file): try: response = process_input(message, file) if file: file_name = file.name history.append((f"[File uploaded: {file_name}]", response)) elif message: history.append((message, response)) except Exception as e: history.append((message, f"An error occurred: {str(e)}")) return history, "" # Clear the input box after sending def new_chat(): return [], "" # Return an empty list for chatbot and an empty string for file_indicator with gr.Blocks(css=""" body { margin: 0; padding: 0; } /* Removes default body margins */ #chatbot { position: fixed; /* Fixed position relative to the viewport */ top: 50px; /* Space for the New Chat button */ left: 0; right: 0; bottom: 60px; /* Space for the input area */ overflow-y: auto; /* Adds vertical scrollbar when content overflows */ padding: 20px; /* Inner spacing */ } #chatbot > div > div { margin-bottom: 15px; /* Space between chat messages */ max-width: 80%; /* Maximum width of chat bubbles */ } #chatbot > div > div:nth-child(odd) { margin-left: auto; /* Aligns user messages to the right */ text-align: right; } #chatbot > div > div:nth-child(even) { margin-right: auto; /* Aligns AI messages to the left */ text-align: left; } #input-area { position: fixed; /* Keeps input area at the bottom of the screen */ bottom: 0; left: 0; right: 0; background-color: #2c2c2c; /* Background color of input area */ padding: 10px; /* Inner spacing of input area */ height: 60px; /* Height of input area */ box-sizing: border-box; /* Includes padding in the element's total width and height */ } #input-box { max-width: 90%; /* Maximum width of the input box, using percentage for responsiveness */ width: 100%; margin: 0 auto; /* Centers the input box */ display: flex; /* Allows flexible alignment of child elements */ align-items: center; /* Vertically centers items in the input box */ height: 100%; /* Takes full height of parent */ } #file-indicator { font-size: 12px; /* Keeping original font size */ color: #aaa; /* Color of file indicator text */ position: absolute; top: -20px; /* Positions indicator above the input area */ left: 0; right: 0; text-align: center; } #new-chat-btn { position: fixed; /* Keeps button in a fixed position */ top: 10px; /* Distance from top of viewport */ left: 10px; /* Distance from left of viewport */ width: auto; /* Allow button to size based on content */ height: auto; padding: 5px 10px; } #file-upload { width: 40px; /* Width of file upload button */ height: 40px; /* Height of file upload button */ padding: 0; flex-shrink: 0; /* Prevents button from shrinking */ } #file-upload > button { border-radius: 50%; /* Makes the button circular */ padding: 8px; width: 100%; height: 100%; } #msg { flex-grow: 1; /* Allows the text input to grow and fill available space */ height: 40px; /* Height of text input */ margin-right: 10px; /* Space between text input and file upload button */ } /* Media query for smaller screens */ @media (max-width: 600px) { #chatbot { top: 40px; /* Slightly less space for the button on small screens */ bottom: 50px; /* Slightly less space for the input area on small screens */ padding: 10px; /* Reduced padding on small screens */ } #input-area { height: 50px; /* Slightly reduced height on small screens */ } #input-box { max-width: 95%; /* Wider input box on small screens */ } #file-upload, #msg { height: 30px; /* Slightly smaller height for inputs on small screens */ } #new-chat-btn { font-size: 14px; /* Keeping font size consistent, adjust if needed */ } } #app-description { position: fixed; top: 10px; left: 50%; transform: translateX(-50%); text-align: center; width: 90%; max-width: 800px; z-index: 1000; background-color: #ffffff; /* Darker background color */ padding: 10px; border-radius: 5px; color: #1a1a1a; /* Pure white text color */ font-weight: bold; /* Make the text bold */ } #new-chat-btn { top: 60px; /* Adjusted to make room for the description */ } #chatbot { top: 100px; /* Adjusted to make room for the description and button */ } @media (max-width: 600px) { #app-description { font-size: 14px; padding: 5px; } #new-chat-btn { top: 50px; } #chatbot { top: 90px; } } """) as demo: gr.Markdown( """ MultiModal phi-2: This app takes text from bottom input bar. Audio and image files can be attached with attachment button. (After attaching audio or image file, click in text box and enter to upload file) """, elem_id="app-description" ) new_chat_btn = gr.Button("New Chat", elem_id="new-chat-btn") chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, show_label=False) with gr.Column(elem_id="input-area"): file_indicator = gr.Markdown(elem_id="file-indicator") with gr.Row(elem_id="input-box"): msg = gr.Textbox( show_label=False, placeholder="Send a message...", container=False, elem_id="msg" ) file_input = gr.UploadButton("📎", file_types=["image", "audio"], elem_id="file-upload") def handle_file_upload(file): if file: return gr.update(value=""), f"📎 Uploaded: {os.path.basename(file.name)}" # just actual file name, not full path here return gr.update(value=""), "" def process_message(message, history, file): try: if file: file_name = file.name history.append((f"📎 [File uploaded: {os.path.basename(file_name)}]", None)) if message: history.append((message, None)) response = process_input(message, file) if message or file: history[-1] = (history[-1][0], response) except Exception as e: history.append((message, f"An error occurred: {str(e)}")) return history, "", None # Clear the input box and file input after sending msg.submit(process_message, inputs=[msg, chatbot, file_input], outputs=[chatbot, msg, file_indicator]) file_input.upload(handle_file_upload, inputs=[file_input], outputs=[msg, file_indicator]) new_chat_btn.click(new_chat, outputs=[chatbot, file_indicator]) demo.launch(share=True) # Test model loading # test_input = "Human: Hello, how are you? Explain share market to me: " # Add a prompt format # test_encoded = tokenizer(test_input, return_tensors="pt").to(device) # test_image_features = torch.zeros((1, clip_embedding_dim), device=device, dtype=torch.float32) # with torch.no_grad(): # test_output = whole_model.generate( # input_ids=test_encoded.input_ids, # attention_mask=test_encoded.attention_mask, # image_embeddings=test_image_features, # max_new_tokens=200, # num_beams=5, # do_sample=True, # temperature=0.8, # top_k=50, # top_p=0.92, # no_repeat_ngram_size=3, # repetition_penalty=1.2, # ) # test_response = tokenizer.decode(test_output[0], skip_special_tokens=True) # # Handle the case where 'AI:' is not in the response # if "AI:" in test_response: # ai_response = test_response.split("AI:")[1].strip() # else: # ai_response = test_response.replace(test_input, "").strip() # print("Test model output:", ai_response)