nijatmammadov's picture
Update app.py
ec8c15a verified
import gradio as gr
from transformers import AutoProcessor
from transformers import Idefics3ForConditionalGeneration
from ultralytics import YOLO
from PIL import Image, ImageDraw
import torch
# System message
system_message = """
You are a Vision-Language Model specialized in understanding real-world object images with annotated bounding boxes.
Your task is to analyze the visual content, including detected objects, their locations, and appearances, and respond accurately to natural language queries.
Possible queries include:
- Describing the objects in the image in detail
- Estimating market prices based on visual and contextual cues
- Answering specific factual or contextual questions about the image
Your responses should be concise, accurate, and directly based on the visual information. Use the image content, object positions, and any visual clues to inform your answers. Avoid unnecessary explanation unless explicitly requested.
"""
RESIZE_SIZE = (512, 512)
MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct"
ADAPTER_PATH = "smolvlm-instruct-trl-sft-ChartQA"
YOLO_MODEL_PATH = "saved_model.pt"
model = Idefics3ForConditionalGeneration.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="eager",
)
model.load_adapter(ADAPTER_PATH)
processor = AutoProcessor.from_pretrained(MODEL_ID)
yolo = YOLO(YOLO_MODEL_PATH)
def generate_text_from_sample(sample, max_new_tokens=1024, device="cuda"):
text_input = processor.apply_chat_template(
sample[1:2], add_generation_prompt=True
)
image = sample[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs = [[image]]
model_inputs = processor(
text=text_input,
images=image_inputs,
return_tensors="pt",
).to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
trimmed_generated_ids = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def predict(msg, img):
image = img.convert("RGB")
original_size = image.size
resized_image = image.resize(RESIZE_SIZE)
scale_x = RESIZE_SIZE[0] / original_size[0]
scale_y = RESIZE_SIZE[1] / original_size[1]
# Object detection
pred = yolo(img)
draw = ImageDraw.Draw(resized_image)
bboxes = pred[0].boxes.xyxy.tolist()
for box in bboxes:
scaled_box = [
int(box[0] * scale_x),
int(box[1] * scale_y),
int(box[2] * scale_x),
int(box[3] * scale_y)
]
draw.rectangle(scaled_box, outline="red", width=5)
input_sample = [
{'role': 'system', 'content': [{'type': 'text', 'text': system_message}]},
{'role': 'user', 'content': [{'type': 'image', 'image': resized_image}, {'type': 'text', 'text': msg}]},
{'role': 'assistant', 'content': [{'type': 'text', 'text': ""}]}
]
response = generate_text_from_sample(input_sample)
return response, resized_image
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Query", placeholder="What objects are in this image?"),
gr.Image(type="pil", label="Upload Image")
],
outputs=[
gr.Textbox(label="Response"),
gr.Image(label="Annotated Image")
],
title="Visual Object Analyzer with VLM",
description="Upload an image and ask a question. The model will detect objects, draw bounding boxes, and generate an answer."
)
if __name__ == "__main__":
demo.launch()