Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| import time | |
| from PIL import Image | |
| import torch | |
| import whisperx | |
| from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer | |
| from models.vision_projector_model import VisionProjector | |
| from config import VisionProjectorConfig, app_config as cfg | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| vision_projector = VisionProjector(VisionProjectorConfig()) | |
| ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device)) | |
| vision_projector.load_state_dict(ckpt['model_state_dict']) | |
| phi_base_model = AutoModelForCausalLM.from_pretrained( | |
| 'microsoft/phi-2', | |
| low_cpu_mem_usage=True, | |
| return_dict=True, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| # device_map=device_map, | |
| ) | |
| from peft import PeftModel | |
| phi_new_model = "models/phi_adapter" | |
| phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model) | |
| phi_model = phi_model.merge_and_unload().to(device) | |
| '''compute_type = 'float32' | |
| if device != 'cpu': | |
| compute_type = 'float16''' | |
| audi_model = whisperx.load_model("small", device, compute_type='float16') | |
| tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2') | |
| tokenizer.pad_token = tokenizer.unk_token | |
| ### app functions ## | |
| context_added = False | |
| query_added = False | |
| context = None | |
| context_type = '' | |
| query = '' | |
| bot_active = False | |
| def print_like_dislike(x: gr.LikeData): | |
| print(x.index, x.value, x.liked) | |
| def add_text(history, text): | |
| global context, context_type, context_added, query, query_added | |
| context_added = False | |
| if not context_type and '</context>' not in text: | |
| context = "**Please add context (upload image/audio or enter text followed by \</context\>" | |
| context_type = 'error' | |
| context_added = True | |
| query_added = False | |
| elif '</context>' in text: | |
| context_type = 'text' | |
| context_added = True | |
| text = text.replace('</context>', ' ') | |
| context = text | |
| query_added = False | |
| elif context_type in ['[text]', '[image]', '[audio]']: | |
| query = 'Human### ' + text + '\n' + 'AI### ' | |
| query_added = True | |
| context_added = False | |
| else: | |
| query_added = False | |
| context_added = True | |
| context = 'error' | |
| context = "**Please provide a valid context**" | |
| history = history + [(text, None)] | |
| return history, gr.Textbox(value="", interactive=False) | |
| def add_file(history, file): | |
| global context_added, context, context_type, query_added | |
| context = file | |
| context_type = 'image' | |
| context_added = True | |
| query_added = False | |
| history = history + [((file.name,), None)] | |
| return history | |
| def audio_upload(history, audio_file): | |
| global context, context_type, context_added, query, query_added | |
| if audio_file: | |
| context_added = True | |
| context_type = 'audio' | |
| context = audio_file | |
| query_added = False | |
| history = history + [((audio_file,), None)] | |
| else: | |
| pass | |
| return history | |
| def preprocess_fn(history): | |
| global context, context_added, query, context_type, query_added | |
| if context_added: | |
| if context_type == 'image': | |
| image = Image.open(context) | |
| inputs = clip_processor(images=image, return_tensors="pt") | |
| x = clip_model(**inputs, output_hidden_states=True) | |
| image_features = x.hidden_states[-2] | |
| context = vision_projector(image_features) | |
| elif context_type == 'audio': | |
| audio_file = context | |
| audio = whisperx.load_audio(audio_file) | |
| result = audi_model.transcribe(audio, batch_size=1) | |
| error = False | |
| if result.get('language', None) and result.get('segments', None): | |
| try: | |
| model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) | |
| result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
| except Exception as e: | |
| error = True | |
| print(result.get('language', None)) | |
| if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None): | |
| text = result["segments"][0].get('text', '') | |
| print(text) | |
| context_type = 'audio' | |
| context_added = True | |
| context = text | |
| query_added = False | |
| print(context) | |
| else: | |
| error = True | |
| else: | |
| error = True | |
| if error: | |
| context_type = 'error' | |
| context_added = True | |
| context = "**Please provide a valid audio file / context**" | |
| query_added = False | |
| print("Here") | |
| return history | |
| def bot(history): | |
| global context, context_added, query, context_type, query_added, bot_active | |
| response = '' | |
| if context_added: | |
| context_added = False | |
| if context_type == 'error': | |
| response = context | |
| query = '' | |
| elif context_type in ['image', 'audio', 'text']: | |
| response = '' | |
| if context_type == 'audio': | |
| response = 'Context: \n🗣 ' + '"_' + context.strip() + '_"\n\n' | |
| response += "**Please proceed with your queries**" | |
| query = '' | |
| context_type = '[' + context_type + ']' | |
| elif query_added: | |
| query_added = False | |
| if context_type == '[image]': | |
| query_ids = tokenizer.encode(query) | |
| query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device) | |
| query_embeds = phi_model.get_input_embeddings()(query_ids) | |
| inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1) | |
| out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, | |
| bos_token_id=tokenizer.bos_token_id) | |
| response = tokenizer.decode(out[0], skip_special_tokens=True) | |
| elif context_type in ['[text]', '[audio]']: | |
| input_text = context + query | |
| input_tokens = tokenizer.encode(input_text) | |
| input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device) | |
| inputs_embeds = phi_model.get_input_embeddings()(input_ids) | |
| out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50, | |
| bos_token_id=tokenizer.bos_token_id) | |
| response = tokenizer.decode(out[0], skip_special_tokens=True) | |
| else: | |
| query = '' | |
| response = "**Please provide a valid context**" | |
| if response: | |
| bot_active = True | |
| if history and len(history[-1]) > 1: | |
| history[-1][1] = "" | |
| for character in response: | |
| history[-1][1] += character | |
| time.sleep(0.05) | |
| yield history | |
| time.sleep(0.5) | |
| bot_active = False | |
| def clear_fn(): | |
| global context_added, context_type, context, query, query_added | |
| context_added = False | |
| context_type = '' | |
| context = None | |
| query = '' | |
| query_added = False | |
| return { | |
| chatbot: None | |
| } | |
| with gr.Blocks() as app: | |
| gr.Markdown( | |
| """ | |
| # ContextGPT - A Multimodal chatbot | |
| ### Upload image or audio to add a context. And then ask questions. | |
| ### You can also enter text followed by \</context\> to set the context. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| scale=4, | |
| show_label=False, | |
| placeholder="Press enter to send ", | |
| container=False, | |
| ) | |
| with gr.Row(): | |
| aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True, | |
| show_share_button=True) | |
| btn = gr.UploadButton("📷", file_types=["image"]) | |
| with gr.Row(): | |
| clear = gr.Button("Clear") | |
| txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
| preprocess_fn, chatbot, chatbot | |
| ).then( | |
| bot, chatbot, chatbot, api_name="bot_response" | |
| ) | |
| txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
| file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
| preprocess_fn, chatbot, chatbot | |
| ).then( | |
| bot, chatbot, chatbot, api_name="bot_response" | |
| ) | |
| chatbot.like(print_like_dislike, None, None) | |
| clear.click(clear_fn, None, chatbot, queue=False) | |
| aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then( | |
| preprocess_fn, chatbot, chatbot | |
| ).then( | |
| bot, chatbot, chatbot, api_name="bot_response" | |
| ) | |
| aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then( | |
| preprocess_fn, chatbot, chatbot | |
| ).then( | |
| bot, chatbot, chatbot, api_name="bot_response" | |
| ) | |
| app.queue() | |
| app.launch() |