Monimoy commited on
Commit
9023859
·
verified ·
1 Parent(s): 54fad30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -115,21 +115,33 @@ def predict(image_input, question):
115
  encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
116
 
117
  with torch.no_grad():
118
- logits = model(image, encoded["input_ids"], encoded["attention_mask"])
119
-
120
- # Generate answer
121
- generated_tokens = model.phi3.generate(
122
- inputs=None, # Remove input_ids and attention_mask
123
- inputs_embeds=logits, # Use the logits from the forward pass as input embeddings
124
- max_length=128,
125
- pad_token_id=text_tokenizer.eos_token_id
126
- )
 
 
 
 
 
 
 
 
 
127
 
128
  answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
129
  answer = answer.replace(prompt, "").strip() # Remove prompt from answer
130
 
131
  return answer
132
 
 
 
 
133
  except Exception as e:
134
  #return f"An error occurred: {str(e)}"
135
  return f"An error occurred: {traceback.format_exc()}"
 
115
  encoded = text_tokenizer(prompt, return_tensors="pt").to(device)
116
 
117
  with torch.no_grad():
118
+ # Get image embeddings
119
+ image_embeddings = model.image_encoder(image)
120
+ projected_image_embeddings = model.image_projection(image_embeddings)
121
+
122
+ # Reshape image embeddings to (batch_size, 1, phi3_embed_dim)
123
+ projected_image_embeddings = projected_image_embeddings.unsqueeze(1)
124
+
125
+ # Concatenate along the sequence dimension (dim=1)
126
+ extended_attention_mask = torch.cat([torch.ones(projected_image_embeddings.shape[:2], device=encoded["attention_mask"].device), encoded["attention_mask"]], dim=1)
127
+ extended_input_ids = torch.cat([torch.zeros(projected_image_embeddings.shape[:2], dtype=torch.long, device=encoded["input_ids"].device), encoded["input_ids"]], dim=1)
128
+
129
+ # Generate answer
130
+ generated_tokens = model.phi3.generate(
131
+ input_ids=extended_input_ids,
132
+ attention_mask=extended_attention_mask,
133
+ max_length=200,
134
+ pad_token_id=text_tokenizer.eos_token_id
135
+ )
136
 
137
  answer = text_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
138
  answer = answer.replace(prompt, "").strip() # Remove prompt from answer
139
 
140
  return answer
141
 
142
+ except Exception as e:
143
+ return f"An error occurred: {str(e)}"
144
+
145
  except Exception as e:
146
  #return f"An error occurred: {str(e)}"
147
  return f"An error occurred: {traceback.format_exc()}"