Spaces:
Sleeping
Sleeping
File size: 8,575 Bytes
10a574a 575a2af 10a574a 9b3f0d6 10a574a f48fe50 10a574a ee4a9ad 7b8c93b 10a574a 575a2af 10a574a 575a2af 10a574a 60e755a 10a574a 575a2af 60e755a 10a574a 575a2af 10a574a 575a2af 10a574a 575a2af 10a574a 575a2af 10a574a 575a2af 10a574a 575a2af 10a574a 575a2af 10a574a ee4a9ad 10a574a 575a2af 10a574a 575a2af 60e755a 10a574a d1d625f 10a574a 2ba8556 54fad30 62f2810 60e755a 013d2cc 62f2810 10a574a 38084a4 10a574a 9023859 67d6f70 f233688 9023859 10a574a 013d2cc 10a574a 013d2cc 10a574a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# app.py
import spaces
import os
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoConfig, BitsAndBytesConfig
import timm
from torchvision import transforms
#from llama_cpp import Llama
from peft import PeftModel, prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType
import traceback
import warnings
warnings.filterwarnings('ignore')
# 1. Model Definitions (Same as in training script)
class SigLIPImageEncoder(torch.nn.Module):
def __init__(self, model_name='resnet50', embed_dim=512, pretrained_path=None):
super().__init__()
self.model = timm.create_model(model_name, pretrained=False, num_classes=0, global_pool='avg') # pretrained=False
self.embed_dim = embed_dim
self.projection = torch.nn.Linear(self.model.num_features, embed_dim)
if pretrained_path:
#self.load_state_dict(torch.load(pretrained_path, map_location=torch.device('cpu'))) # Load to CPU first
self.load_state_dict(torch.load(pretrained_path))
print(f"Loaded SigLIP image encoder from {pretrained_path}")
else:
print("Initialized SigLIP image encoder without pretrained weights.")
def forward(self, image):
features = self.model(image)
embedding = self.projection(features)
return embedding
class Phi3WithImage(torch.nn.Module):
def __init__(self, phi3_model_name, image_encoder, image_embed_dim=512, image_token_id=1, bnb_config=None):
super().__init__()
self.phi3 = AutoModelForCausalLM.from_pretrained(
phi3_model_name,
#torch_dtype=torch.bfloat16,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
#quantization_config=bnb_config
)
self.image_encoder = image_encoder
self.image_embed_dim = image_embed_dim
self.phi3_embed_dim = self.phi3.config.hidden_size
self.image_projection = torch.nn.Linear(image_embed_dim, self.phi3_embed_dim)
self.image_token_id = image_token_id
def forward(self, image, question_input_ids, question_attention_mask, answer_input_ids, answer_attention_mask):
batch_size = question_input_ids.size(0)
# Encode image and project to text embedding space
image_embeddings = self.image_encoder(image) # Shape: [batch_size, image_embed_dim]
projected_image_embeddings = self.image_projection(image_embeddings).unsqueeze(1) # Shape: [batch_size, 1, hidden_dim]
# Get the token embeddings for question + answer from the model’s embedding layer
token_embeddings = self.phi3.get_input_embeddings()(
torch.cat([question_input_ids, answer_input_ids], dim=1)
) # Shape: [batch_size, seq_len, hidden_dim]
# Concatenate image embeddings at the start
inputs_embeds = torch.cat([projected_image_embeddings, token_embeddings], dim=1)
# Create combined attention mask
image_attention_mask = torch.ones((batch_size, 1), device=question_attention_mask.device)
full_attention_mask = torch.cat([image_attention_mask, question_attention_mask, answer_attention_mask], dim=1)
# Prepare labels: mask image + question parts so only answers are supervised
#labels = torch.cat([torch.full((batch_size, 1 + question_input_ids.size(1)), -100, device=question_input_ids.device), answer_input_ids], dim=1)
outputs = self.phi3(
inputs_embeds=inputs_embeds,
attention_mask=full_attention_mask
)
return outputs.logits
# 2. Load Models and Tokenizer
phi3_model_name = "microsoft/Phi-3-mini-4k-instruct" # Or your specific Phi-3 variant
lora_model_path = "./qlora-phi3-model-new"
image_model_name = 'resnet50'
image_embed_dim = 512
siglip_pretrained_path = "image_encoder.pth"
#device = torch.device("cpu") # Force CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load Tokenizer
text_tokenizer = AutoTokenizer.from_pretrained(phi3_model_name, trust_remote_code=True)
text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training
# 7. BitsAndBytesConfig for QLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Image Transformations
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load Models
image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device)
#model = Phi3WithImage(phi3_model_name, lora_model_path, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device)
model = Phi3WithImage(phi3_model_name, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device)
#model.phi3 = prepare_model_for_kbit_training(model.phi3)
# Disable cache manually (important!)
if hasattr(model.phi3.config, 'use_cache'):
model.phi3.config.use_cache = False
# Load LoRA model
model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto", offload_dir='./offload')
model.phi3 = model.phi3.merge_and_unload()
#model.phi3.save_pretrained("merged_model_fp16")
#model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto")
#model.phi3 = AutoModelForCausalLM.from_pretrained(
# "merged_model_fp16",
# quantization_config=bnb_config,
# torch_dtype=torch.bfloat16,
# device_map="auto"
#)
model.eval() # Set to evaluation mode
# 3. Inference Function
@spaces.GPU
# app.py
# ... existing code ...
# 3. Inference Function
def predict(image_input, question):
"""
Takes an image and a question as input and returns an answer.
"""
if image_input is None or question is None or question == "":
return "Please provide both an image and a question."
try:
image = Image.fromarray(image_input).convert("RGB")
image = image_transform(image).unsqueeze(0).to(device)
prompt = f"Question: {question}\nAnswer:"
encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
# Get image embeddings
image_embeddings = model.image_encoder(image)
projected_image_embeddings = model.image_projection(image_embeddings)
# Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
# Concatenate along the sequence dimension (dim=1)
extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1)
extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1)
# Generate answer
generated_tokens = model.phi3.generate(
input_ids=extended_input_ids,
attention_mask=extended_attention_mask,
max_length=128,
pad_token_id=text_tokenizer.eos_token_id,
)
answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
answer = answer.replace(prompt, "").strip() # Remove prompt from answer
return answer
except Exception as e:
#return f"An error occurred: {str(e)}"
return f"An error occurred: {traceback.format_exc()}"
# 5. Launch the App
if __name__ == "__main__":
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(label="Upload an Image"),
gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
],
outputs=gr.Textbox(label="Answer"),
title="Image Question Answering with Phi-3 and SigLIP",
description="Ask questions about an image and get answers powered by Phi-3 and SigLIP.",
examples=[
["cat_0006.png", "Create a interesting story about this image?"],
["bird_0004.png", "Can you describe this image?"],
["truck_0003.png", "Elaborate the setting of the image"],
["ship_0007.png", "Explain the purpose of image"]
]
)
iface.launch() |