| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
|
|
| tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") |
| model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") |
|
|
| def reason(objects, caption, question): |
| prompt = f""" |
| You are a visual reasoning system. |
| |
| Use ONLY the given objects and scene. |
| Do NOT invent new events or actions. |
| |
| If an action is visible, describe it. |
| If no clear action is visible, describe the scene simply. |
| |
| Example: |
| Objects: person, dog |
| Scene: a man walking a dog on a path |
| Question: What is happening in this image? |
| Answer: A person is walking a dog outdoors. |
| |
| Objects: car |
| Scene: a car on a race track |
| Question: What is happening in this image? |
| Answer: A car is driving on a race track. |
| |
| Now answer: |
| |
| Objects: {objects} |
| Scene: {caption} |
| Question: {question} |
| |
| Answer: |
| """ |
|
|
| inputs = tokenizer(prompt, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=40 |
| ) |
|
|
| raw = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| answer = raw.split("Answer:")[-1].strip() |
|
|
| |
| answer = answer.split("\n")[0] |
|
|
| return answer |