Spaces:
Sleeping
Sleeping
| # app.py | |
| import spaces | |
| import os | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoTokenizer, AutoModelForCausalLM,AutoConfig, BitsAndBytesConfig | |
| import timm | |
| from torchvision import transforms | |
| #from llama_cpp import Llama | |
| from peft import PeftModel, prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType | |
| import traceback | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # 1. Model Definitions (Same as in training script) | |
| class SigLIPImageEncoder(torch.nn.Module): | |
| def __init__(self, model_name='resnet50', embed_dim=512, pretrained_path=None): | |
| super().__init__() | |
| self.model = timm.create_model(model_name, pretrained=False, num_classes=0, global_pool='avg') # pretrained=False | |
| self.embed_dim = embed_dim | |
| self.projection = torch.nn.Linear(self.model.num_features, embed_dim) | |
| if pretrained_path: | |
| #self.load_state_dict(torch.load(pretrained_path, map_location=torch.device('cpu'))) # Load to CPU first | |
| self.load_state_dict(torch.load(pretrained_path)) | |
| print(f"Loaded SigLIP image encoder from {pretrained_path}") | |
| else: | |
| print("Initialized SigLIP image encoder without pretrained weights.") | |
| def forward(self, image): | |
| features = self.model(image) | |
| embedding = self.projection(features) | |
| return embedding | |
| class Phi3WithImage(torch.nn.Module): | |
| def __init__(self, phi3_model_name, image_encoder, image_embed_dim=512, image_token_id=1, bnb_config=None): | |
| super().__init__() | |
| self.phi3 = AutoModelForCausalLM.from_pretrained( | |
| phi3_model_name, | |
| #torch_dtype=torch.bfloat16, | |
| torch_dtype=torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| #quantization_config=bnb_config | |
| ) | |
| self.image_encoder = image_encoder | |
| self.image_embed_dim = image_embed_dim | |
| self.phi3_embed_dim = self.phi3.config.hidden_size | |
| self.image_projection = torch.nn.Linear(image_embed_dim, self.phi3_embed_dim) | |
| self.image_token_id = image_token_id | |
| def forward(self, image, question_input_ids, question_attention_mask, answer_input_ids, answer_attention_mask): | |
| batch_size = question_input_ids.size(0) | |
| # Encode image and project to text embedding space | |
| image_embeddings = self.image_encoder(image) # Shape: [batch_size, image_embed_dim] | |
| projected_image_embeddings = self.image_projection(image_embeddings).unsqueeze(1) # Shape: [batch_size, 1, hidden_dim] | |
| # Get the token embeddings for question + answer from the model’s embedding layer | |
| token_embeddings = self.phi3.get_input_embeddings()( | |
| torch.cat([question_input_ids, answer_input_ids], dim=1) | |
| ) # Shape: [batch_size, seq_len, hidden_dim] | |
| # Concatenate image embeddings at the start | |
| inputs_embeds = torch.cat([projected_image_embeddings, token_embeddings], dim=1) | |
| # Create combined attention mask | |
| image_attention_mask = torch.ones((batch_size, 1), device=question_attention_mask.device) | |
| full_attention_mask = torch.cat([image_attention_mask, question_attention_mask, answer_attention_mask], dim=1) | |
| # Prepare labels: mask image + question parts so only answers are supervised | |
| #labels = torch.cat([torch.full((batch_size, 1 + question_input_ids.size(1)), -100, device=question_input_ids.device), answer_input_ids], dim=1) | |
| outputs = self.phi3( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=full_attention_mask | |
| ) | |
| return outputs.logits | |
| # 2. Load Models and Tokenizer | |
| phi3_model_name = "microsoft/Phi-3-mini-4k-instruct" # Or your specific Phi-3 variant | |
| lora_model_path = "./qlora-phi3-model-new" | |
| image_model_name = 'resnet50' | |
| image_embed_dim = 512 | |
| siglip_pretrained_path = "image_encoder.pth" | |
| #device = torch.device("cpu") # Force CPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load Tokenizer | |
| text_tokenizer = AutoTokenizer.from_pretrained(phi3_model_name, trust_remote_code=True) | |
| text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training | |
| # 7. BitsAndBytesConfig for QLoRA | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| # Image Transformations | |
| image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Load Models | |
| image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device) | |
| #model = Phi3WithImage(phi3_model_name, lora_model_path, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device) | |
| model = Phi3WithImage(phi3_model_name, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device) | |
| #model.phi3 = prepare_model_for_kbit_training(model.phi3) | |
| # Disable cache manually (important!) | |
| if hasattr(model.phi3.config, 'use_cache'): | |
| model.phi3.config.use_cache = False | |
| # Load LoRA model | |
| model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto", offload_dir='./offload') | |
| model.phi3 = model.phi3.merge_and_unload() | |
| #model.phi3.save_pretrained("merged_model_fp16") | |
| #model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto") | |
| #model.phi3 = AutoModelForCausalLM.from_pretrained( | |
| # "merged_model_fp16", | |
| # quantization_config=bnb_config, | |
| # torch_dtype=torch.bfloat16, | |
| # device_map="auto" | |
| #) | |
| model.eval() # Set to evaluation mode | |
| # 3. Inference Function | |
| # app.py | |
| # ... existing code ... | |
| # 3. Inference Function | |
| def predict(image_input, question): | |
| """ | |
| Takes an image and a question as input and returns an answer. | |
| """ | |
| if image_input is None or question is None or question == "": | |
| return "Please provide both an image and a question." | |
| try: | |
| image = Image.fromarray(image_input).convert("RGB") | |
| image = image_transform(image).unsqueeze(0).to(device) | |
| prompt = f"Question: {question}\nAnswer:" | |
| encoded = text_tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| # Get image embeddings | |
| image_embeddings = model.image_encoder(image) | |
| projected_image_embeddings = model.image_projection(image_embeddings) | |
| # Reshape image embeddings to (batch_size, 1, phi3_embed_dim) | |
| projected_image_embeddings = projected_image_embeddings.unsqueeze(1) | |
| # Concatenate along the sequence dimension (dim=1) | |
| extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1) | |
| extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1) | |
| # Generate answer | |
| generated_tokens = model.phi3.generate( | |
| input_ids=extended_input_ids, | |
| attention_mask=extended_attention_mask, | |
| max_length=128, | |
| pad_token_id=text_tokenizer.eos_token_id, | |
| ) | |
| answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
| answer = answer.replace(prompt, "").strip() # Remove prompt from answer | |
| return answer | |
| except Exception as e: | |
| #return f"An error occurred: {str(e)}" | |
| return f"An error occurred: {traceback.format_exc()}" | |
| # 5. Launch the App | |
| if __name__ == "__main__": | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(label="Upload an Image"), | |
| gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?") | |
| ], | |
| outputs=gr.Textbox(label="Answer"), | |
| title="Image Question Answering with Phi-3 and SigLIP", | |
| description="Ask questions about an image and get answers powered by Phi-3 and SigLIP.", | |
| examples=[ | |
| ["cat_0006.png", "Create a interesting story about this image?"], | |
| ["bird_0004.png", "Can you describe this image?"], | |
| ["truck_0003.png", "Elaborate the setting of the image"], | |
| ["ship_0007.png", "Explain the purpose of image"] | |
| ] | |
| ) | |
| iface.launch() |