Monimoy commited on
Commit
575a2af
·
verified ·
1 Parent(s): 7bb0ea1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -20
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import timm
9
  from torchvision import transforms
10
  #from llama_cpp import Llama
@@ -31,42 +31,53 @@ class SigLIPImageEncoder(torch.nn.Module):
31
  embedding = self.projection(features)
32
  return embedding
33
 
 
34
  class Phi3WithImage(torch.nn.Module):
35
- def __init__(self, phi3_model_name, lora_model_path, image_encoder, image_embed_dim=512):
 
36
  super().__init__()
37
  self.phi3 = AutoModelForCausalLM.from_pretrained(
38
  phi3_model_name,
39
  torch_dtype=torch.bfloat16,
40
  device_map="auto",
41
- trust_remote_code=True # Important for some Phi-3 variants
 
42
  )
43
-
44
  self.image_encoder = image_encoder
45
  self.image_embed_dim = image_embed_dim
46
  self.phi3_embed_dim = self.phi3.config.hidden_size
47
-
48
- # Project image embeddings to Phi-3's embedding space
49
  self.image_projection = torch.nn.Linear(image_embed_dim, self.phi3_embed_dim)
 
 
 
 
 
 
 
 
50
 
51
- def forward(self, image, question_input_ids, question_attention_mask):
52
- image_embeddings = self.image_encoder(image)
53
- projected_image_embeddings = self.image_projection(image_embeddings)
 
54
 
55
- # Concatenate image embeddings to the input sequence
56
- # This is a simplified approach. More sophisticated methods exist.
57
- # Assumes image embeddings are prepended to the sequence.
58
 
59
- # Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
60
- projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
 
61
 
62
- # Concatenate along the sequence dimension (dim=1)
63
- extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=question_attention_mask.device), question_attention_mask], dim=1)
64
- extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=question_input_ids.device), question_input_ids], dim=1)
65
 
66
- # Pass the concatenated input to Phi-3
67
- outputs = self.phi3(input_ids=extended_input_ids, attention_mask=extended_attention_mask) # No labels during inference
 
 
68
  return outputs.logits
69
 
 
70
  # 2. Load Models and Tokenizer
71
  phi3_model_name = "microsoft/Phi-3-mini-4k-instruct" # Or your specific Phi-3 variant
72
  lora_model_path = "./qlora-phi3-model-new"
@@ -81,6 +92,14 @@ print(f"Using device: {device}")
81
  text_tokenizer = AutoTokenizer.from_pretrained(phi3_model_name, trust_remote_code=True)
82
  text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training
83
 
 
 
 
 
 
 
 
 
84
  # Image Transformations
85
  image_transform = transforms.Compose([
86
  transforms.Resize((224, 224)),
@@ -90,7 +109,8 @@ image_transform = transforms.Compose([
90
 
91
  # Load Models
92
  image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device)
93
- model = Phi3WithImage(phi3_model_name, lora_model_path, image_encoder, image_embed_dim).to(device)
 
94
 
95
  # Load LoRA model
96
  model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto", offload_dir='./offload')
 
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM,AutoConfig, BitsAndBytesConfig
8
  import timm
9
  from torchvision import transforms
10
  #from llama_cpp import Llama
 
31
  embedding = self.projection(features)
32
  return embedding
33
 
34
+
35
  class Phi3WithImage(torch.nn.Module):
36
+
37
+ def __init__(self, phi3_model_name, image_encoder, image_embed_dim=512, image_token_id=1, bnb_config=None):
38
  super().__init__()
39
  self.phi3 = AutoModelForCausalLM.from_pretrained(
40
  phi3_model_name,
41
  torch_dtype=torch.bfloat16,
42
  device_map="auto",
43
+ trust_remote_code=True,
44
+ quantization_config=bnb_config
45
  )
 
46
  self.image_encoder = image_encoder
47
  self.image_embed_dim = image_embed_dim
48
  self.phi3_embed_dim = self.phi3.config.hidden_size
 
 
49
  self.image_projection = torch.nn.Linear(image_embed_dim, self.phi3_embed_dim)
50
+ self.image_token_id = image_token_id
51
+
52
+ def forward(self, image, question_input_ids, question_attention_mask, answer_input_ids, answer_attention_mask):
53
+ batch_size = question_input_ids.size(0)
54
+
55
+ # Encode image and project to text embedding space
56
+ image_embeddings = self.image_encoder(image) # Shape: [batch_size, image_embed_dim]
57
+ projected_image_embeddings = self.image_projection(image_embeddings).unsqueeze(1) # Shape: [batch_size, 1, hidden_dim]
58
 
59
+ # Get the token embeddings for question + answer from the model’s embedding layer
60
+ token_embeddings = self.phi3.get_input_embeddings()(
61
+ torch.cat([question_input_ids, answer_input_ids], dim=1)
62
+ ) # Shape: [batch_size, seq_len, hidden_dim]
63
 
64
+ # Concatenate image embeddings at the start
65
+ inputs_embeds = torch.cat([projected_image_embeddings, token_embeddings], dim=1)
 
66
 
67
+ # Create combined attention mask
68
+ image_attention_mask = torch.ones((batch_size, 1), device=question_attention_mask.device)
69
+ full_attention_mask = torch.cat([image_attention_mask, question_attention_mask, answer_attention_mask], dim=1)
70
 
71
+ # Prepare labels: mask image + question parts so only answers are supervised
72
+ #labels = torch.cat([torch.full((batch_size, 1 + question_input_ids.size(1)), -100, device=question_input_ids.device), answer_input_ids], dim=1)
 
73
 
74
+ outputs = self.phi3(
75
+ inputs_embeds=inputs_embeds,
76
+ attention_mask=full_attention_mask
77
+ )
78
  return outputs.logits
79
 
80
+
81
  # 2. Load Models and Tokenizer
82
  phi3_model_name = "microsoft/Phi-3-mini-4k-instruct" # Or your specific Phi-3 variant
83
  lora_model_path = "./qlora-phi3-model-new"
 
92
  text_tokenizer = AutoTokenizer.from_pretrained(phi3_model_name, trust_remote_code=True)
93
  text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training
94
 
95
+ # 7. BitsAndBytesConfig for QLoRA
96
+ bnb_config = BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_use_double_quant=True,
99
+ bnb_4bit_quant_type="nf4",
100
+ bnb_4bit_compute_dtype=torch.bfloat16
101
+ )
102
+
103
  # Image Transformations
104
  image_transform = transforms.Compose([
105
  transforms.Resize((224, 224)),
 
109
 
110
  # Load Models
111
  image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device)
112
+ #model = Phi3WithImage(phi3_model_name, lora_model_path, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device)
113
+ model = Phi3WithImage(phi3_model_name, image_encoder, image_embed_dim, bnb_config=bnb_config).to(device)
114
 
115
  # Load LoRA model
116
  model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto", offload_dir='./offload')