Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import librosa | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel | |
| from peft import PeftConfig, PeftModel | |
| from safetensors.torch import load_file | |
| import sys | |
| import os | |
| import torch.nn as nn | |
| from safetensors import safe_open | |
| import json | |
| torch.autograd.set_detect_anomaly(True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ---- | |
| def load_clip(): | |
| # -- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Load CLIP processor and model, saving locally if not already present | |
| if os.path.exists("data/processor_clip_embeddings_vit_base_patch32.pt"): | |
| processor = CLIPProcessor.from_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt") | |
| else: | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # save processor for later use | |
| processor.save_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt") | |
| model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| model_clip = model_clip.to(device).to(torch.float32) | |
| model_clip.eval() # Set CLIP model to evaluation mode | |
| return processor, model_clip | |
| class ProjectionLayer(nn.Module): | |
| def __init__(self, clip_embedding_dim, phi_hidden_dim): | |
| super().__init__() | |
| self.linear = nn.Linear(clip_embedding_dim, phi_hidden_dim) | |
| # Comment out the layer normalization for now | |
| # self.ln = nn.LayerNorm(phi_hidden_dim) | |
| def forward(self, image_embeddings): | |
| # Return only the linear transformation | |
| return self.linear(image_embeddings) | |
| # return self.ln(self.linear(image_embeddings)) | |
| class MultimodalPhiWithAdapter(nn.Module): | |
| def __init__(self, language_model, projection_layer, freeze_language_model=True, freeze_projection_layer=False): | |
| super().__init__() | |
| self.language_model = language_model | |
| self.projection_layer = projection_layer | |
| self.config = language_model.config | |
| self.set_trainable_params(freeze_language_model, freeze_projection_layer) | |
| # Convert all parameters to float32 | |
| self.to(torch.float32) | |
| def set_trainable_params(self, freeze_language_model, freeze_projection_layer): | |
| for param in self.language_model.parameters(): | |
| param.requires_grad = not freeze_language_model | |
| for param in self.projection_layer.parameters(): | |
| param.requires_grad = not freeze_projection_layer | |
| def forward(self, input_ids=None, attention_mask=None, image_embeddings=None, labels=None, inputs_embeds=None, **kwargs): | |
| batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] | |
| # Ensure all inputs are in float32 | |
| if inputs_embeds is not None: | |
| inputs_embeds = inputs_embeds.to(torch.float32) | |
| if image_embeddings is not None: | |
| image_embeddings = image_embeddings.to(torch.float32) | |
| if inputs_embeds is None: | |
| if image_embeddings is not None: | |
| projected_embeddings = self.projection_layer(image_embeddings) | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1) | |
| else: | |
| combined_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| else: | |
| combined_embeds = inputs_embeds | |
| if attention_mask is not None: | |
| if image_embeddings is not None: | |
| image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=combined_embeds.device) | |
| combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1) | |
| else: | |
| combined_attention_mask = attention_mask | |
| else: | |
| combined_attention_mask = None | |
| # Ensure combined_embeds is in float32 | |
| combined_embeds = combined_embeds.to(torch.float32) | |
| if labels is not None and image_embeddings is not None: | |
| pad_labels = torch.full((batch_size, 1), -100, dtype=labels.dtype, device=labels.device) | |
| labels = torch.cat([pad_labels, labels], dim=1) | |
| outputs = self.language_model( | |
| inputs_embeds=combined_embeds, | |
| attention_mask=combined_attention_mask, | |
| labels=labels, | |
| **kwargs | |
| ) | |
| return outputs | |
| def generate(self, input_ids, attention_mask, image_embeddings=None, **kwargs): | |
| if image_embeddings is not None: | |
| batch_size = input_ids.shape[0] | |
| projected_embeddings = self.projection_layer(image_embeddings) | |
| # Prepend projected image embeddings to the input sequence | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1) | |
| # Adjust attention mask for the added image token | |
| image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=input_ids.device) | |
| combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1) | |
| # Call the base model's generate method with the combined embeddings | |
| return self.language_model.generate( | |
| inputs_embeds=combined_embeds, | |
| attention_mask=combined_attention_mask, | |
| **kwargs | |
| ) | |
| else: | |
| # Handle text-only input | |
| return self.language_model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **kwargs | |
| ) | |
| def prepare_inputs_for_generation(self, *args, **kwargs): | |
| return self.language_model.prepare_inputs_for_generation(*args, **kwargs) | |
| def count_trainable_parameters(self): | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def count_total_parameters(self): | |
| return sum(p.numel() for p in self.parameters()) | |
| clip_embedding_dim = 512 #clip_embeddings.shape[1] | |
| # Then, after the import: | |
| processor, model_clip = load_clip() | |
| # Load your Phi model | |
| # if os.path.exists("local_phi2_model"): | |
| # base_model = AutoModelForCausalLM.from_pretrained("local_phi2_model").to(device) | |
| # else: | |
| # base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device) | |
| # base_model.save_pretrained("local_phi2_model") | |
| from transformers import PhiForCausalLM, PhiConfig | |
| if os.path.exists("local_phi2_model"): | |
| base_model = PhiForCausalLM.from_pretrained("local_phi2_model").to(device) | |
| else: | |
| config = PhiConfig.from_pretrained("microsoft/phi-2") | |
| base_model = PhiForCausalLM.from_pretrained("microsoft/phi-2", config=config).to(device) | |
| #base_model.save_pretrained("local_phi2_model") | |
| if os.path.exists("local_phi2_tokenizer"): | |
| tokenizer = AutoTokenizer.from_pretrained("local_phi2_tokenizer") | |
| print("tokenizer loaded from local") | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") | |
| tokenizer.save_pretrained("local_phi2_tokenizer") | |
| print("tokenizer loaded from HF") | |
| peft_config = PeftConfig.from_pretrained("118k_12k_itr_fine_tuned_phi_lora") | |
| projection_layer = ProjectionLayer(clip_embedding_dim, base_model.config.hidden_size).to(device).to(torch.float32) | |
| projection_layer.load_state_dict(torch.load("118k_12k_itr_projection_layer.pt", map_location=device)) | |
| fine_tuned_model = PeftModel.from_pretrained(base_model, "118k_12k_itr_fine_tuned_phi_lora").to(device) | |
| whole_model = MultimodalPhiWithAdapter(fine_tuned_model, projection_layer).to(device).to(torch.float32) | |
| # Replace the load_file call with this: | |
| def load_safetensors(path): | |
| with safe_open(path, framework="pt") as f: | |
| metadata = f.metadata() # This is already a dict, no need for json.loads() | |
| tensors = {k: f.get_tensor(k) for k in f.keys()} | |
| return tensors, metadata | |
| state_dict, metadata = load_safetensors("118k_4k_itr_fine_tuned_phi_lora/adapter_model.safetensors") | |
| state_dict = {k: v.to(device).to(torch.float32) for k, v in state_dict.items()} | |
| whole_model.load_state_dict(state_dict, strict=False) | |
| # Set the model to evaluation mode | |
| whole_model.eval() | |
| whisper_model_path = "whisper-small-local" | |
| if os.path.exists(whisper_model_path): | |
| #print("Loading Whisper model from local storage...") | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_path) | |
| whisper_processor = WhisperProcessor.from_pretrained(whisper_model_path) | |
| else: | |
| print("Downloading Whisper model from Hugging Face...") | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
| whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| def transcribe_audio(audio): | |
| if isinstance(audio, str): # If it's a file path | |
| audio, sr = librosa.load(audio, sr=16000) | |
| else: # If it's already loaded audio data | |
| sr, audio = audio | |
| audio = audio.astype(np.float32) | |
| audio = librosa.util.fix_length(audio, size=480000) | |
| input_features = whisper_processor(audio, sampling_rate=16000, return_tensors="pt").input_features | |
| with torch.no_grad(): | |
| generated_ids = whisper_model.generate(input_features) | |
| transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return transcription | |
| def process_input(text_input, file_input): | |
| input_text = text_input or "" | |
| image_features = None | |
| file_description = "" | |
| if file_input is not None: | |
| file_path = file_input.name | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| if file_extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']: | |
| image = Image.open(file_path) | |
| print(f"image loaded, {file_path} {file_extension}") | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| image_features = model_clip.get_image_features(**inputs) | |
| file_description = "An image has been uploaded. " | |
| # Add a specific prompt for image analysis | |
| if not input_text: | |
| input_text = "Please describe and analyze the content of the uploaded image." | |
| else: | |
| input_text = f"Regarding the uploaded image: {input_text}" | |
| elif file_extension in ['.mp3', '.wav', '.ogg']: | |
| transcribed_text = transcribe_audio(file_path) | |
| file_description = f"Transcribed audio: '{transcribed_text}'. " | |
| else: | |
| file_description = f"Unsupported file type: {os.path.basename(file_path)}. " | |
| if not input_text and file_description: | |
| input_text = f"{file_description}Please describe or analyze the content." | |
| elif file_description: | |
| input_text = f"{file_description} {input_text}" | |
| encoded_text = tokenizer(input_text, return_tensors="pt").to(device) | |
| # Handle case when there's no image | |
| if image_features is not None: | |
| image_features = image_features.to(device).to(torch.float32) | |
| else: | |
| image_features = torch.zeros((1, clip_embedding_dim), device=device, dtype=torch.float32) | |
| # In the generate function | |
| with torch.no_grad(): | |
| # Choose parameters based on whether image_features are present | |
| if image_features is not None and abs(image_features.sum().item()) > 1e-6: | |
| generation_params = { | |
| "max_new_tokens": 400, | |
| "num_beams": 4, | |
| "do_sample": True, | |
| "temperature": 0.8, | |
| "top_k": 40, | |
| "top_p": 0.92, | |
| "no_repeat_ngram_size": 2, | |
| "repetition_penalty": 1.2, | |
| } | |
| else: | |
| generation_params = { | |
| "max_new_tokens": 150, | |
| "num_beams": 5, | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "top_k": 50, | |
| "top_p": 0.95, | |
| "no_repeat_ngram_size": 3, | |
| "repetition_penalty": 1.1, | |
| } | |
| outputs = whole_model.generate( | |
| input_ids=encoded_text.input_ids, | |
| attention_mask=encoded_text.attention_mask, | |
| image_embeddings=image_features, | |
| bad_words_ids=[[tokenizer.unk_token_id]], | |
| renormalize_logits=True, | |
| **generation_params | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the input text from the generated output | |
| response = generated_text #[len(input_text):].strip() | |
| if not response: | |
| response = "I apologize, but I couldn't generate a meaningful response. Could you please provide more context or try rephrasing your input?" | |
| return response | |
| def chat(message, history, file): | |
| try: | |
| response = process_input(message, file) | |
| if file: | |
| file_name = file.name | |
| history.append((f"[File uploaded: {file_name}]", response)) | |
| elif message: | |
| history.append((message, response)) | |
| except Exception as e: | |
| history.append((message, f"An error occurred: {str(e)}")) | |
| return history, "" # Clear the input box after sending | |
| def new_chat(): | |
| return [], "" # Return an empty list for chatbot and an empty string for file_indicator | |
| with gr.Blocks(css=""" | |
| body { margin: 0; padding: 0; } /* Removes default body margins */ | |
| #chatbot { | |
| position: fixed; /* Fixed position relative to the viewport */ | |
| top: 50px; /* Space for the New Chat button */ | |
| left: 0; | |
| right: 0; | |
| bottom: 60px; /* Space for the input area */ | |
| overflow-y: auto; /* Adds vertical scrollbar when content overflows */ | |
| padding: 20px; /* Inner spacing */ | |
| } | |
| #chatbot > div > div { | |
| margin-bottom: 15px; /* Space between chat messages */ | |
| max-width: 80%; /* Maximum width of chat bubbles */ | |
| } | |
| #chatbot > div > div:nth-child(odd) { | |
| margin-left: auto; /* Aligns user messages to the right */ | |
| text-align: right; | |
| } | |
| #chatbot > div > div:nth-child(even) { | |
| margin-right: auto; /* Aligns AI messages to the left */ | |
| text-align: left; | |
| } | |
| #input-area { | |
| position: fixed; /* Keeps input area at the bottom of the screen */ | |
| bottom: 0; | |
| left: 0; | |
| right: 0; | |
| background-color: #2c2c2c; /* Background color of input area */ | |
| padding: 10px; /* Inner spacing of input area */ | |
| height: 60px; /* Height of input area */ | |
| box-sizing: border-box; /* Includes padding in the element's total width and height */ | |
| } | |
| #input-box { | |
| max-width: 90%; /* Maximum width of the input box, using percentage for responsiveness */ | |
| width: 100%; | |
| margin: 0 auto; /* Centers the input box */ | |
| display: flex; /* Allows flexible alignment of child elements */ | |
| align-items: center; /* Vertically centers items in the input box */ | |
| height: 100%; /* Takes full height of parent */ | |
| } | |
| #file-indicator { | |
| font-size: 12px; /* Keeping original font size */ | |
| color: #aaa; /* Color of file indicator text */ | |
| position: absolute; | |
| top: -20px; /* Positions indicator above the input area */ | |
| left: 0; | |
| right: 0; | |
| text-align: center; | |
| } | |
| #new-chat-btn { | |
| position: fixed; /* Keeps button in a fixed position */ | |
| top: 10px; /* Distance from top of viewport */ | |
| left: 10px; /* Distance from left of viewport */ | |
| width: auto; /* Allow button to size based on content */ | |
| height: auto; | |
| padding: 5px 10px; | |
| } | |
| #file-upload { | |
| width: 40px; /* Width of file upload button */ | |
| height: 40px; /* Height of file upload button */ | |
| padding: 0; | |
| flex-shrink: 0; /* Prevents button from shrinking */ | |
| } | |
| #file-upload > button { | |
| border-radius: 50%; /* Makes the button circular */ | |
| padding: 8px; | |
| width: 100%; | |
| height: 100%; | |
| } | |
| #msg { | |
| flex-grow: 1; /* Allows the text input to grow and fill available space */ | |
| height: 40px; /* Height of text input */ | |
| margin-right: 10px; /* Space between text input and file upload button */ | |
| } | |
| /* Media query for smaller screens */ | |
| @media (max-width: 600px) { | |
| #chatbot { | |
| top: 40px; /* Slightly less space for the button on small screens */ | |
| bottom: 50px; /* Slightly less space for the input area on small screens */ | |
| padding: 10px; /* Reduced padding on small screens */ | |
| } | |
| #input-area { | |
| height: 50px; /* Slightly reduced height on small screens */ | |
| } | |
| #input-box { | |
| max-width: 95%; /* Wider input box on small screens */ | |
| } | |
| #file-upload, #msg { | |
| height: 30px; /* Slightly smaller height for inputs on small screens */ | |
| } | |
| #new-chat-btn { | |
| font-size: 14px; /* Keeping font size consistent, adjust if needed */ | |
| } | |
| } | |
| #app-description { | |
| position: fixed; | |
| top: 10px; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| text-align: center; | |
| width: 90%; | |
| max-width: 800px; | |
| z-index: 1000; | |
| background-color: #ffffff; /* Darker background color */ | |
| padding: 10px; | |
| border-radius: 5px; | |
| color: #1a1a1a; /* Pure white text color */ | |
| font-weight: bold; /* Make the text bold */ | |
| } | |
| #new-chat-btn { | |
| top: 60px; /* Adjusted to make room for the description */ | |
| } | |
| #chatbot { | |
| top: 100px; /* Adjusted to make room for the description and button */ | |
| } | |
| @media (max-width: 600px) { | |
| #app-description { | |
| font-size: 14px; | |
| padding: 5px; | |
| } | |
| #new-chat-btn { | |
| top: 50px; | |
| } | |
| #chatbot { | |
| top: 90px; | |
| } | |
| } | |
| """) as demo: | |
| gr.Markdown( | |
| """ | |
| MultiModal phi-2: This app takes text from bottom input bar. | |
| Audio and image files can be attached with attachment button. | |
| (After attaching audio or image file, click in text box and enter to upload file) | |
| """, | |
| elem_id="app-description" | |
| ) | |
| new_chat_btn = gr.Button("New Chat", elem_id="new-chat-btn") | |
| chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, show_label=False) | |
| with gr.Column(elem_id="input-area"): | |
| file_indicator = gr.Markdown(elem_id="file-indicator") | |
| with gr.Row(elem_id="input-box"): | |
| msg = gr.Textbox( | |
| show_label=False, | |
| placeholder="Send a message...", | |
| container=False, | |
| elem_id="msg" | |
| ) | |
| file_input = gr.UploadButton("๐", file_types=["image", "audio"], elem_id="file-upload") | |
| def handle_file_upload(file): | |
| if file: | |
| return gr.update(value=""), f"๐ Uploaded: {os.path.basename(file.name)}" # just actual file name, not full path here | |
| return gr.update(value=""), "" | |
| def process_message(message, history, file): | |
| try: | |
| if file: | |
| file_name = file.name | |
| history.append((f"๐ [File uploaded: {os.path.basename(file_name)}]", None)) | |
| if message: | |
| history.append((message, None)) | |
| response = process_input(message, file) | |
| if message or file: | |
| history[-1] = (history[-1][0], response) | |
| except Exception as e: | |
| history.append((message, f"An error occurred: {str(e)}")) | |
| return history, "", None # Clear the input box and file input after sending | |
| msg.submit(process_message, inputs=[msg, chatbot, file_input], outputs=[chatbot, msg, file_indicator]) | |
| file_input.upload(handle_file_upload, inputs=[file_input], outputs=[msg, file_indicator]) | |
| new_chat_btn.click(new_chat, outputs=[chatbot, file_indicator]) | |
| demo.launch(share=True) | |
| # Test model loading | |
| # test_input = "Human: Hello, how are you? Explain share market to me: " # Add a prompt format | |
| # test_encoded = tokenizer(test_input, return_tensors="pt").to(device) | |
| # test_image_features = torch.zeros((1, clip_embedding_dim), device=device, dtype=torch.float32) | |
| # with torch.no_grad(): | |
| # test_output = whole_model.generate( | |
| # input_ids=test_encoded.input_ids, | |
| # attention_mask=test_encoded.attention_mask, | |
| # image_embeddings=test_image_features, | |
| # max_new_tokens=200, | |
| # num_beams=5, | |
| # do_sample=True, | |
| # temperature=0.8, | |
| # top_k=50, | |
| # top_p=0.92, | |
| # no_repeat_ngram_size=3, | |
| # repetition_penalty=1.2, | |
| # ) | |
| # test_response = tokenizer.decode(test_output[0], skip_special_tokens=True) | |
| # # Handle the case where 'AI:' is not in the response | |
| # if "AI:" in test_response: | |
| # ai_response = test_response.split("AI:")[1].strip() | |
| # else: | |
| # ai_response = test_response.replace(test_input, "").strip() | |
| # print("Test model output:", ai_response) |