Update app.py
Browse files
app.py
CHANGED
|
@@ -8,10 +8,10 @@ from PIL import Image
|
|
| 8 |
|
| 9 |
|
| 10 |
class _MLPVectorProjector(nn.Module):
|
| 11 |
-
def
|
| 12 |
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
|
| 13 |
):
|
| 14 |
-
super(_MLPVectorProjector, self).
|
| 15 |
self.mlps = nn.ModuleList()
|
| 16 |
for _ in range(width):
|
| 17 |
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
|
|
@@ -59,8 +59,13 @@ def encode_image(image_path):
|
|
| 59 |
return img_embedding
|
| 60 |
|
| 61 |
#Get the projection model
|
|
|
|
|
|
|
| 62 |
|
| 63 |
#Get the fine-tuned phi-2 model
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
def example_inference(input_text, count): #, image, img_qn, audio):
|
|
@@ -87,6 +92,7 @@ def textMode(text, count):
|
|
| 87 |
|
| 88 |
def imageMode(image, question):
|
| 89 |
image_embedding = encode_image(image)
|
|
|
|
| 90 |
return "In progress"
|
| 91 |
|
| 92 |
def audioMode(audio):
|
|
@@ -120,7 +126,7 @@ with gr.Blocks() as demo:
|
|
| 120 |
text_output = gr.Textbox(label="Chat GPT like text")
|
| 121 |
with gr.Tab("Image mode"):
|
| 122 |
with gr.Row():
|
| 123 |
-
image_input = gr.Image()
|
| 124 |
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
|
| 125 |
image_button = gr.Button("Submit")
|
| 126 |
image_text_output = gr.Textbox(label="Answer")
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class _MLPVectorProjector(nn.Module):
|
| 11 |
+
def init(
|
| 12 |
self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
|
| 13 |
):
|
| 14 |
+
super(_MLPVectorProjector, self).init()
|
| 15 |
self.mlps = nn.ModuleList()
|
| 16 |
for _ in range(width):
|
| 17 |
mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
|
|
|
|
| 59 |
return img_embedding
|
| 60 |
|
| 61 |
#Get the projection model
|
| 62 |
+
img_proj_head = _MLPVectorProjector(512, 2560, 1, 4).to("cuda")
|
| 63 |
+
img_proj_head.load_state_dict(torch.load('projection_finetuned.pth'))
|
| 64 |
|
| 65 |
#Get the fine-tuned phi-2 model
|
| 66 |
+
phi2_finetuned = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
"phi2_adaptor_fineTuned", trust_remote_code=True,
|
| 68 |
+
torch_dtype = torch.float32).to("cuda")
|
| 69 |
|
| 70 |
|
| 71 |
def example_inference(input_text, count): #, image, img_qn, audio):
|
|
|
|
| 92 |
|
| 93 |
def imageMode(image, question):
|
| 94 |
image_embedding = encode_image(image)
|
| 95 |
+
imgToTextEmb = img_proj_head(image_embedding)
|
| 96 |
return "In progress"
|
| 97 |
|
| 98 |
def audioMode(audio):
|
|
|
|
| 126 |
text_output = gr.Textbox(label="Chat GPT like text")
|
| 127 |
with gr.Tab("Image mode"):
|
| 128 |
with gr.Row():
|
| 129 |
+
image_input = gr.Image(type="filepath")
|
| 130 |
image_text_input = gr.Textbox(placeholder="Enter a question/prompt around the image", label="Question/Prompt")
|
| 131 |
image_button = gr.Button("Submit")
|
| 132 |
image_text_output = gr.Textbox(label="Answer")
|