bhargavi909 commited on
Commit
d45d51f
·
verified ·
1 Parent(s): f17bf67

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import streamlit as st
4
+ from transformers import (
5
+ CLIPProcessor, CLIPModel,
6
+ DistilBertTokenizer, DistilBertModel,
7
+ GPT2LMHeadModel, GPT2Tokenizer
8
+ )
9
+
10
+ # -------- Load Models --------
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").vision_model.to(device)
15
+
16
+ text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
17
+ text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased").to(device)
18
+
19
+ decoder_tokenizer = GPT2Tokenizer.from_pretrained("sreebhargavibalija/sreebhargavibalija-multimodal-gen")
20
+ decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
21
+ decoder = GPT2LMHeadModel.from_pretrained("sreebhargavibalija/sreebhargavibalija-multimodal-gen").to(device)
22
+
23
+ # -------- Fusion Wrapper --------
24
+ class MultimodalGenerator(torch.nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.image_encoder = clip_model
28
+ self.text_encoder = text_encoder
29
+ self.decoder = decoder
30
+
31
+ self.project_image = torch.nn.Linear(768, 768)
32
+ self.project_text = torch.nn.Linear(768, 768)
33
+ self.fusion = torch.nn.Linear(768 * 2, 768)
34
+
35
+ def forward(self, image_tensor, prompt_input_ids, prompt_attention_mask, max_length=50):
36
+ img_feat = self.image_encoder(pixel_values=image_tensor).last_hidden_state[:, 0, :]
37
+ img_feat = self.project_image(img_feat)
38
+
39
+ txt_feat = self.text_encoder(input_ids=prompt_input_ids, attention_mask=prompt_attention_mask).last_hidden_state[:, 0, :]
40
+ txt_feat = self.project_text(txt_feat)
41
+
42
+ fused = self.fusion(torch.cat([img_feat, txt_feat], dim=-1)).unsqueeze(1)
43
+
44
+ generated = self.decoder.generate(
45
+ inputs_embeds=fused,
46
+ max_length=max_length,
47
+ do_sample=True,
48
+ top_k=50,
49
+ top_p=0.95,
50
+ num_return_sequences=1,
51
+ pad_token_id=self.decoder.config.pad_token_id
52
+ )
53
+ return generated
54
+
55
+ # Initialize model
56
+ model = MultimodalGenerator().to(device)
57
+ model.eval()
58
+
59
+ # -------- Streamlit UI --------
60
+ st.set_page_config(page_title="Multimodal LLM", layout="centered")
61
+ st.title("🧠 Multimodal LLM: Image + Prompt → Text")
62
+
63
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
64
+ prompt_text = st.text_input("Enter your prompt (e.g. 'Describe this scene'):")
65
+
66
+ if uploaded_file is not None and prompt_text.strip():
67
+ image = Image.open(uploaded_file).convert("RGB")
68
+ st.image(image, caption="Uploaded Image", use_column_width=True)
69
+
70
+ image_tensor = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)
71
+ prompt_inputs = text_tokenizer(prompt_text, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
72
+ prompt_ids = prompt_inputs["input_ids"].to(device)
73
+ prompt_mask = prompt_inputs["attention_mask"].to(device)
74
+
75
+ with st.spinner("Generating..."):
76
+ with torch.no_grad():
77
+ generated_ids = model(image_tensor, prompt_ids, prompt_mask, max_length=64)
78
+ output_text = decoder_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
79
+
80
+ st.markdown("### ✨ Generated Text")
81
+ st.success(output_text)
82
+ else:
83
+ st.info("👆 Upload an image and enter a prompt to get started!")