Monimoy's picture
Update app.py
d1d625f verified
# app.py
import spaces
import os
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoConfig, BitsAndBytesConfig
import timm
from torchvision import transforms
#from llama_cpp import Llama
from peft import PeftModel, prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType
import traceback
import warnings
warnings.filterwarnings('ignore')
# 1. Model Definitions (Same as in training script)
class SigLIPImageEncoder(torch.nn.Module):
def __init__(self, model_name='resnet50', embed_dim=512, pretrained_path=None):
super().__init__()
self.model = timm.create_model(model_name, pretrained=False, num_classes=0, global_pool='avg') # pretrained=False
self.embed_dim = embed_dim
self.projection = torch.nn.Linear(self.model.num_features, embed_dim)
if pretrained_path:
#self.load_state_dict(torch.load(pretrained_path, map_location=torch.device('cpu'))) # Load to CPU first
self.load_state_dict(torch.load(pretrained_path))
print(f"Loaded SigLIP image encoder from {pretrained_path}")
else:
print("Initialized SigLIP image encoder without pretrained weights.")
def forward(self, image):
features = self.model(image)
embedding = self.projection(features)
return embedding
class Phi3WithImage(torch.nn.Module):
def __init__(self, phi3_model_name, image_encoder, image_embed_dim=512, image_token_id=1, bnb_config=None):
super().__init__()
self.phi3 = AutoModelForCausalLM.from_pretrained(
phi3_model_name,
#torch_dtype=torch.bfloat16,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
#quantization_config=bnb_config
)
self.image_encoder = image_encoder
self.image_embed_dim = image_embed_dim
self.phi3_embed_dim = self.phi3.config.hidden_size
self.image_projection = torch.nn.Linear(image_embed_dim, self.phi3_embed_dim)
self.image_token_id = image_token_id
def forward(self, image, question_input_ids, question_attention_mask, answer_input_ids, answer_attention_mask):
batch_size = question_input_ids.size(0)
# Encode image and project to text embedding space
image_embeddings = self.image_encoder(image) # Shape: [batch_size, image_embed_dim]
projected_image_embeddings = self.image_projection(image_embeddings).unsqueeze(1) # Shape: [batch_size, 1, hidden_dim]
# Get the token embeddings for question + answer from the model’s embedding layer
token_embeddings = self.phi3.get_input_embeddings()(
torch.cat([question_input_ids, answer_input_ids], dim=1)
) # Shape: [batch_size, seq_len, hidden_dim]
# Concatenate image embeddings at the start
inputs_embeds = torch.cat([projected_image_embeddings, token_embeddings], dim=1)
# Create combined attention mask
image_attention_mask = torch.ones((batch_size, 1), device=question_attention_mask.device)
full_attention_mask = torch.cat([image_attention_mask, question_attention_mask, answer_attention_mask], dim=1)
# Prepare labels: mask image + question parts so only answers are supervised
#labels = torch.cat([torch.full((batch_size, 1 + question_input_ids.size(1)), -100, device=question_input_ids.device), answer_input_ids], dim=1)
outputs = self.phi3(
inputs_embeds=inputs_embeds,
attention_mask=full_attention_mask
)
return outputs.logits
# 2. Load Models and Tokenizer
phi3_model_name = "microsoft/Phi-3-mini-4k-instruct" # Or your specific Phi-3 variant
lora_model_path = "./qlora-phi3-model-new"
image_model_name = 'resnet50'
image_embed_dim = 512
siglip_pretrained_path = "image_encoder.pth"
#device = torch.device("cpu") # Force CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load Tokenizer
text_tokenizer = AutoTokenizer.from_pretrained(phi3_model_name, trust_remote_code=True)
text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training
# 7. BitsAndBytesConfig for QLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Image Transformations
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load Models
image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device)
#model = Phi3WithImage(phi3_model_name, lora_model_path, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device)
model = Phi3WithImage(phi3_model_name, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device)
#model.phi3 = prepare_model_for_kbit_training(model.phi3)
# Disable cache manually (important!)
if hasattr(model.phi3.config, 'use_cache'):
model.phi3.config.use_cache = False
# Load LoRA model
model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto", offload_dir='./offload')
model.phi3 = model.phi3.merge_and_unload()
#model.phi3.save_pretrained("merged_model_fp16")
#model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto")
#model.phi3 = AutoModelForCausalLM.from_pretrained(
# "merged_model_fp16",
# quantization_config=bnb_config,
# torch_dtype=torch.bfloat16,
# device_map="auto"
#)
model.eval() # Set to evaluation mode
# 3. Inference Function
@spaces.GPU
# app.py
# ... existing code ...
# 3. Inference Function
def predict(image_input, question):
"""
Takes an image and a question as input and returns an answer.
"""
if image_input is None or question is None or question == "":
return "Please provide both an image and a question."
try:
image = Image.fromarray(image_input).convert("RGB")
image = image_transform(image).unsqueeze(0).to(device)
prompt = f"Question: {question}\nAnswer:"
encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
# Get image embeddings
image_embeddings = model.image_encoder(image)
projected_image_embeddings = model.image_projection(image_embeddings)
# Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
# Concatenate along the sequence dimension (dim=1)
extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1)
extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1)
# Generate answer
generated_tokens = model.phi3.generate(
input_ids=extended_input_ids,
attention_mask=extended_attention_mask,
max_length=128,
pad_token_id=text_tokenizer.eos_token_id,
)
answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
answer = answer.replace(prompt, "").strip() # Remove prompt from answer
return answer
except Exception as e:
#return f"An error occurred: {str(e)}"
return f"An error occurred: {traceback.format_exc()}"
# 5. Launch the App
if __name__ == "__main__":
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(label="Upload an Image"),
gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
],
outputs=gr.Textbox(label="Answer"),
title="Image Question Answering with Phi-3 and SigLIP",
description="Ask questions about an image and get answers powered by Phi-3 and SigLIP.",
examples=[
["cat_0006.png", "Create a interesting story about this image?"],
["bird_0004.png", "Can you describe this image?"],
["truck_0003.png", "Elaborate the setting of the image"],
["ship_0007.png", "Explain the purpose of image"]
]
)
iface.launch()