Monimoy commited on
Commit
10a574a
·
verified ·
1 Parent(s): 2237baa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py CHANGED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import spaces
3
+ import os
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import timm
9
+ from torchvision import transforms
10
+ #from llama_cpp import Llama
11
+ from peft import PeftModel
12
+
13
+ import traceback
14
+
15
+ # 1. Model Definitions (Same as in training script)
16
+ class SigLIPImageEncoder(torch.nn.Module):
17
+ def __init__(self, model_name='resnet50', embed_dim=512, pretrained_path=None):
18
+ super().__init__()
19
+ self.model = timm.create_model(model_name, pretrained=False, num_classes=0, global_pool='avg') # pretrained=False
20
+ self.embed_dim = embed_dim
21
+ self.projection = torch.nn.Linear(self.model.num_features, embed_dim)
22
+
23
+ if pretrained_path:
24
+ self.load_state_dict(torch.load(pretrained_path, map_location=torch.device('cpu'))) # Load to CPU first
25
+ print(f"Loaded SigLIP image encoder from {pretrained_path}")
26
+ else:
27
+ print("Initialized SigLIP image encoder without pretrained weights.")
28
+
29
+ def forward(self, image):
30
+ features = self.model(image)
31
+ embedding = self.projection(features)
32
+ return embedding
33
+
34
+ class Phi3WithImage(torch.nn.Module):
35
+ def __init__(self, phi3_model_name, image_encoder, image_embed_dim=512):
36
+ super().__init__()
37
+ self.phi3 = AutoModelForCausalLM.from_pretrained(
38
+ phi3_model_name,
39
+ torch_dtype=torch.bfloat16,
40
+ device_map="auto",
41
+ trust_remote_code=True # Important for some Phi-3 variants
42
+ )
43
+ self.image_encoder = image_encoder
44
+ self.image_embed_dim = image_embed_dim
45
+ self.phi3_embed_dim = self.phi3.config.hidden_size
46
+
47
+ # Project image embeddings to Phi-3's embedding space
48
+ self.image_projection = torch.nn.Linear(image_embed_dim, self.phi3_embed_dim)
49
+
50
+ def forward(self, image, question_input_ids, question_attention_mask):
51
+ image_embeddings = self.image_encoder(image)
52
+ projected_image_embeddings = self.image_projection(image_embeddings)
53
+
54
+ # Concatenate image embeddings to the input sequence
55
+ # This is a simplified approach. More sophisticated methods exist.
56
+ # Assumes image embeddings are prepended to the sequence.
57
+
58
+ # Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
59
+ projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
60
+
61
+ # Concatenate along the sequence dimension (dim=1)
62
+ extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=question_attention_mask.device), question_attention_mask], dim=1)
63
+ extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=question_input_ids.device), question_input_ids], dim=1)
64
+
65
+ # Pass the concatenated input to Phi-3
66
+ outputs = self.phi3(input_ids=extended_input_ids, attention_mask=extended_attention_mask) # No labels during inference
67
+ return outputs.logits
68
+
69
+ # 2. Load Models and Tokenizer
70
+ phi3_model_name = "microsoft/Phi-3-mini-4k-instruct" # Or your specific Phi-3 variant
71
+ lora_model_path = "./qlora-phi3-model-new"
72
+ image_model_name = 'resnet50'
73
+ image_embed_dim = 512
74
+ siglip_pretrained_path = "image_encoder.pth"
75
+
76
+ device = torch.device("cpu") # Force CPU
77
+ print(f"Using device: {device}")
78
+
79
+ # Load Tokenizer
80
+ text_tokenizer = AutoTokenizer.from_pretrained(phi3_model_name, trust_remote_code=True)
81
+ text_tokenizer.pad_token = text_tokenizer.eos_token # Important for training
82
+
83
+ # Image Transformations
84
+ image_transform = transforms.Compose([
85
+ transforms.Resize((224, 224)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
88
+ ])
89
+
90
+ # Load Models
91
+ image_encoder = SigLIPImageEncoder(model_name=image_model_name, embed_dim=image_embed_dim, pretrained_path=siglip_pretrained_path).to(device)
92
+ model = Phi3WithImage(phi3_model_name, image_encoder, image_embed_dim).to(device)
93
+
94
+ # Load LoRA model
95
+ model.phi3 = PeftModel.from_pretrained(model.phi3, lora_model_path, device_map="auto")
96
+ model.eval() # Set to evaluation mode
97
+
98
+ # 3. Inference Function
99
+ @spaces.GPU
100
+ def predict(image_input, question):
101
+ """
102
+ Takes an image and a question as input and returns an answer.
103
+ """
104
+ if image_input is None or question is None or question == "":
105
+ return "Please provide both an image and a question."
106
+
107
+ try:
108
+ image = Image.fromarray(image_input).convert("RGB")
109
+ image = image_transform(image).unsqueeze(0).to(device)
110
+
111
+ prompt = f"Question: {question}\nAnswer:"
112
+ encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
113
+
114
+ with torch.no_grad():
115
+ logits = model(image, encoded["input_ids"], encoded["attention_mask"])
116
+
117
+ # Generate answer
118
+ generated_tokens = model.phi3.generate(
119
+ inputs=None, # Remove input_ids and attention_mask
120
+ inputs_embeds=logits, # Use the logits from the forward pass as input embeddings
121
+ max_length=128,
122
+ pad_token_id=text_tokenizer.eos_token_id
123
+ )
124
+
125
+ answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
126
+ answer = answer.replace(prompt, "").strip() # Remove prompt from answer
127
+
128
+ return answer
129
+
130
+ except Exception as e:
131
+ #return f"An error occurred: {str(e)}"
132
+ #return f"An error occurred: {traceback.format_exc()}"
133
+
134
+
135
+
136
+ # 5. Launch the App
137
+ if __name__ == "__main__":
138
+ iface = gr.Interface(
139
+ fn=predict,
140
+ inputs=[
141
+ gr.Image(label="Upload an Image"),
142
+ gr.Textbox(label="Ask a Question about the Image", placeholder="What is in the image?")
143
+ ],
144
+ outputs=gr.Textbox(label="Answer"),
145
+ title="Image Question Answering with Phi-3 and SigLIP",
146
+ description="Ask questions about an image and get answers powered by Phi-3 and SigLIP.",
147
+ examples=[
148
+ ["cat_0006.png", "Create a interesting story about this image?"],
149
+ ["bird_0004.png", "Can you describe this image?"],
150
+ ["truck_0003.png", "Elaborate the setting of the image"],
151
+ ["ship_0007.png", "Explain the purpose of image"]
152
+ ]
153
+ )
154
+
155
+
156
+ iface.launch()