VRE / src /reasoning.py
nasim-raj-laskar
initial deploy
aeba2d0
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
def reason(objects, caption, question):
prompt = f"""
You are a visual reasoning system.
Use ONLY the given objects and scene.
Do NOT invent new events or actions.
If an action is visible, describe it.
If no clear action is visible, describe the scene simply.
Example:
Objects: person, dog
Scene: a man walking a dog on a path
Question: What is happening in this image?
Answer: A person is walking a dog outdoors.
Objects: car
Scene: a car on a race track
Question: What is happening in this image?
Answer: A car is driving on a race track.
Now answer:
Objects: {objects}
Scene: {caption}
Question: {question}
Answer:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=40
)
raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = raw.split("Answer:")[-1].strip()
# remove accidental extra parts
answer = answer.split("\n")[0]
return answer