Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, pipeline | |
| from transformers import AutoModelForCausalLM | |
| from torchvision import transforms | |
| from transformers import CLIPProcessor, CLIPModel | |
| from model import build_mlp_vector_projector | |
| device = "cpu" | |
| # Load the CLIP model and processor | |
| clip_model_name = "openai/clip-vit-base-patch16" | |
| clip_model = CLIPModel.from_pretrained(clip_model_name).to(device) | |
| clip_processor = CLIPProcessor.from_pretrained(clip_model_name) | |
| clip_transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ] | |
| ) | |
| def process_image(img_path): | |
| image = Image.open(img_path).convert("RGB") | |
| image = clip_transform(image) | |
| inputs = clip_processor(text=[""], images=image, | |
| return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| img_embedding = clip_model(**inputs).image_embeds | |
| img_proj_head = build_mlp_vector_projector().to(device) | |
| img_proj_head.load_state_dict(torch.load( | |
| 'stage_2_proj_head_v3.pth', map_location=torch.device(device))) | |
| img_tokens = img_proj_head(img_embedding) | |
| return img_tokens | |
| phi_model_name = "microsoft/phi-2" | |
| text_tokenizer = AutoTokenizer.from_pretrained( | |
| phi_model_name, trust_remote_code=True) | |
| with torch.no_grad(): | |
| base_phi2_text = AutoModelForCausalLM.from_pretrained( | |
| phi_model_name, trust_remote_code=True, | |
| device_map="auto", torch_dtype=torch.float16 | |
| ) | |
| tuned_phi2 = AutoModelForCausalLM.from_pretrained( | |
| "stage2_adaptor", trust_remote_code=True, | |
| ).to("cpu") | |
| print("phi2 model loaded") | |
| audio_model_name = "openai/whisper-small" | |
| audio_pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=audio_model_name, | |
| chunk_length_s=30, | |
| device=device) | |
| def process_text(text, count): | |
| inputs = text_tokenizer.encode(text, return_tensors="pt") | |
| input_embeds = tuned_phi2.get_submodule( | |
| 'model.embed_tokens')(inputs).to(device) | |
| prediction = text_tokenizer.batch_decode( | |
| tuned_phi2.generate( | |
| inputs_embeds=input_embeds, | |
| max_new_tokens=30, | |
| bos_token_id=text_tokenizer.bos_token_id, | |
| eos_token_id=text_tokenizer.eos_token_id, | |
| pad_token_id=text_tokenizer.pad_token_id | |
| ) | |
| ) | |
| return prediction[0].rstrip('<|endoftext|>').rstrip("\n") | |
| def process_audio(audio): | |
| if audio is None: | |
| raise gr.Error( | |
| "Please provide an audio file or record your input" | |
| ) | |
| text = audio_pipe( | |
| audio, | |
| batch_size=8, | |
| generate_kwargs={"task": "transcribe"}, | |
| return_timestamps=True | |
| )["text"] | |
| return text | |
| def generate_response(image, audio, text, count): | |
| count = int(count) | |
| overall_input = "" | |
| if audio: | |
| overall_input = process_audio(audio) | |
| if text: | |
| overall_input = text + overall_input | |
| if image: | |
| img_tokens = process_image(image) | |
| overall_input = "Question: " + overall_input + "Answer:" | |
| q_tokens = text_tokenizer.encode( | |
| overall_input, | |
| return_tensors='pt').to(device) | |
| question_token_embeddings = tuned_phi2.get_submodule( | |
| 'model.embed_tokens')(q_tokens).to(device) | |
| inputs = torch.concat( | |
| (img_tokens.unsqueeze(0), question_token_embeddings), | |
| axis=-2).to(device) | |
| prediction = text_tokenizer.batch_decode( | |
| tuned_phi2.generate( | |
| inputs_embeds=inputs, | |
| max_new_tokens=30, | |
| bos_token_id=text_tokenizer.bos_token_id, | |
| eos_token_id=text_tokenizer.eos_token_id, | |
| pad_token_id=text_tokenizer.pad_token_id | |
| ) | |
| ) | |
| return prediction[0].rstrip('<|endoftext|>').rstrip("\n") | |
| else: | |
| return process_text(overall_input, count) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **AnyModeAssistant**") | |
| gr.Markdown("Use any mode text/image/audio to interact with AI assistant") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Row("Text"): | |
| text_input = gr.Textbox(placeholder="Enter your question here", | |
| label="Input") | |
| with gr.Row(): | |
| image_input = gr.Image(type="filepath") | |
| with gr.Row("Audio mode"): | |
| audio_input = gr.Audio(type="filepath") | |
| with gr.Row("Image"): | |
| response_count = gr.Textbox( | |
| placeholder="Number of tokens to respond", | |
| value=20, | |
| label="Count") | |
| with gr.Column(scale=2): | |
| response = gr.Textbox(label="AI Response") | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit") | |
| submit_button.click(generate_response, | |
| inputs=[ | |
| image_input, audio_input, | |
| text_input, response_count | |
| ], | |
| outputs=response) | |
| gr.Examples( | |
| examples=[ | |
| ["dog_man_forest.jpg", "audio.wav", "Is there a dog present in the image?"], | |
| ], | |
| inputs=[image_input, audio_input, text_input, response_count], | |
| outputs=[response], | |
| fn=generate_response, | |
| ) | |
| demo.launch(share=True) | |