imeesam commited on
Commit
b3d99de
Β·
1 Parent(s): 337c7d0

Updated the code for my custom Space

Browse files
Files changed (4) hide show
  1. Dockerfile +18 -0
  2. README.md +3 -2
  3. app.py +120 -128
  4. requirements.txt +4 -1
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces expects port 7860
2
+ FROM python:3.10-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # System deps (for Pillow / torch)
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ libgl1 libglib2.0-0 && \
9
+ rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy everything in the repo
15
+ COPY . .
16
+
17
+ EXPOSE 7860
18
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,9 +1,10 @@
1
  ---
2
- title: Vocal_Eyes
 
 
3
  emoji: πŸ˜ŽπŸ•ΆπŸ‘¨β€πŸ¦―
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: gradio
7
  sdk_version: 6.5.1
8
  app_file: app.py
9
  license: mit
 
1
  ---
2
+ title: VocalEyes
3
+ sdk: docker
4
+ app_port: 7860
5
  emoji: πŸ˜ŽπŸ•ΆπŸ‘¨β€πŸ¦―
6
  colorFrom: blue
7
  colorTo: purple
 
8
  sdk_version: 6.5.1
9
  app_file: app.py
10
  license: mit
app.py CHANGED
@@ -1,67 +1,67 @@
1
  import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
  import torchvision.transforms as T
5
- import os
6
  from PIL import Image
 
 
 
 
7
  from RelTR_build import Reltr_model
8
  from T5_build import T5_model
9
- import gradio as gr
10
-
11
-
12
- CLASSES = [ 'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
13
- 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
14
- 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
15
- 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
16
- 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
17
- 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
18
- 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
19
- 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
20
- 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
21
- 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
22
- 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
23
- 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
24
- 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
25
- 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']
26
-
27
- REL_CLASSES = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
28
- 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
29
- 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on',
30
- 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over',
31
- 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
32
- 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']
33
-
34
- model=Reltr_model()
35
- tokenizer_2,model_text_2=T5_model()
36
-
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  transform = T.Compose([
39
  T.Resize(800),
40
  T.ToTensor(),
41
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
  ])
43
 
44
 
 
45
  def normalize_triplets(triplets):
46
- seen = set()
47
- cleaned = []
48
- for s, r, o in triplets:
49
- key = tuple(sorted((s, r, o)))
50
- if key not in seen:
51
- seen.add(key)
52
- cleaned.append((s, r, o))
53
- return cleaned
54
 
55
 
56
  class RelTRSceneGraphExtractor:
57
- def __init__(self, model, obj_classes, rel_classes, device="cuda"):
58
  self.device = device
59
- self.model = model.to(self.device)
60
-
61
- self.model.eval()
62
  self.obj_classes = obj_classes
63
  self.rel_classes = rel_classes
64
- self.device = device
65
 
66
  @torch.no_grad()
67
  def extract_triplets(self, image_tensor, conf_thresh=0.3):
@@ -74,138 +74,130 @@ class RelTRSceneGraphExtractor:
74
 
75
  triplets = []
76
  for i in range(len(probas_rel)):
77
- score = probas_rel[i].max().item()
78
- if score < conf_thresh:
79
  continue
80
-
81
  rel = self.rel_classes[probas_rel[i].argmax()]
82
  sub = self.obj_classes[probas_sub[i].argmax()]
83
  obj = self.obj_classes[probas_obj[i].argmax()]
84
-
85
  triplets.append((sub, rel, obj))
86
 
 
87
 
88
- triplets = normalize_triplets(triplets)
89
-
90
- return triplets
91
 
92
  class SceneGraphToText:
93
- def __init__(self):
94
- pass
95
-
96
  def convert(self, triplets):
97
- if len(triplets) == 0:
98
  return "No clear relationships detected in the image."
 
99
 
100
- sentences = []
101
- for sub, rel, obj in triplets:
102
- sentences.append(f"A {sub} is {rel} a {obj}.")
103
- return " ".join(sentences)
104
 
105
  class T5TextGenerator:
106
- def __init__(self, model, tokenizer, device="cuda"):
107
  self.tokenizer = tokenizer
108
- self.model = model.to(device)
109
  self.device = device
110
 
111
  @torch.no_grad()
112
- def generate(self, prompt, **gen_kwargs):
113
- # Tokenize and move each tensor to the right device
114
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
115
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
116
-
117
- # Pass through the model with any generation kwargs
118
- outputs = self.model.generate(
119
- **inputs,
120
-
121
- **gen_kwargs
122
- )
123
-
124
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
125
 
126
 
127
  class RelTR_T5_Pipeline:
128
- def __init__(self, reltr_model,model_text,tokenizer, obj_classes, rel_classes, device="cuda"):
129
- self.scene_graph = RelTRSceneGraphExtractor(
130
- reltr_model, obj_classes, rel_classes, device
131
- )
132
  self.graph_to_text = SceneGraphToText()
133
- self.t5 = T5TextGenerator(model_text,tokenizer,device=device)
134
 
135
  def run(self, image_tensor):
136
  triplets = self.scene_graph.extract_triplets(image_tensor)
137
  scene_text = self.graph_to_text.convert(triplets)
138
 
139
-
140
- prompt = f"""
141
- Convert the following relationship facts into a short, coherent scene description.
142
-
143
  Rules:
144
  - Use only the provided facts.
145
  - Combine related facts into natural sentences.
146
  - Do not invent new objects or actions.
147
  - Describe the main subject first.
148
  - Maximum 3 sentences.
149
-
150
  Facts:
151
-
152
- {scene_text}
153
- """
154
-
155
-
156
 
157
  output = self.t5.generate(
158
- prompt,
159
-
160
-
161
- max_new_tokens=60,
162
- do_sample=False, # still deterministic
163
- num_beams=5,
164
- length_penalty=1.2, # encourages longer summaries
165
- repetition_penalty=1.3,
166
- early_stopping=True
167
- )
168
-
169
- return {
170
- "triplets": triplets,
171
- "scene_text": scene_text,
172
- "generated_text": output
173
- }
174
-
175
-
176
 
177
 
 
 
 
 
178
 
179
  pipeline = RelTR_T5_Pipeline(
180
- reltr_model=model,
181
- model_text=model_text_2,
182
- tokenizer=tokenizer_2,
183
  obj_classes=CLASSES,
184
  rel_classes=REL_CLASSES,
185
- device="cpu"
186
  )
 
187
 
188
- def Scene_to_text(image):
189
- img = transform(image).unsqueeze(0)
190
- result = pipeline.run(img)
191
- return result["generated_text"]
192
-
193
- #Build Gradio app
194
 
195
 
 
 
 
196
 
197
- title="Vocal_EyesπŸ˜ŽπŸ•ΆπŸ‘¨β€πŸ¦―"
198
- description="Converts images into scene graph triplets then into a short factual description."
199
- # Create examples list from "examples/" directory
200
- example_list = [["examples/" + example] for example in os.listdir("examples")]
201
 
202
- demo = gr.Interface(
203
- fn=Scene_to_text,
204
- inputs=gr.Image(type="pil"),
205
- outputs=[gr.Textbox(label="Description", lines=3)],
206
- title=title,
207
- description=description,
208
- examples=example_list
 
209
 
210
- )
211
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
 
2
  import torchvision.transforms as T
3
+ import io
4
  from PIL import Image
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException
6
+ from fastapi.responses import JSONResponse
7
+ import uvicorn
8
+
9
  from RelTR_build import Reltr_model
10
  from T5_build import T5_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # ── Label sets ────────────────────────────────────────────────────────────────
13
+ CLASSES = [
14
+ 'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
15
+ 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
16
+ 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
17
+ 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
18
+ 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
19
+ 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
20
+ 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
21
+ 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
22
+ 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
23
+ 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
24
+ 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
25
+ 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
26
+ 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
27
+ 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra',
28
+ ]
29
+
30
+ REL_CLASSES = [
31
+ '__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
32
+ 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
33
+ 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on',
34
+ 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over',
35
+ 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
36
+ 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with',
37
+ ]
38
+
39
+ # ── Image preprocessing ───────────────────────────────────────────────���────────
40
  transform = T.Compose([
41
  T.Resize(800),
42
  T.ToTensor(),
43
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
44
  ])
45
 
46
 
47
+ # ── Pipeline components ────────────────────────────────────────────────────────
48
  def normalize_triplets(triplets):
49
+ seen = set()
50
+ cleaned = []
51
+ for s, r, o in triplets:
52
+ key = tuple(sorted((s, r, o)))
53
+ if key not in seen:
54
+ seen.add(key)
55
+ cleaned.append((s, r, o))
56
+ return cleaned
57
 
58
 
59
  class RelTRSceneGraphExtractor:
60
+ def __init__(self, model, obj_classes, rel_classes, device="cpu"):
61
  self.device = device
62
+ self.model = model.to(self.device).eval()
 
 
63
  self.obj_classes = obj_classes
64
  self.rel_classes = rel_classes
 
65
 
66
  @torch.no_grad()
67
  def extract_triplets(self, image_tensor, conf_thresh=0.3):
 
74
 
75
  triplets = []
76
  for i in range(len(probas_rel)):
77
+ if probas_rel[i].max().item() < conf_thresh:
 
78
  continue
 
79
  rel = self.rel_classes[probas_rel[i].argmax()]
80
  sub = self.obj_classes[probas_sub[i].argmax()]
81
  obj = self.obj_classes[probas_obj[i].argmax()]
 
82
  triplets.append((sub, rel, obj))
83
 
84
+ return normalize_triplets(triplets)
85
 
 
 
 
86
 
87
  class SceneGraphToText:
 
 
 
88
  def convert(self, triplets):
89
+ if not triplets:
90
  return "No clear relationships detected in the image."
91
+ return " ".join(f"A {s} is {r} a {o}." for s, r, o in triplets)
92
 
 
 
 
 
93
 
94
  class T5TextGenerator:
95
+ def __init__(self, model, tokenizer, device="cpu"):
96
  self.tokenizer = tokenizer
97
+ self.model = model.to(device).eval()
98
  self.device = device
99
 
100
  @torch.no_grad()
101
+ def generate(self, prompt, **gen_kwargs):
 
102
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
103
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
104
+ outputs = self.model.generate(**inputs, **gen_kwargs)
 
 
 
 
 
 
 
105
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
106
 
107
 
108
  class RelTR_T5_Pipeline:
109
+ def __init__(self, reltr_model, model_text, tokenizer, obj_classes, rel_classes, device="cpu"):
110
+ self.scene_graph = RelTRSceneGraphExtractor(reltr_model, obj_classes, rel_classes, device)
 
 
111
  self.graph_to_text = SceneGraphToText()
112
+ self.t5 = T5TextGenerator(model_text, tokenizer, device=device)
113
 
114
  def run(self, image_tensor):
115
  triplets = self.scene_graph.extract_triplets(image_tensor)
116
  scene_text = self.graph_to_text.convert(triplets)
117
 
118
+ prompt = f"""Convert the following relationship facts into a short, coherent scene description.
 
 
 
119
  Rules:
120
  - Use only the provided facts.
121
  - Combine related facts into natural sentences.
122
  - Do not invent new objects or actions.
123
  - Describe the main subject first.
124
  - Maximum 3 sentences.
 
125
  Facts:
126
+ {scene_text}"""
 
 
 
 
127
 
128
  output = self.t5.generate(
129
+ prompt,
130
+ max_new_tokens=60,
131
+ do_sample=False,
132
+ num_beams=5,
133
+ length_penalty=1.2,
134
+ repetition_penalty=1.3,
135
+ early_stopping=True,
136
+ )
137
+ return {"triplets": triplets, "scene_text": scene_text, "generated_text": output}
 
 
 
 
 
 
 
 
 
138
 
139
 
140
+ # ── Load models once at startup ────────────────────────────────────────────────
141
+ print("Loading models...")
142
+ _reltr = Reltr_model()
143
+ _tokenizer, _t5 = T5_model()
144
 
145
  pipeline = RelTR_T5_Pipeline(
146
+ reltr_model=_reltr,
147
+ model_text=_t5,
148
+ tokenizer=_tokenizer,
149
  obj_classes=CLASSES,
150
  rel_classes=REL_CLASSES,
151
+ device="cpu",
152
  )
153
+ print("Models loaded.")
154
 
155
+ # ── FastAPI app ────────────────────────────────────────────────────────────────
156
+ app = FastAPI(
157
+ title="VocalEyes API",
158
+ description="Converts an uploaded image into a short scene description via RelTR + T5.",
159
+ version="1.0.0",
160
+ )
161
 
162
 
163
+ @app.get("/")
164
+ def root():
165
+ return {"status": "ok", "message": "VocalEyes API is running. POST an image to /predict"}
166
 
 
 
 
 
167
 
168
+ @app.post("/predict")
169
+ async def predict(file: UploadFile = File(...)):
170
+ # ── Validate content type ──────────────────────────────────────────────────
171
+ if file.content_type not in ("image/jpeg", "image/png", "image/webp", "image/bmp"):
172
+ raise HTTPException(
173
+ status_code=415,
174
+ detail=f"Unsupported file type '{file.content_type}'. Send JPEG, PNG, WEBP, or BMP.",
175
+ )
176
 
177
+ # ── Read & preprocess ──────────────────────────────────────────────────────
178
+ try:
179
+ raw = await file.read()
180
+ image = Image.open(io.BytesIO(raw)).convert("RGB")
181
+ except Exception as e:
182
+ raise HTTPException(status_code=400, detail=f"Could not read image: {e}")
183
+
184
+ img_tensor = transform(image).unsqueeze(0) # (1, 3, H, W)
185
+
186
+ # ── Run pipeline ───────────────────────────────────────────────────────────
187
+ try:
188
+ result = pipeline.run(img_tensor)
189
+ except Exception as e:
190
+ raise HTTPException(status_code=500, detail=f"Pipeline error: {e}")
191
+
192
+ return JSONResponse({
193
+ "description": result["generated_text"],
194
+ "triplets": [
195
+ {"subject": s, "relation": r, "object": o}
196
+ for s, r, o in result["triplets"]
197
+ ],
198
+ })
199
+
200
+
201
+ # ── Entry point ────────────────────────────────────────────────────────────────
202
+ if __name__ == "__main__":
203
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -7,4 +7,7 @@ opencv-python>=4.7.0,<5
7
  pillow>=9.5.0,<13
8
  numpy>=1.25.0,<2
9
  scipy>=1.10.0,<1.18
10
- huggingface_hub<1.18
 
 
 
 
7
  pillow>=9.5.0,<13
8
  numpy>=1.25.0,<2
9
  scipy>=1.10.0,<1.18
10
+ huggingface_hub<1.18
11
+ fastapi>=0.111.0
12
+ uvicorn[standard]>=0.29.0
13
+ python-multipart # required for UploadFile