Spaces:
Running
Running
MinAA commited on
Commit ·
693999e
1
Parent(s): 5cccb5d
init
Browse files
app.py
CHANGED
|
@@ -955,13 +955,15 @@ def visual_qa(image, question, model_name):
|
|
| 955 |
# Для BLIP VQA используем формат "Question: {question} Answer:"
|
| 956 |
prompt = f"Question: {question} Answer:"
|
| 957 |
inputs = processor(image, prompt, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
| 958 |
# Используем параметры генерации, которые помогают получить ответ, а не вопрос
|
| 959 |
out = model.generate(
|
| 960 |
**inputs,
|
| 961 |
max_length=50,
|
| 962 |
num_beams=3,
|
| 963 |
-
do_sample=False
|
| 964 |
-
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 965 |
)
|
| 966 |
answer = processor.decode(out[0], skip_special_tokens=True)
|
| 967 |
# Убираем промпт из ответа, если он там остался
|
|
@@ -990,70 +992,13 @@ def visual_qa(image, question, model_name):
|
|
| 990 |
**inputs,
|
| 991 |
max_length=50,
|
| 992 |
num_beams=5,
|
| 993 |
-
do_sample=False
|
| 994 |
-
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 995 |
)
|
| 996 |
answer = processor.decode(out[0], skip_special_tokens=True)
|
| 997 |
# Убираем вопрос из ответа
|
| 998 |
if question.lower() in answer.lower():
|
| 999 |
answer = answer.replace(question, "").replace("?", "").strip()
|
| 1000 |
return f"Ответ: {answer}"
|
| 1001 |
-
elif "git" in model_name.lower():
|
| 1002 |
-
# GIT модели для VQA требуют специальный формат
|
| 1003 |
-
# Внимание: microsoft/git-base - это модель для captioning, не для VQA
|
| 1004 |
-
# Но можно попробовать использовать её для VQA с правильным форматом
|
| 1005 |
-
cache_key = f"vqa_git_{model_name}"
|
| 1006 |
-
cached = model_cache.get(cache_key)
|
| 1007 |
-
if cached is None:
|
| 1008 |
-
processor = AutoProcessor.from_pretrained(model_name)
|
| 1009 |
-
from transformers import AutoModelForCausalLM
|
| 1010 |
-
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 1011 |
-
cached = (processor, model)
|
| 1012 |
-
model_cache.put(cache_key, cached)
|
| 1013 |
-
|
| 1014 |
-
processor, model = cached
|
| 1015 |
-
# Для GIT используем формат "Question: {question} Answer:"
|
| 1016 |
-
prompt = f"Question: {question} Answer:"
|
| 1017 |
-
inputs = processor(images=image, text=prompt, return_tensors="pt")
|
| 1018 |
-
generated_ids = model.generate(
|
| 1019 |
-
pixel_values=inputs.pixel_values,
|
| 1020 |
-
input_ids=inputs.input_ids,
|
| 1021 |
-
max_length=50,
|
| 1022 |
-
num_beams=3,
|
| 1023 |
-
do_sample=False,
|
| 1024 |
-
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 1025 |
-
)
|
| 1026 |
-
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 1027 |
-
# Извлекаем только ответ (часть после "Answer:")
|
| 1028 |
-
if "Answer:" in generated_text:
|
| 1029 |
-
answer = generated_text.split("Answer:")[-1].strip()
|
| 1030 |
-
else:
|
| 1031 |
-
# Убираем промпт из ответа
|
| 1032 |
-
answer = generated_text.replace(prompt, "").strip()
|
| 1033 |
-
# Убираем вопрос, если он там остался
|
| 1034 |
-
if question.lower() in answer.lower():
|
| 1035 |
-
answer = answer.replace(question, "").strip()
|
| 1036 |
-
# Проверяем, не является ли ответ вопросом
|
| 1037 |
-
if answer.lower().strip().startswith(("which", "what", "where", "when", "who", "how", "why")):
|
| 1038 |
-
# Если ответ начинается с вопросительного слова, это может быть вопрос
|
| 1039 |
-
# Пробуем еще раз с другим форматом
|
| 1040 |
-
prompt = f"{question}?"
|
| 1041 |
-
inputs = processor(images=image, text=prompt, return_tensors="pt")
|
| 1042 |
-
generated_ids = model.generate(
|
| 1043 |
-
pixel_values=inputs.pixel_values,
|
| 1044 |
-
input_ids=inputs.input_ids,
|
| 1045 |
-
max_length=50,
|
| 1046 |
-
num_beams=5,
|
| 1047 |
-
do_sample=False,
|
| 1048 |
-
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 1049 |
-
)
|
| 1050 |
-
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 1051 |
-
# Убираем вопрос из ответа
|
| 1052 |
-
if question.lower() in generated_text.lower():
|
| 1053 |
-
answer = generated_text.replace(question, "").replace("?", "").strip()
|
| 1054 |
-
else:
|
| 1055 |
-
answer = generated_text.strip()
|
| 1056 |
-
return f"Ответ: {answer}"
|
| 1057 |
else:
|
| 1058 |
vqa = get_pipeline("visual-question-answering", model_name)
|
| 1059 |
result = vqa(image=image, question=question)
|
|
@@ -1654,8 +1599,7 @@ with gr.Blocks(title="Трансформеры Hugging Face", theme=gr.themes.So
|
|
| 1654 |
vqa_model = gr.Dropdown(
|
| 1655 |
choices=[
|
| 1656 |
"dandelin/vilt-b32-finetuned-vqa",
|
| 1657 |
-
"Salesforce/blip-vqa-base"
|
| 1658 |
-
"microsoft/git-base"
|
| 1659 |
],
|
| 1660 |
value="dandelin/vilt-b32-finetuned-vqa",
|
| 1661 |
label="Выберите модель"
|
|
|
|
| 955 |
# Для BLIP VQA используем формат "Question: {question} Answer:"
|
| 956 |
prompt = f"Question: {question} Answer:"
|
| 957 |
inputs = processor(image, prompt, return_tensors="pt")
|
| 958 |
+
# Устанавливаем pad_token_id в модели, если его нет
|
| 959 |
+
if model.config.pad_token_id is None:
|
| 960 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 961 |
# Используем параметры генерации, которые помогают получить ответ, а не вопрос
|
| 962 |
out = model.generate(
|
| 963 |
**inputs,
|
| 964 |
max_length=50,
|
| 965 |
num_beams=3,
|
| 966 |
+
do_sample=False
|
|
|
|
| 967 |
)
|
| 968 |
answer = processor.decode(out[0], skip_special_tokens=True)
|
| 969 |
# Убираем промпт из ответа, если он там остался
|
|
|
|
| 992 |
**inputs,
|
| 993 |
max_length=50,
|
| 994 |
num_beams=5,
|
| 995 |
+
do_sample=False
|
|
|
|
| 996 |
)
|
| 997 |
answer = processor.decode(out[0], skip_special_tokens=True)
|
| 998 |
# Убираем вопрос из ответа
|
| 999 |
if question.lower() in answer.lower():
|
| 1000 |
answer = answer.replace(question, "").replace("?", "").strip()
|
| 1001 |
return f"Ответ: {answer}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
else:
|
| 1003 |
vqa = get_pipeline("visual-question-answering", model_name)
|
| 1004 |
result = vqa(image=image, question=question)
|
|
|
|
| 1599 |
vqa_model = gr.Dropdown(
|
| 1600 |
choices=[
|
| 1601 |
"dandelin/vilt-b32-finetuned-vqa",
|
| 1602 |
+
"Salesforce/blip-vqa-base"
|
|
|
|
| 1603 |
],
|
| 1604 |
value="dandelin/vilt-b32-finetuned-vqa",
|
| 1605 |
label="Выберите модель"
|