diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9b178bc83323aca973522e58b9666a43a8e1ee4e --- /dev/null +++ b/app.py @@ -0,0 +1,194 @@ +import gradio as gr +import os +from PIL import Image +import numpy as np +import torch +import pickle +from transformers import AutoProcessor +from src.model import MMEBModel +from src.arguments import ModelArguments + +# 假设图片库存储在本地文件夹中 +QUERY_DIR = "imgs/queries" +IMAGE_DIR = "imgs/candidates" +# IMAGE_DIR = "imgs" +image_paths = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith((".jpg", ".png"))] +global IMAGE_TOKEN, TOP_N +IMAGE_TOKEN = "<|image_1|>" +TOP_N = 5 +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"device: {device}") +# 模型加载和初始化 +def load_model(): + global IMAGE_TOKEN + # 模型参数 + model_args = ModelArguments( + # model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/IDMR/IDMR_InternVL2_5-2B", # 替换为你的模型名称 + model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B", + model_backbone="internvl_2_5", # 替换为你的模型 backbone + ) + + # 加载处理器 + if model_args.model_backbone == "phi35v": + processor = AutoProcessor.from_pretrained( + model_args.model_name, + trust_remote_code=True, + num_crops=model_args.num_crops, + ) + processor.tokenizer.padding_side = "right" + elif model_args.model_backbone == "internvl_2_5": + from src.vlm_backbone.intern_vl import InternVLProcessor + from transformers import AutoTokenizer, AutoImageProcessor + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name, + trust_remote_code=True + ) + image_processor = AutoImageProcessor.from_pretrained( + model_args.model_name, + trust_remote_code=True, + use_fast=False + ) + processor = InternVLProcessor( + image_processor=image_processor, + tokenizer=tokenizer + ) + IMAGE_TOKEN = "" + + # 加载模型 + model = MMEBModel.load(model_args) + model = model.to(device, dtype=torch.bfloat16) + model.eval() + + return model, processor + +# 加载模型和处理器 +model, processor = load_model() + +def get_inputs(processor, text, image_path=None, image=None): + if image_path: + image = Image.open(image_path) + + if image is None: + text = text.replace(IMAGE_TOKEN, "") + + inputs = processor( + text=text, + images=[image] if image else None, + return_tensors="pt", + max_length=1024, + truncation=True + ) + inputs = {key: value.to(device) for key, value in inputs.items()} + inputs["image_flags"] = torch.tensor([1 if image else 0], dtype=torch.long).to(device) + if image is None: + del inputs['pixel_values'] + return inputs + + +# 将图片库中的图像编码为 embedding +def encode_image_library(image_paths): + embeddings = [] + for img_path in image_paths: + text = f"{IMAGE_TOKEN}\n Represent the given image." + print(f"text: {text}") + inputs = get_inputs(processor, text, image_path=img_path) + with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16): + output = model(tgt=inputs) + embeddings.append(output["tgt_reps"].float().cpu().numpy()) + return np.stack(embeddings) + +# 保存 embedding 到文件 +def save_embeddings(embeddings, file_path="image_embeddings.pkl"): + with open(file_path, "wb") as f: + pickle.dump(embeddings, f) + +# 加载 embedding 从文件 +def load_embeddings(file_path="image_embeddings.pkl"): + with open(file_path, "rb") as f: + return pickle.load(f) + +# 计算相似度(余弦相似度) +def cosine_similarity(query_embedding, embeddings): + similarity = np.sum(query_embedding * embeddings, axis=-1) + return similarity + +# 检索逻辑 +def retrieve_images(query_text, query_image, top_n=TOP_N): + if query_text: + query_text = f"{IMAGE_TOKEN}\n {query_text}" + else: + query_text = f"{IMAGE_TOKEN}\n Represent the given image." + + if query_image is not None: + image = Image.fromarray(query_image) + else: + image = None + inputs = get_inputs(processor, query_text, image=image) + print(f"inputs: {inputs}") + # with torch.no_grad(): + with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16): + query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy() + + + # 加载图片库的 embedding + embeddings = load_embeddings() + + # 计算相似度 + similarity = cosine_similarity(query_embedding, embeddings) + similarity = similarity.T + print(f"cosine_similarity: {similarity}") + top_indices = np.argsort(-similarity).squeeze(0)[:top_n] + print(f"top_indices: {top_indices}") + + # similarity = model.compute_similarity(np.expand_dims(query_embedding.squeeze(0), axis=1), embeddings.squeeze(1)) + # print(f"model.compute_similarity: {similarity}") + + return [image_paths[i] for i in top_indices] + +# 界面逻辑 +def demo(query_text, query_image): + # 执行检索 + # print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}") + + retrieved_images = retrieve_images(query_text, query_image) + # 返回检索结果(图片列表) + return [Image.open(img) for img in retrieved_images] + +# 预置示例 +def load_examples(): + examples = [] + # 获取QUERY_DIR中的所有图片文件 + image_files = [f for f in os.listdir(QUERY_DIR) if f.endswith((".jpg", ".png"))] + + for img_file in image_files: + # 构建图片完整路径 + img_path = os.path.join(QUERY_DIR, img_file) + # 获取对应的txt文件名(将图片扩展名替换为.txt) + txt_file = os.path.splitext(img_file)[0] + ".txt" + txt_path = os.path.join(QUERY_DIR, txt_file) + + # 如果存在对应的txt文件,读取查询文本 + if os.path.exists(txt_path): + with open(txt_path, 'r', encoding='utf-8') as f: + query_text = f.read().strip().replace("<|image_1|>\n", "") + examples.append([query_text, img_path]) + + return examples + +# 构建 Gradio 界面 +iface = gr.Interface( + fn=demo, + inputs=["text", "image"], + outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})"), + examples=load_examples(), # 使用动态加载的示例 + title="Multimodal Retrieval Demo", + description="Enter a query and upload an image to retrieve relevant images from the library. You can click on the example below to use it as a query" +) + +# 在启动时编码图片库并保存 embedding +if not os.path.exists("image_embeddings.pkl"): + embeddings = encode_image_library(image_paths) + save_embeddings(embeddings) + +# 启动 Gradio 应用 +iface.launch() \ No newline at end of file diff --git a/image_embeddings.pkl b/image_embeddings.pkl new file mode 100644 index 0000000000000000000000000000000000000000..89ea1409110ce4086ec54c2ea74fba2ea15d2499 --- /dev/null +++ b/image_embeddings.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8dcedaab4e3bcc555795f56b15a7d830b74ffc707260c3b0152ba8d99a992bd +size 409764 diff --git a/imgs/candidates/000000007574.jpg b/imgs/candidates/000000007574.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b47cf6a611b689bc4ab3ac3d2f2e5a7754816716 Binary files /dev/null and b/imgs/candidates/000000007574.jpg differ diff --git a/imgs/candidates/000000009448.jpg b/imgs/candidates/000000009448.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b68d9331e23c70c60c02a5cee436224f1b5e34d Binary files /dev/null and b/imgs/candidates/000000009448.jpg differ diff --git a/imgs/candidates/000000014007.jpg b/imgs/candidates/000000014007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..86961634015e26e5b973f99622c8f67b51d26cc4 Binary files /dev/null and b/imgs/candidates/000000014007.jpg differ diff --git a/imgs/candidates/000000021839.jpg b/imgs/candidates/000000021839.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0142fb9a5f0d2b6d82b4afa7c4488aef83151922 Binary files /dev/null and b/imgs/candidates/000000021839.jpg differ diff --git a/imgs/candidates/000000022892.jpg b/imgs/candidates/000000022892.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f3b8be10e96e44335f178cd3b27259973b06624 Binary files /dev/null and b/imgs/candidates/000000022892.jpg differ diff --git a/imgs/candidates/000000024610.jpg b/imgs/candidates/000000024610.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fbbef1f51c9525bad832e04d7cdd50e89fba9e48 Binary files /dev/null and b/imgs/candidates/000000024610.jpg differ diff --git a/imgs/candidates/000000025593.jpg b/imgs/candidates/000000025593.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3853bd9ef49361d67578c8ef8287ad39559b614 Binary files /dev/null and b/imgs/candidates/000000025593.jpg differ diff --git a/imgs/candidates/000000044068.jpg b/imgs/candidates/000000044068.jpg new file mode 100644 index 0000000000000000000000000000000000000000..58dfc0fc93de8511df6a17d57ab53a30166de32c Binary files /dev/null and b/imgs/candidates/000000044068.jpg differ diff --git a/imgs/candidates/000000084362.jpg b/imgs/candidates/000000084362.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e118cac0455d5c18ff8d051efcb9e7290e700798 Binary files /dev/null and b/imgs/candidates/000000084362.jpg differ diff --git a/imgs/candidates/000000098839.jpg b/imgs/candidates/000000098839.jpg new file mode 100644 index 0000000000000000000000000000000000000000..50c5993b2afe834081aef086438ed429e438cc7f Binary files /dev/null and b/imgs/candidates/000000098839.jpg differ diff --git a/imgs/candidates/000000107339.jpg b/imgs/candidates/000000107339.jpg new file mode 100644 index 0000000000000000000000000000000000000000..17068a639777e34cf05a35900a72f8747772e85b Binary files /dev/null and b/imgs/candidates/000000107339.jpg differ diff --git a/imgs/candidates/000000144333.jpg b/imgs/candidates/000000144333.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51652a8439a35d1255cb6e70a52ff5315fc61e76 Binary files /dev/null and b/imgs/candidates/000000144333.jpg differ diff --git a/imgs/candidates/000000159791.jpg b/imgs/candidates/000000159791.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a741c3020eb7c7fbbdb892f04eb3812616e7a430 Binary files /dev/null and b/imgs/candidates/000000159791.jpg differ diff --git a/imgs/candidates/000000168593.jpg b/imgs/candidates/000000168593.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d49657c9bc5cc11a8bdf452dc35eb3f916b07d0f Binary files /dev/null and b/imgs/candidates/000000168593.jpg differ diff --git a/imgs/candidates/000000182155.jpg b/imgs/candidates/000000182155.jpg new file mode 100644 index 0000000000000000000000000000000000000000..14c9c41e5b84368c074e2cf615ea5090eac632fe Binary files /dev/null and b/imgs/candidates/000000182155.jpg differ diff --git a/imgs/candidates/000000186449.jpg b/imgs/candidates/000000186449.jpg new file mode 100644 index 0000000000000000000000000000000000000000..157090609b0202d2c66137f483643c38bce1e95d Binary files /dev/null and b/imgs/candidates/000000186449.jpg differ diff --git a/imgs/candidates/000000191845.jpg b/imgs/candidates/000000191845.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd383af024a511ea3ec63bc7765a9e9ffceca7fc Binary files /dev/null and b/imgs/candidates/000000191845.jpg differ diff --git a/imgs/candidates/000000210299.jpg b/imgs/candidates/000000210299.jpg new file mode 100644 index 0000000000000000000000000000000000000000..998ddca78d93efc7efa13e5c8475c932b3eff5eb Binary files /dev/null and b/imgs/candidates/000000210299.jpg differ diff --git a/imgs/candidates/000000221708.jpg b/imgs/candidates/000000221708.jpg new file mode 100644 index 0000000000000000000000000000000000000000..055df134ed1bc4606533f5f5d7f5576f58aac08a Binary files /dev/null and b/imgs/candidates/000000221708.jpg differ diff --git a/imgs/candidates/000000223747.jpg b/imgs/candidates/000000223747.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99043ec9e33705f0e3757749ce85598b28bd1fa2 Binary files /dev/null and b/imgs/candidates/000000223747.jpg differ diff --git a/imgs/candidates/000000226111.jpg b/imgs/candidates/000000226111.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b020594a9ef3fbc565a361adf3a40c3bc6262e70 Binary files /dev/null and b/imgs/candidates/000000226111.jpg differ diff --git a/imgs/candidates/000000226984.jpg b/imgs/candidates/000000226984.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20bc71a613d9d7efc5acddc7cfd4345f79ccb928 Binary files /dev/null and b/imgs/candidates/000000226984.jpg differ diff --git a/imgs/candidates/000000252294.jpg b/imgs/candidates/000000252294.jpg new file mode 100644 index 0000000000000000000000000000000000000000..be7ce02528ef5cfd3376d7b9b3fb495cfa906199 Binary files /dev/null and b/imgs/candidates/000000252294.jpg differ diff --git a/imgs/candidates/000000256941.jpg b/imgs/candidates/000000256941.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b7030b8ee30a79f71f5d2b75eafd7eac7fd9425a Binary files /dev/null and b/imgs/candidates/000000256941.jpg differ diff --git a/imgs/candidates/000000280710.jpg b/imgs/candidates/000000280710.jpg new file mode 100644 index 0000000000000000000000000000000000000000..856d602fdff493fbba906b5090e4d88a1db285a2 Binary files /dev/null and b/imgs/candidates/000000280710.jpg differ diff --git a/imgs/candidates/000000281179.jpg b/imgs/candidates/000000281179.jpg new file mode 100644 index 0000000000000000000000000000000000000000..90bb0a7341a2006c4e9572631d186f4ebc015c7c Binary files /dev/null and b/imgs/candidates/000000281179.jpg differ diff --git a/imgs/candidates/000000283717.jpg b/imgs/candidates/000000283717.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e20cd3433bd2f55f69c9a2df2e9a07db5c67d78 Binary files /dev/null and b/imgs/candidates/000000283717.jpg differ diff --git a/imgs/candidates/000000284445.jpg b/imgs/candidates/000000284445.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ed6a2b51afb0e57300ef6367acccd101aa07096a Binary files /dev/null and b/imgs/candidates/000000284445.jpg differ diff --git a/imgs/candidates/000000287649.jpg b/imgs/candidates/000000287649.jpg new file mode 100644 index 0000000000000000000000000000000000000000..309a92e1b956482c4ff37b64ee942cca9d8224c6 Binary files /dev/null and b/imgs/candidates/000000287649.jpg differ diff --git a/imgs/candidates/000000289343.jpg b/imgs/candidates/000000289343.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8c9307218b97d37c3f6ef1bfe1c335dbf4e5bcf8 Binary files /dev/null and b/imgs/candidates/000000289343.jpg differ diff --git a/imgs/candidates/000000295809.jpg b/imgs/candidates/000000295809.jpg new file mode 100644 index 0000000000000000000000000000000000000000..66200c3d6b1754ab3b3904c0972b8d16352a8117 Binary files /dev/null and b/imgs/candidates/000000295809.jpg differ diff --git a/imgs/candidates/000000334371.jpg b/imgs/candidates/000000334371.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b47ac6d2c8382203bdeff8efcc9e7c84cc4c209c Binary files /dev/null and b/imgs/candidates/000000334371.jpg differ diff --git a/imgs/candidates/000000350054.jpg b/imgs/candidates/000000350054.jpg new file mode 100644 index 0000000000000000000000000000000000000000..21aa55221ec0642eeb1f609c76d402d9372f108d Binary files /dev/null and b/imgs/candidates/000000350054.jpg differ diff --git a/imgs/candidates/000000361621.jpg b/imgs/candidates/000000361621.jpg new file mode 100644 index 0000000000000000000000000000000000000000..45fb69ee870346bf592368dd4d5d3cb64eaee8ba Binary files /dev/null and b/imgs/candidates/000000361621.jpg differ diff --git a/imgs/candidates/000000369503.jpg b/imgs/candidates/000000369503.jpg new file mode 100644 index 0000000000000000000000000000000000000000..83f4e9dad380a4f2702a90e8b560ebb8f6da005c Binary files /dev/null and b/imgs/candidates/000000369503.jpg differ diff --git a/imgs/candidates/000000384661.jpg b/imgs/candidates/000000384661.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1c39d5333caeab4b668c24e7980df23fc72b2bec Binary files /dev/null and b/imgs/candidates/000000384661.jpg differ diff --git a/imgs/candidates/000000385997.jpg b/imgs/candidates/000000385997.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d908e725033c81ac891919f1529a3cfc48c26d04 Binary files /dev/null and b/imgs/candidates/000000385997.jpg differ diff --git a/imgs/candidates/000000398377.jpg b/imgs/candidates/000000398377.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33900cabc587a7fe24888297aa6d9295663104fe Binary files /dev/null and b/imgs/candidates/000000398377.jpg differ diff --git a/imgs/candidates/000000402473.jpg b/imgs/candidates/000000402473.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05dd2cff1d2b1f538f2f31a5a2c68d42c92b7570 Binary files /dev/null and b/imgs/candidates/000000402473.jpg differ diff --git a/imgs/candidates/000000426166.jpg b/imgs/candidates/000000426166.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e48e2ad70a4bdd943eed05b77b036de4a7c8a051 Binary files /dev/null and b/imgs/candidates/000000426166.jpg differ diff --git a/imgs/candidates/000000441247.jpg b/imgs/candidates/000000441247.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9026cf75d7f14c2be32d771811aa7b47c1571c91 Binary files /dev/null and b/imgs/candidates/000000441247.jpg differ diff --git a/imgs/candidates/000000455157.jpg b/imgs/candidates/000000455157.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6c8acb1c1a77f3f6101ea525e56527f9cf7781f Binary files /dev/null and b/imgs/candidates/000000455157.jpg differ diff --git a/imgs/candidates/000000492077.jpg b/imgs/candidates/000000492077.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fba26f0c25ae9cf8055391532ef4686577bfcc15 Binary files /dev/null and b/imgs/candidates/000000492077.jpg differ diff --git a/imgs/candidates/000000496854.jpg b/imgs/candidates/000000496854.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ec1802a8992a63055c9cbd262ba5a8ca63ea0801 Binary files /dev/null and b/imgs/candidates/000000496854.jpg differ diff --git a/imgs/candidates/000000501523.jpg b/imgs/candidates/000000501523.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce9058c1539d2ea574fe6b9379c1ea327eab3ac9 Binary files /dev/null and b/imgs/candidates/000000501523.jpg differ diff --git a/imgs/candidates/000000530099.jpg b/imgs/candidates/000000530099.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8b2fff168e86ad44a30a0a5f97904207536477f Binary files /dev/null and b/imgs/candidates/000000530099.jpg differ diff --git a/imgs/candidates/000000530162.jpg b/imgs/candidates/000000530162.jpg new file mode 100644 index 0000000000000000000000000000000000000000..55fb073d806ff9d5d9b6f8d71d295b9a158e1052 Binary files /dev/null and b/imgs/candidates/000000530162.jpg differ diff --git a/imgs/candidates/000000530836.jpg b/imgs/candidates/000000530836.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e0b603c4883da89ce2193670cb7555badf9ff2ea Binary files /dev/null and b/imgs/candidates/000000530836.jpg differ diff --git a/imgs/candidates/000000535306.jpg b/imgs/candidates/000000535306.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b1869aed23639c2e03842acbc0cec40d90f14042 Binary files /dev/null and b/imgs/candidates/000000535306.jpg differ diff --git a/imgs/candidates/000000578489.jpg b/imgs/candidates/000000578489.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6089bb7765c48ddaf03065c3fea4ad2f3eb52113 Binary files /dev/null and b/imgs/candidates/000000578489.jpg differ diff --git a/imgs/queries/191845_288306.jpg b/imgs/queries/191845_288306.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c6dcddb605c34348a5936b31d42ac547a8827b3f Binary files /dev/null and b/imgs/queries/191845_288306.jpg differ diff --git a/imgs/queries/191845_288306.txt b/imgs/queries/191845_288306.txt new file mode 100644 index 0000000000000000000000000000000000000000..1cc22e9c15a5891a419cf7f1cbdbff2639fc75bb --- /dev/null +++ b/imgs/queries/191845_288306.txt @@ -0,0 +1,2 @@ +<|image_1|> + Find me an image containing the object in the given image with the following caption: The umbrella is surrounded by other pedestrians crossing the street. \ No newline at end of file diff --git a/imgs/queries/210299_218627.jpg b/imgs/queries/210299_218627.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a57ece71e6100aaa55f15b2fa5c1909800a14e65 Binary files /dev/null and b/imgs/queries/210299_218627.jpg differ diff --git a/imgs/queries/210299_218627.txt b/imgs/queries/210299_218627.txt new file mode 100644 index 0000000000000000000000000000000000000000..7cee7bcf6688992e2e1c2ef7830c0646fc210cc9 --- /dev/null +++ b/imgs/queries/210299_218627.txt @@ -0,0 +1,2 @@ +<|image_1|> + Find me an image containing the object in the given image with the following caption: The person is riding a bicycle on a paved surface, with a shadow cast on the ground. \ No newline at end of file diff --git a/imgs/queries/221708_330943.jpg b/imgs/queries/221708_330943.jpg new file mode 100644 index 0000000000000000000000000000000000000000..01f40c06a718a9b08008f938a0977d8bd64d9407 Binary files /dev/null and b/imgs/queries/221708_330943.jpg differ diff --git a/imgs/queries/221708_330943.txt b/imgs/queries/221708_330943.txt new file mode 100644 index 0000000000000000000000000000000000000000..61a94b809c6b1f358fa10636a6e0a751b661ad6c --- /dev/null +++ b/imgs/queries/221708_330943.txt @@ -0,0 +1,2 @@ +<|image_1|> + Find me an image containing the object in the given image with the following caption: The refrigerator is situated near the sink, with its side facing a wooden dining table. \ No newline at end of file diff --git a/imgs/queries/402473_273417.jpg b/imgs/queries/402473_273417.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3e3effb8fe1ea583fb5d935898a5adfc60103fd8 Binary files /dev/null and b/imgs/queries/402473_273417.jpg differ diff --git a/imgs/queries/402473_273417.txt b/imgs/queries/402473_273417.txt new file mode 100644 index 0000000000000000000000000000000000000000..de68b61dd19f3fe535759391ed59ee60b7b175d2 --- /dev/null +++ b/imgs/queries/402473_273417.txt @@ -0,0 +1,2 @@ +<|image_1|> + Find me an image containing the object in the given image with the following caption: The cat is lying on a white surface near a black and white cat. \ No newline at end of file diff --git a/imgs/queries/441247_101686.jpg b/imgs/queries/441247_101686.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8c6cd3d9931bf1e9d5b8ce137a57602e8539ac26 Binary files /dev/null and b/imgs/queries/441247_101686.jpg differ diff --git a/imgs/queries/441247_101686.txt b/imgs/queries/441247_101686.txt new file mode 100644 index 0000000000000000000000000000000000000000..9cce5689d19282039672b8fac14392d4b927e752 --- /dev/null +++ b/imgs/queries/441247_101686.txt @@ -0,0 +1,2 @@ +<|image_1|> + Find me an image containing the object in the given image with the following caption: A woman sits on the chair in the middle of the room. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..00bcf5b9756d067093bd3aab48e07b6577b7fc0d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch==2.4.0 +transformers==4.46.3 +accelerate==0.29.1 +datasets==2.20.0 +numpy==1.26.4 +sentencepiece==0.1.99 +timm==0.9.12 +tqdm==4.67.1 +peft==0.11.1 +einops==0.6.1 +gradio==5.21.0 diff --git a/src/__pycache__/arguments.cpython-310.pyc b/src/__pycache__/arguments.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8642cc519bf4689803eb8e198537a04ddaa01f1c Binary files /dev/null and b/src/__pycache__/arguments.cpython-310.pyc differ diff --git a/src/__pycache__/collator.cpython-310.pyc b/src/__pycache__/collator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..137926f5dd2e6684607e7075a90482a6ccd959bf Binary files /dev/null and b/src/__pycache__/collator.cpython-310.pyc differ diff --git a/src/__pycache__/dataset.cpython-310.pyc b/src/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3981d84db0731e73a3d85f3395e8b08dc70e0026 Binary files /dev/null and b/src/__pycache__/dataset.cpython-310.pyc differ diff --git a/src/__pycache__/loss.cpython-310.pyc b/src/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd64ef02641739164834d486f02f0ceb0f67a793 Binary files /dev/null and b/src/__pycache__/loss.cpython-310.pyc differ diff --git a/src/__pycache__/model.cpython-310.pyc b/src/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14adfb9a2fd5e1548514990a2560d79d19a50cde Binary files /dev/null and b/src/__pycache__/model.cpython-310.pyc differ diff --git a/src/__pycache__/trainer.cpython-310.pyc b/src/__pycache__/trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ace96fb54102d721dfa365d0c7927229f9af11e8 Binary files /dev/null and b/src/__pycache__/trainer.cpython-310.pyc differ diff --git a/src/arguments.py b/src/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0c877f201e1050e90ea6ebebd1b6a2aa66ab1e --- /dev/null +++ b/src/arguments.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass, field +from transformers import TrainingArguments +from typing import List + + +@dataclass +class ModelArguments: + model_name: str = field( + metadata={"help": "huggingface model name or path"} + ) + model_backbone: str = field( + metadata={"help": "vlm backbone"} + ) + processor_name: str = field( + default=None, metadata={"help": "processor_name, huggingface model name or path"} + ) + model_type: str = field( + default=None, metadata={"help": "lavis model type"} + ) + checkpoint_path: str = field( + default=None, metadata={"help": "a local model path"} + ) + pooling: str = field( + default='last', + metadata={"help": "pooling method for encoder"} + ) + normalize: bool = field( + default=False, + metadata={"help": "normalize query and passage representations"} + ) + temperature: float = field( + default=0.02, + metadata={"help": "temperature for softmax"} + ) + lora: bool = field( + default=False, metadata={"help": "do parameter-efficient fine-tuning with lora"} + ) + lora_r: int = field( + default=16, + metadata={"help": "lora r"} + ) + lora_alpha: int = field( + default=64, + metadata={"help": "lora alpha"} + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "lora dropout"} + ) + lora_target_modules: str = field( + default="qkv_proj,o_proj,gate_up_proj,down_proj,k_proj,q_proj,out_proj,v_proj", + metadata={"help": "lora target modules"} + ) + num_crops: int = field( + default=16, + metadata={"help": "number of crops used in image encoder"} + ) + + +@dataclass +class DataArguments: + dataset_name: str = field( + default=None, metadata={"help": "huggingface dataset name"} + ) + subset_name: List[str] = field( + default=None, metadata={"help": "Useful for datasets with subsets"} + ) + dataset_split: str = field( + default='train', metadata={"help": "dataset split"} + ) + num_sample_per_subset: int = field( + default=100, metadata={"help": "number of training samples per subset"} + ) + num_samples: int = field( + default=None, metadata={"help": "number of total training samples"} + ) + image_dir: str = field( + default=None, metadata={"help": "Image directory path"} + ) + encode_output_path: str = field( + default=None, metadata={"help": "encode output path"} + ) + max_len: int = field( + default=128, metadata={"help": "The maximum total input sequence length after tokenization."}, + ) + embedding_type: str = field( + default="", metadata={"help": "embedding type"} + ) + randaugment: bool = field( + default=False, metadata={"help": "use randaugment"} + ) + +@dataclass +class TrainingArguments(TrainingArguments): + image_encoder_freeze: bool = field( + default=False, metadata={"help": "huggingface model name"} + ) + output_dir: str = field( + default=None, metadata={"help": "directory for saving trained models"} + ) + project_name: str = field( + default=None, metadata={"help": "project name"} + ) + + logging_steps: int = field( + default=1, metadata={"help": "logging steps"} + ) + num_train_epochs: int = field( + default=1, metadata={"help": "number of training epochs"} + ) + grad_cache: bool = field( + default=False, metadata={"help": "Use gradient cache update"}) + gc_q_chunk_size: int = field( + default=2, metadata={"help": "query side subset size"}) + gc_p_chunk_size: int = field( + default=2, metadata={"help": "target side subset size"}) + hard_neg: bool = field( + default=False, metadata={"help": "Use hard negative samples"} + ) + wandb: bool = field( + default=False, metadata={"help": "Use weight and bias"} + ) + resume_from_checkpoint: str = field( + default=None, metadata={"help": "resume ckpt path if needed"} + ) + +@dataclass +class MTEBArguments: + task_types: List[str] = field( + default=None, metadata={"help": ""} + ) + tasks: List[str] = field( + default=None, metadata={"help": ""} + ) diff --git a/src/biencoder_gc.py b/src/biencoder_gc.py new file mode 100644 index 0000000000000000000000000000000000000000..01c25aff55445b996e46277e8f38acb0855ed7dc --- /dev/null +++ b/src/biencoder_gc.py @@ -0,0 +1,354 @@ +from typing import List, Union, Callable, Any, Dict +from contextlib import nullcontext +from itertools import repeat +from collections import UserDict +import logging + +import torch +from torch import nn, Tensor +from torch.cuda.amp import GradScaler, autocast + +from grad_cache.context_managers import RandContext +from src.model.biencoder import BiEncoder +from utils import dist_utils +logger = logging.getLogger(__name__) + + +def is_binary_tensor(tensor): + unique_elements = torch.unique(tensor) + return torch.equal(unique_elements, torch.tensor([0, 1], dtype=tensor.dtype).to(unique_elements.device)) + + +class BiEncoderGradCache(nn.Module): + """ + Gradient Cache class. Implements input chunking, first graph-less forward pass, Gradient Cache creation, second + forward & backward gradient computation. Optimizer step is not included. Native torch automatic mixed precision is + supported. User needs to handle gradient unscaling and scaler update after a gradeitn cache step. + """ + def __init__( + self, + models: List[nn.Module], + chunk_sizes: Union[int, List[int]], + loss_fns, + split_input_fn: Callable[[Any, int], Any] = None, + get_rep_fn: Callable[..., Tensor] = None, + fp16_or_bf16: bool = False, + dtype=torch.float32, + scaler: GradScaler = None, + ): + """ + Initialize the Gradient Cache class instance. + :param models: A list of all encoder models to be updated by the current cache. + :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model. + :param loss_fns: A dict of loss functions that takes arbitrary numbers of representation tensors and + arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations + in the autograd graph, which are later relied upon to create the gradient cache. + :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this + class will try its best to split the inputs of supported types. See `split_inputs` function. + :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If + not provided, the generic output is assumed to be the representation tensor. + :param fp16_or_bf16: If True, run mixed precision training, which requires scaler to also be set. + :param scaler: A GradScaler object for automatic mixed precision training. + """ + super(BiEncoderGradCache, self).__init__() + self.models = models + self.q_encoder = models[0] + self.k_encoder = models[1] + + if isinstance(chunk_sizes, int): + self.chunk_sizes = [chunk_sizes for _ in range(len(models))] + else: + self.chunk_sizes = chunk_sizes + + self.split_input_fn = split_input_fn + self.get_rep_fn = get_rep_fn + self.loss_fns = loss_fns + + self.fp16_or_bf16 = fp16_or_bf16 + self.dtype = dtype + self.scaler = scaler + + self._get_input_tensors_strict = False + + def __call__(self, *args, **kwargs): + """ + Call the cache_step function. + :return: Current step loss. + """ + return self.cache_step(*args, **kwargs) + + def split_inputs(self, model_input, chunk_size: int) -> List: + """ + Split input into chunks. Will call user provided `split_input_fn` if specified. Otherwise, + it can handle input types of tensor, list of tensors and dictionary of tensors. + :param model_input: Generic pytorch input. + :param chunk_size: Size of each chunk. + :return: A list of chunked pytorch input. + """ + # delegate splitting to user provided function + if self.split_input_fn is not None: + return self.split_input_fn(model_input, chunk_size) + + if isinstance(model_input, (dict, UserDict)) and all(isinstance(x, Tensor) for x in model_input.values()): + keys = list(model_input.keys()) + chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] + return [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] + + elif isinstance(model_input, list) and all(isinstance(x, Tensor) for x in model_input): + chunked_x = [t.split(chunk_size, dim=0) for t in model_input] + return [list(s) for s in zip(*chunked_x)] + + elif isinstance(model_input, Tensor): + return list(model_input.split(chunk_size, dim=0)) + + elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: + args_chunks = self.split_inputs(model_input[0], chunk_size) + kwargs_chunks = self.split_inputs(model_input[1], chunk_size) + return list(zip(args_chunks, kwargs_chunks)) + + else: + raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}') + + def get_input_tensors(self, model_input) -> List[Tensor]: + """ + Recursively go through model input and grab all tensors, which are then used to record current device random + states. This method will do its best to parse types of Tensor, tuple, list, dict and UserDict. Other types will + be ignored unless self._get_input_tensors_strict is set to True, in which case an exception will be raised. + :param model_input: input to model + :return: all torch tensors in model_input + """ + if isinstance(model_input, Tensor): + return [model_input] + + elif isinstance(model_input, (list, tuple)): + return sum((self.get_input_tensors(x) for x in model_input), []) + + elif isinstance(model_input, (dict, UserDict)): + return sum((self.get_input_tensors(x) for x in model_input.values()), []) + + elif self._get_input_tensors_strict: + raise NotImplementedError(f'get_input_tensors not implemented for type {type(model_input)}') + + else: + return [] + + def model_call(self, model: nn.Module, model_input): + """ + Literally call the model's __call__ method. + :param model: model to be called + :param model_input: input to the model call + :return: model output + """ + with autocast('cuda', dtype=self.dtype) if self.fp16_or_bf16 else nullcontext(): + if isinstance(model_input, Tensor): + return model(model_input) + elif isinstance(model_input, list): + return model(*model_input) + elif isinstance(model_input, (dict, UserDict)): + return model(**model_input) + elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: + model_args, model_kwargs = model_input + return model(*model_args, **model_kwargs) + elif isinstance(model_input, tuple): + return model(*model_input) + else: + raise NotImplementedError + + def get_reps(self, model_out) -> Tensor: + """ + Return representation tensor from generic model output + :param model_out: generic model output + :return: a single tensor corresponding to the model representation output + """ + if self.get_rep_fn is not None: + return self.get_rep_fn(model_out) + else: + return model_out + + def compute_loss(self, loss_mapping=None, *reps: Tensor, **loss_kwargs) -> Tensor: + """ + Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models + registered in this GradCache class instance. + :param reps: Representations for computing the loss. + reps[0]: query vector, shape=[B,H] + reps[1]: doc vector, shape=[B*num_neg,H] + :param loss_kwargs: Keyword arguments input to the loss function. + :return: the loss tensor. + """ + if loss_mapping is None: + loss_fn = self.loss_fns["distributed_inbatch_contrastive"] + loss, loss_details = loss_fn(*reps, **loss_kwargs) + else: + # print('start to compute loss') + bsz, hdim = reps[0].shape + loss, loss_details = 0.0, {} + preds = torch.zeros(bsz * dist_utils.get_world_size(), dtype=torch.long, device=reps[0].device) + labels = torch.zeros(bsz * dist_utils.get_world_size(), dtype=torch.long, device=reps[0].device) + for loss_name, data_idxs in loss_mapping.items(): + # print("get loss_name, data_indxs", loss_name, data_idxs) + data_idxs = torch.tensor(data_idxs).to(reps[0].device) + q = reps[0].index_select(0, index=data_idxs) + if len(reps[1].shape) == 1 or is_binary_tensor(reps[1]): + # in cases d is one-hot label for classification loss + d = reps[1] + else: + d = reps[1].view(bsz, -1, hdim).index_select(0, index=data_idxs) + d = d.view(-1, hdim) + # print_rank(f"loss_name={loss_name}, q.shape={q.shape}, d.shape={d.shape}") + _loss, _loss_details = self.loss_fns[loss_name](q, d, **loss_kwargs) + loss += _loss + # print("finish loss fns") + if "labels" in _loss_details: + # since we compute losses per group/loss-type (stored in loss_mapping), so the data is reordered by group and we need to gather preds/labels + if torch.distributed.is_initialized(): + data_idxs = data_idxs + bsz * dist_utils.get_rank() + # print('start to gather data index') + data_idxs = dist_utils.dist_gather(data_idxs) + # print('finish gather the data index') + # TODO, this might not work correctly for classification loss + preds.index_copy_(0, data_idxs, _loss_details["preds"]) + labels.index_copy_(0, data_idxs, _loss_details["labels"]) + loss_details["preds"] = preds + loss_details["labels"] = labels + # print('finish loss', data_idxs) + # print('finish to compute_loss') + return loss, loss_details + + def forward_no_grad( + self, + model: nn.Module, + model_inputs, + ) -> [Tensor, List[RandContext]]: + """ + The first forward pass without gradient computation. + :param model: Encoder model. + :param model_inputs: Model input already broken into chunks. A tuple of two lists (ids, masks) + :return: A tuple of a) representations and b) recorded random states. + """ + rnd_states = [] + model_reps = [] + + with torch.no_grad(): + for x in zip(*model_inputs): + rnd_states.append(RandContext(*self.get_input_tensors(x))) + y = self.model_call(model, x) + model_reps.append(self.get_reps(y)) + + # concatenate all sub-batch representations + model_reps = torch.cat(model_reps, dim=0) + return model_reps, rnd_states + + def build_cache(self, deepspeed=None, loss_mapping=None, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]: + """ + Compute the gradient cache + :param reps: Computed representations from all encoder models + :param loss_kwargs: Extra keyword arguments to the loss function + :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor + """ + new_reps = [] + for r in reps: + if isinstance(r, torch.Tensor) and r.ndim == 2: + new_reps.append(r.detach().requires_grad_()) + elif isinstance(r, list): + new_reps.append(torch.cat(r, dim=0)) + # reps = [r.detach().requires_grad_() for r in reps] + reps = tuple(new_reps) + with autocast(dtype=self.dtype) if self.fp16_or_bf16 else nullcontext(): + loss, loss_details = self.compute_loss(loss_mapping, *reps, **loss_kwargs) + + if deepspeed is None: + if self.scaler: + self.scaler.scale(loss).backward() + else: + loss.backward() + else: + deepspeed.backward(loss) + + cache = [r.grad for r in reps if len(r.shape) > 1 and not is_binary_tensor(r[0])] + + return cache, loss.detach(), loss_details + + def forward_backward( + self, + model: nn.Module, + model_inputs, + cached_gradients: List[Tensor], + random_states: List[RandContext], + no_sync_except_last: bool = False, + deepspeed = None, + ): + """ + Run the second forward and the backward pass to compute gradient for a model. + :param model: Encoder model. + :param model_inputs: Chunked input to the encoder model. + :param cached_gradients: Chunked gradient cache tensor for each input. + :param random_states: Each input's device random state during the first forward. + :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes + for the last sub-batch's forward-backward pass. + """ + if no_sync_except_last and deepspeed is None: + sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext] + else: + sync_contexts = [nullcontext for _ in range(len(model_inputs))] + + for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts): + with sync_context(): + with state: + y = self.model_call(model, x) + reps = self.get_reps(y) + + surrogate = torch.dot(reps.flatten(), gradient.flatten()) + if deepspeed is None: + surrogate.backward() + else: + deepspeed.backward(surrogate) + + def cache_step( + self, + inputs, + masks, + no_sync_except_last: bool = False, + deepspeed: object = None, + loss_mapping = None, + **loss_kwargs + ) -> Tensor: + """ + Run a cached step to compute gradient over the inputs. + :param model_inputs: Input to each encoder model. Should be in similar order as the class's model. + :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction + across processes for the last sub-batch's forward-backward pass. + :param loss_kwargs: Additional keyword arguments to the loss function. + :return: The current's loss. + """ + all_reps = [] + all_rnd_states = [] + + inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(inputs, self.chunk_sizes)] + masks = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(masks, self.chunk_sizes)] + + for model, input, mask in zip(self.models, inputs, masks): + if len(input[0].shape) == 1 or is_binary_tensor(input[0]): + # input is label + all_reps.append(input) + all_rnd_states.append(input) + else: + model_reps, rnd_states = self.forward_no_grad(model, model_inputs=(input, mask)) + all_reps.append(model_reps) + all_rnd_states.append(rnd_states) + + # print('start to build cache') + cache, loss, loss_details = self.build_cache(deepspeed, loss_mapping, *all_reps, **loss_kwargs) + cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] + + for model, input, mask, model_cache, rnd_states in zip(self.models, inputs, masks, cache, all_rnd_states): + self.forward_backward(model, model_inputs=list(zip(input, mask)), + cached_gradients=model_cache, random_states=rnd_states, + no_sync_except_last=no_sync_except_last, + deepspeed=deepspeed, + ) + + # print('finish forward backward') + log_stats = BiEncoder._report_train_metrics(q=all_reps[0], k=all_reps[1], + preds=loss_details["preds"], labels=loss_details["labels"], + loss_details=loss_details) + return loss, log_stats diff --git a/src/collator.py b/src/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..69f1b585028689434b9749549fca7a83125ff629 --- /dev/null +++ b/src/collator.py @@ -0,0 +1,421 @@ +import logging +from typing import List, Tuple +from dataclasses import dataclass +from transformers import ProcessorMixin, AutoProcessor, AutoTokenizer +from src.arguments import DataArguments, ModelArguments +import torch + + +logger = logging.getLogger(__name__) + + +@dataclass +class TrainCollator: + data_args: DataArguments + model_args: ModelArguments + processor: ProcessorMixin + + def __call__(self, examples): + """ + :param examples: [{qry:..., qry_image:..., pos_text:..., pos_image:...}] * batch_size + """ + # import pdb; pdb.set_trace() + qry_inputs = self._get_batch_inputs(examples, 0, 1) # qry_inputs: {'input_ids': tensor(batch_size, max_len), 'attention_mask': tensor(batch_size, max_len), 'pixel_values': tensor(batch_size, 4, 224, 224), 'image_sizes': tensor(batch_size, 2)} + pos_inputs = self._get_batch_inputs(examples, 2, 3) + if "hard_neg" in self.data_args.dataset_name: + hard_neg_inputs = self._get_batch_inputs(examples, 4, 5) + return qry_inputs, pos_inputs, hard_neg_inputs + return qry_inputs, pos_inputs + + def _get_batch_inputs(self, examples, text_idx, image_idx): + input_ids, pixel_values = [], [] + image_mask, image_sizes, image_grid_thw = [], [], [] + + for example in examples: + text, image = example[text_idx], example[image_idx] + has_image = image is not None + image_mask.append(1 if has_image else 0) + + # 统一processor调用逻辑 + if self.model_args.model_backbone == "llava_next": + inputs = self.processor( + text=text, + images=image if has_image else None, + return_tensors="pt", + max_length=self.data_args.max_len, + truncation=True + ) + elif self.model_args.model_backbone in ["qwen", "qwen2_vl"]: # Qwen系列 + inputs = self.processor( + text=[text], # Qwen需要列表输入 + images=[image] if has_image else None, + return_tensors="pt", + max_length=self.data_args.max_len, + truncation=True + ) + else: # Phi3/InternVL通用处理 + inputs = self.processor( + text=text, + images=[image] if has_image else None, + return_tensors="pt", + max_length=self.data_args.max_len, + truncation=True + ) + + # 统一输入格式处理 + if has_image: + if self.model_args.model_backbone == "qwen": + pixel_values.append(inputs['pixel_values'].unsqueeze(0)) + else: + pixel_values.append(inputs['pixel_values']) + + # 保持维度对齐原始逻辑 + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + + # 处理多模态元数据 + if "image_sizes" in inputs: + image_sizes.append(inputs['image_sizes']) + if "image_grid_thw" in inputs: + image_grid_thw.append(inputs['image_grid_thw']) + + # 保持原始填充逻辑 + input_ids = torch._C._nn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.processor.tokenizer.pad_token_id + ).squeeze(2) + + attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) + + # 构建返回字典 + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'image_mask': torch.tensor(image_mask, dtype=torch.float) # 保持与原始字段名一致 + } + + # 处理图像数据 + if any(image_mask): + if pixel_values: + inputs['pixel_values'] = torch.cat(pixel_values, dim=0) + if image_sizes: # LLaMA系列专用 + inputs['image_sizes'] = torch.cat(image_sizes, dim=0) + if image_grid_thw: # Phi3专用 + inputs['image_grid_thw'] = torch.cat(image_grid_thw, dim=0) + + # InternVL专用字段适配 + if self.model_args.model_backbone == "internvl_2_5": + inputs['image_flags'] = inputs['image_mask'].to(torch.long) # 模型需要long类型 + # del inputs['image_mask'] # 根据模型接口调整字段名 + + return inputs +""" + def _get_batch_inputs(self, examples, text_idx, image_idx): + input_ids, pixel_values, image_sizes, image_grid_thw = [], [], [], [] + image_mask = [] + image_exist = False + for example in examples: + text, image = example[text_idx], example[image_idx] # text: str, image: PIL.Image.Image(765*512) + if image is None: + image_mask.append(0) + if self.model_args.model_backbone == "llava_next": + inputs = self.processor(images=None, text=text, return_tensors="pt") + elif self.model_args.model_backbone == "qwen": + inputs = self.processor(text=[text], images=None, return_tensors="pt", + max_length=self.data_args.max_len, truncation=True) + else: # 'phi', 'internvl' + inputs = self.processor(text=text, images=None, return_tensors="pt", + max_length=self.data_args.max_len, truncation=True) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + else: + image_mask.append(1) + image_exist = True + if self.model_args.model_backbone == "llava_next": + inputs = self.processor(images=image, text=text, return_tensors="pt") + pixel_values.append(inputs['pixel_values']) + elif self.model_args.model_backbone == "qwen": + inputs = self.processor(text=[text], images=[image], return_tensors="pt", + max_length=self.data_args.max_len, truncation=True) + pixel_values.append(inputs['pixel_values'].unsqueeze(0)) + else: + inputs = self.processor(text=text, images=[image], return_tensors="pt", + max_length=self.data_args.max_len, truncation=True) + pixel_values.append(inputs['pixel_values']) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + if "image_sizes" in inputs: + image_sizes.append(inputs['image_sizes']) + if "image_grid_thw" in inputs: + image_grid_thw.append(inputs['image_grid_thw']) + + input_ids = torch._C._nn.pad_sequence( + input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id + ).squeeze(2) + attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) + + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + } + if image_exist: + inputs['image_mask'] = torch.Tensor(image_mask) + pixel_values = torch.cat(pixel_values, dim=0) + inputs['pixel_values'] = pixel_values + if image_sizes: + image_sizes = torch.cat(image_sizes, dim=0) + inputs['image_sizes'] = image_sizes + elif image_grid_thw: + image_grid_thw = torch.cat(image_grid_thw, dim=0) + inputs['image_grid_thw'] = image_grid_thw + + return inputs +""" + + +@dataclass +class EvalCollator: + data_args: DataArguments + model_args: ModelArguments + processor: ProcessorMixin + + def __call__(self, examples): + """ + :param examples: qry, qry_image, pos_text, pos_image + """ + inputs = self._get_batch_inputs(examples) + return inputs + """ + def _get_batch_inputs(self, examples): + input_ids, pixel_values, image_sizes = [], [], [] + image_exist = False + for example in examples: + text, image = example + if image is None: + if self.model_args.model_backbone == "llava_next": + inputs = self.processor(images=None, text=text, return_tensors="pt") + else: + inputs = self.processor(text, None, return_tensors="pt", max_length=self.data_args.max_len, + truncation=True) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + pixel_values.append(None) + image_sizes.append(None) + else: + image_exist = True + if self.model_args.model_backbone == "llava_next": + inputs = self.processor(images=image, text=text, return_tensors="pt") + else: + inputs = self.processor(text, [image], return_tensors="pt", max_length=self.data_args.max_len, truncation=True) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + pixel_values.append(inputs['pixel_values']) + image_sizes.append(inputs['image_sizes']) + + input_ids = torch._C._nn.pad_sequence( + input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id + ).squeeze(2) + attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) + + if not image_exist: + dummy_pixel_values = torch.zeros(input_ids.shape[0], 1) + dummy_image_sizes = torch.ones(input_ids.shape[0], 1) + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'pixel_values': dummy_pixel_values, + 'image_sizes': dummy_image_sizes, + } + else: + pixel_values_shape = list(set(v.shape for v in pixel_values if v is not None))[0] + pixel_values = [v if v is not None else torch.zeros(pixel_values_shape) for v in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + image_sizes_shape = list(set(v.shape for v in image_sizes if v is not None))[0] + image_sizes = [v if v is not None else torch.ones(image_sizes_shape) for v in image_sizes] + image_sizes = torch.cat(image_sizes, dim=0) + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'pixel_values': pixel_values, + 'image_sizes': image_sizes, + } + + return inputs + """ + def _get_batch_inputs(self, examples): + input_ids, pixel_values, image_sizes = [], [], [] + image_mask = [] # 为internvl2_5添加 + image_exist = False + for example in examples: + text, image = example + # print(text, image) + has_image = image is not None + image_mask.append(1 if has_image else 0) + + if self.model_args.model_backbone == "internvl_2_5": # Phi3/InternVL通用处理 + inputs = self.processor( + text=text, + images=[image] if has_image else None, + return_tensors="pt", + max_length=self.data_args.max_len, + truncation=True + ) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + if has_image: + pixel_values.append(inputs['pixel_values']) + if 'image_sizes' in inputs: + image_sizes.append(inputs['image_sizes']) + continue + + if image is None: + if self.model_args.model_backbone == "llava_next": + inputs = self.processor(images=None, text=text, return_tensors="pt") + else: + inputs = self.processor(text, None, return_tensors="pt", max_length=self.data_args.max_len, + truncation=True) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + pixel_values.append(None) + image_sizes.append(None) + else: + image_exist = True + if self.model_args.model_backbone == "llava_next": + inputs = self.processor(images=image, text=text, return_tensors="pt") + else: + inputs = self.processor(text, [image], return_tensors="pt", max_length=self.data_args.max_len, truncation=True) + input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) + pixel_values.append(inputs['pixel_values']) + image_sizes.append(inputs['image_sizes']) + + + input_ids = torch._C._nn.pad_sequence( + input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id + ).squeeze(2) + attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) + + if self.model_args.model_backbone == "internvl_2_5": + # 构建返回字典 + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'image_mask': torch.tensor(image_mask, dtype=torch.float) + } + + # 处理图像数据 + if any(image_mask): + if pixel_values: + inputs['pixel_values'] = torch.cat(pixel_values, dim=0) + if image_sizes: + inputs['image_sizes'] = torch.cat(image_sizes, dim=0) + # InternVL专用字段适配 + inputs['image_flags'] = inputs['image_mask'].to(torch.long) + del inputs['image_mask'] # 根据模型接口调整字段名 + else: + if not image_exist: + dummy_pixel_values = torch.zeros(input_ids.shape[0], 1) + dummy_image_sizes = torch.ones(input_ids.shape[0], 1) + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'pixel_values': dummy_pixel_values, + 'image_sizes': dummy_image_sizes, + } + else: + pixel_values_shape = list(set(v.shape for v in pixel_values if v is not None))[0] + pixel_values = [v if v is not None else torch.zeros(pixel_values_shape) for v in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + image_sizes_shape = list(set(v.shape for v in image_sizes if v is not None))[0] + image_sizes = [v if v is not None else torch.ones(image_sizes_shape) for v in image_sizes] + image_sizes = torch.cat(image_sizes, dim=0) + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'pixel_values': pixel_values, + 'image_sizes': image_sizes, + } + + return inputs + +@dataclass +class CLIPCollator: + data_args: DataArguments + vis_processors: AutoProcessor + txt_processors: AutoTokenizer + + def __call__(self, examples): + """ + :param examples: qry, qry_image, pos_text, pos_image + """ + inputs = self._get_batch_inputs(examples) + return inputs + + def _get_batch_inputs(self, examples): + input_ids, pixel_values, attention_mask = [], [], [] + image_exist, text_exist = False, False + for example in examples: + text, image = example + if image is not None: + if image.mode == 'L': + image = image.convert('RGB') + image_inputs = self.vis_processors(images=image, return_tensors="pt") + image_exist = True + pixel_values.append(image_inputs['pixel_values']) + if text: + text_exist = True + text_inputs = self.txt_processors(text, padding=getattr(self.data_args, "padding", True), max_length=self.data_args.max_len, truncation=True, return_tensors="pt") + input_ids.append(text_inputs["input_ids"].squeeze(0)) + if text_exist: + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.txt_processors.pad_token_id + ) + attention_mask = input_ids.ne(self.txt_processors.pad_token_id) + if image_exist: + pixel_values = torch.cat(pixel_values, dim=0) + if text_exist and image_exist: + assert input_ids.size()[0]==pixel_values.size()[0] + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'pixel_values': pixel_values, + } + + return inputs + + +@dataclass +class OpenCLIPCollator: + data_args: DataArguments + vis_processors: AutoProcessor + txt_processors: AutoTokenizer + + def __call__(self, examples): + """ + :param examples: qry, qry_image, pos_text, pos_image + """ + inputs = self._get_batch_inputs(examples) + return inputs + + def _get_batch_inputs(self, examples): + input_ids, pixel_values, attention_mask = [], [], [] + image_exist, text_exist = False, False + for example in examples: + text, image = example + if image is not None: + if image.mode == 'L': + image = image.convert('RGB') + image_inputs = self.vis_processors(image).unsqueeze(0) + image_exist = True + pixel_values.append(image_inputs) + if text: + text_exist = True + text_inputs = self.txt_processors(text) + input_ids.append(text_inputs) + if text_exist: + input_ids = torch.cat(input_ids, dim=0) + attention_mask = input_ids.ne(0) + if image_exist: + pixel_values = torch.cat(pixel_values, dim=0) + if text_exist and image_exist: + assert input_ids.size()[0]==pixel_values.size()[0] + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'pixel_values': pixel_values, + } + + return inputs diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6960664601b3baa1895ce872237a1318afb00dfd --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,318 @@ +import random +from typing import List, Tuple +from itertools import islice +import datasets +from datasets import load_dataset, concatenate_datasets +from torch.utils.data import Dataset +from PIL import Image +import os +from torchvision.transforms import RandAugment + +# 定义 RandAugment 仅用于增强 +def get_randaugment_transform(n=2, m=9): + """ + 创建 RandAugment 增强器。 + + 参数: + - n: 每次随机选择的增强操作数量。 + - m: 每种增强操作的强度。 + + 返回: + - RandAugment 对象。 + """ + return RandAugment(num_ops=n, magnitude=m) + + +def add_prompt_template(data): + data["qry"] = f"<|image_1|>{data['qry']}" + data["pos_text"] = f"<|image_1|>{data['pos_text']}" + data["hard_neg_text"] = f"<|image_1|>{data['hard_neg_text']}" + return data + +Phi_Image_token = "<|image_1|>" +Llava_Image_token = "" +Qwen_Image_token = "<|image_pad|>" +Internvl_Image_token = "" +class TrainDataset(Dataset): + def __init__(self, data_args, model_args): + self.data_args = data_args + self.model_args = model_args + self.transform = None + if self.data_args.randaugment: + self.transform = get_randaugment_transform() # RandAugment 或其他增强器 + train_data = [] + + if data_args.subset_name is not None: + print(f"Loading {len(data_args.subset_name)} datasets: {data_args.subset_name}") + for subset in data_args.subset_name: + dataset_name = os.path.join(self.data_args.dataset_name, subset) + subset_data = load_dataset( + dataset_name, + split=f"{self.data_args.dataset_split}", + ) + train_data.append(subset_data) + self.train_data = concatenate_datasets(train_data) + self.train_data = self.train_data.shuffle(seed=42) + else: + train_data = load_dataset( + self.data_args.dataset_name, + split=f"{self.data_args.dataset_split}", + ) + if "hard_neg" in self.data_args.dataset_name: + # self.train_data = train_data.map(add_prompt_template, num_proc=8) + print(train_data) + else: + self.train_data = train_data + if self.data_args.num_samples: + # self.train_data = self.train_data[:self.data_args.num_samples] + self.train_data = self.train_data.select(range(self.data_args.num_samples)) + print(f"len of train_data: {len(self.train_data)}") + + def __len__(self): + return len(self.train_data) + + def _process_image(self, image, resolution): + if image is None: + return None + if resolution == "high": + image = image.resize((1344, 1344)) + elif resolution == "low": + image = image.resize((336, 336)) + elif resolution == "clip": + image = image.resize((224, 224)) + + return image + + def _get_image(self, img_path): + if img_path == "": + return None + if img_path.startswith('/'): + full_img_path = img_path + else: + full_img_path = os.path.join(self.data_args.image_dir, img_path) + image = Image.open(full_img_path) + if self.model_args.model_backbone == "llava_next": + # TODO: make it configurable + return self._process_image(image, "high") + elif self.model_args.model_backbone == "qwen": + return self._process_image(image, "low") + elif self.model_args.model_backbone == "internvl_2_5": + # TODO: make it configurable + return self._process_image(image, "high") + else: + return image + + def __getitem__(self, item) -> Tuple[str, List[str]]: + # qry_text, qry_image_path, pos_text, pos_image_path = ( + # self.train_data[item]["qry"], self.train_data[item]["qry_image_path"], + # self.train_data[item]["pos_text"], self.train_data[item]["pos_image_path"], + # ) + + # return (qry_text, self._get_image(qry_image_path), + # pos_text, self._get_image(pos_image_path)) + + data_item = self.train_data[item] + qry_text, qry_image_path, pos_text, pos_image_path = ( + data_item["qry"], data_item["qry_image_path"], + data_item["pos_text"], data_item["pos_image_path"], + ) + + qry_image = self._get_image(qry_image_path) + if self.transform: + qry_image = self.transform(qry_image) + + if self.model_args.model_backbone == "llava_next": + # Update image token + qry_text = qry_text.replace(Phi_Image_token, Llava_Image_token) + pos_text = pos_text.replace(Phi_Image_token, Llava_Image_token) + elif self.model_args.model_backbone == "qwen": + qry_text = qry_text.replace(Phi_Image_token, Qwen_Image_token) + pos_text = pos_text.replace(Phi_Image_token, Qwen_Image_token) + elif self.model_args.model_backbone == "internvl_2_5": + qry_text = qry_text.replace(Phi_Image_token, Internvl_Image_token) + pos_text = pos_text.replace(Phi_Image_token, Internvl_Image_token) + + if "hard_neg" in self.data_args.dataset_name: + hard_neg_text, hard_neg_image_path = ( + data_item["hard_neg_text"], data_item["hard_neg_image_path"], + ) + if self.model_args.model_backbone == "llava_next": + # Update image token + hard_neg_text = hard_neg_text.replace(Phi_Image_token, Llava_Image_token) + elif self.model_args.model_backbone == "internvl_2_5": + hard_neg_text = hard_neg_text.replace(Phi_Image_token, Internvl_Image_token) + return ( + qry_text, qry_image, + pos_text, self._get_image(pos_image_path), + hard_neg_text, self._get_image(hard_neg_image_path) + ) + + return ( + qry_text, qry_image, + pos_text, self._get_image(pos_image_path) + ) + + + + + +class EvalDataset(Dataset): + def __init__(self, data_args, model_args, subset, text_field, img_path_field): + """ + (text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path") + """ + self.data_args = data_args + self.model_args = model_args + + if data_args.subset_name is not None: + self.eval_data = load_dataset( + self.data_args.dataset_name, + subset, + split=self.data_args.dataset_split, + ) + else: + self.eval_data = load_dataset( + self.data_args.dataset_name, + split=self.data_args.dataset_split, + ) + print(f"len of eval_data: {len(self.eval_data)}") + self.paired_data = self.get_paired_data(text_field, img_path_field) + self.paired_dataset = datasets.Dataset.from_dict({ + "text": [pair["text"] for pair in self.paired_data], + "img_path": [pair["img_path"] for pair in self.paired_data] + }) + + def __len__(self): + return len(self.paired_dataset) + + def __getitem__(self, item): + text, img_path = self.paired_dataset[item]["text"], self.paired_dataset[item]["img_path"] + if self.model_args.model_backbone == "llava_next": + # Update llava image token + text = text.replace(Phi_Image_token, Llava_Image_token) + elif self.model_args.model_backbone == "qwen": + text = text.replace(Phi_Image_token, Qwen_Image_token) + elif self.model_args.model_backbone == "internvl_2_5": + text = text.replace(Phi_Image_token, Internvl_Image_token) + + return text, self._get_image(img_path), + + def _process_image(self, image, resolution): + if image is None: + return None + if resolution == "high": + image = image.resize((1344, 1344)) + else: + image = image.resize((336, 336)) + return image + + def _get_image(self, img_path): + if img_path == "": + return None + if img_path.startswith("/"): + full_img_path = img_path + else: + full_img_path = os.path.join(self.data_args.image_dir, img_path) + image = Image.open(full_img_path) + if self.model_args.model_backbone == "llava_next": + return self._process_image(image, "high") + elif self.model_args.model_backbone == "internvl_2_5": + return self._process_image(image, "high") + else: + return image + return image + + def get_paired_data(self, text_field, img_path_field): + """ + (text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path") + """ + unique_pair = set() + for row in self.eval_data: + if isinstance(row[text_field], str): + if row[text_field]: + unique_pair.add((row[text_field], row[img_path_field])) + else: + if isinstance(row[img_path_field], List): + for img_path in row[img_path_field]: + unique_pair.add((row[text_field], img_path)) + else: + unique_pair.add((row[text_field], row[img_path_field])) + elif isinstance(row[text_field], List): + assert isinstance(row[img_path_field], List) and len(row[img_path_field]) == len(row[text_field]) + for text, img_path in zip(row[text_field], row[img_path_field]): + unique_pair.add((text, img_path)) + + paired_data = [{"text": text, "img_path": img_path} for text, img_path in unique_pair] + return paired_data + + +class FlickrDataset(Dataset): + def __init__(self, modality, model_backbone): + self.model_backbone = model_backbone + self.modality = modality + self.raw_data = load_dataset("nlphuji/flickr_1k_test_image_text_retrieval", split="test") + if modality == "image": + self.eval_data, self.image_names = self.get_image_data() + else: + self.eval_data, self.image_names = self.get_text_data() + + def __len__(self): + return len(self.eval_data) + + def __getitem__(self, idx): + return self.eval_data[idx] + + def __getitem__(self, idx): + text, image = self.eval_data[idx] + if self.model_backbone == "llava_next": + # Update llava image token + text = text.replace(Phi_Image_token, Llava_Image_token) + image = self._process_image(image, "high") + return text, image + + def _process_image(self, image, resolution): + if image is None: + return None + if resolution == "high": + image = image.resize((1344, 1344)) + else: + image = image.resize((336, 336)) + return image + + def _get_image(self, img_path): + if img_path == "": + return None + full_img_path = os.path.join(self.data_args.image_dir, img_path) + image = Image.open(full_img_path) + if self.model_backbone == "llava_next": + return self._process_image(image, "high") + else: + return image + return image + + def get_image_data(self): + eval_data, image_names = [], [] + # i2t + inst = "<|image_1|> Find an image caption describing the given image." # llava-1344-step1k4, i2t=94.0, t2i=80.26 + # inst = "<|image_1|> Represent the given image for image caption retrieval." # llava-1344-step1k4, i2t=94.6, t2i=78.98 + # t2i + # inst = "<|image_1|> Represent the given image." # MSCOCO t2i + + for row in self.raw_data: + eval_data.append((inst, row["image"])) + image_names.append(row["filename"]) + return eval_data, image_names + + def get_text_data(self): + eval_data, image_names = [], [] + # i2t + inst = "" + # t2i + # inst = "Retrieve an image that matches the given caption: " + # inst = "Find me an everyday image that matches the given caption." # MSCOCO t2i + for row in self.raw_data: + for caption in row["caption"]: + # eval_data.append((caption, None)) + eval_data.append((inst + caption, None)) + image_names.append(row["filename"]) + return eval_data, image_names diff --git a/src/dist_utils.py b/src/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be6655edff2680523f91bc581fa67ce334ece96d --- /dev/null +++ b/src/dist_utils.py @@ -0,0 +1,92 @@ +# Code adapted from SimCSE (https://github.com/princeton-nlp/SimCSE) governed by MIT license. + +# Copyright (c) 2023, Salesforce, Inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +import torch +import torch.distributed as dist + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all process, supporting backward propagation. + https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py + """ + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] + dist.all_gather(output, input) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + (input,) = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[dist.get_rank()] + return grad_out + + +def dist_gather(x: torch.tensor): + if not dist.is_initialized(): return x + if len(x.shape) == 0: + x = x.reshape(1) + x_gather = GatherLayer.apply(x) + x_gather = torch.cat(x_gather, dim=0) + return x_gather + + +@torch.no_grad() +def dist_gather_nograd(x: torch.tensor): + if not dist.is_initialized(): return x + x_gather = [torch.ones_like(x) for _ in range(get_world_size())] + dist.all_gather(x_gather, x, async_op=False) + x_gather = torch.cat(x_gather, dim=0) + return x_gather + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main(): + return get_rank() == 0 + + +def get_world_size(): + if not dist.is_initialized(): + return 1 + else: + return dist.get_world_size() + +def barrier(): + if dist.is_initialized(): + dist.barrier() + + +@torch.no_grad() +def varsize_gather_nograd(x: torch.Tensor): + """gather tensors of different sizes along the first dimension""" + if not dist.is_initialized(): + return x + + # determine max size + size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) + allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] + dist.all_gather(allsizes, size) + max_size = max([size.cpu().max() for size in allsizes]) + + padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) + padded[: x.shape[0]] = x + output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] + dist.all_gather(output, padded) + + output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] + output = torch.cat(output, dim=0) + + return output diff --git a/src/loss.py b/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4534775ec29c39526cd9054759b131c5532f94ed --- /dev/null +++ b/src/loss.py @@ -0,0 +1,103 @@ +from torch import Tensor +import torch.distributed as dist +import torch +import torch.nn.functional as F + + +class SimpleContrastiveLoss: + def __init__(self, temperature: float = 0.02): + self.temperature = temperature + + def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: + if target is None: + target_per_qry = y.size(0) // x.size(0) + target = torch.arange( + 0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long) + logits = torch.matmul(x, y.transpose(0, 1)) + loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) + return loss + + +class DistributedContrastiveLoss(SimpleContrastiveLoss): + def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02): + assert dist.is_initialized(), "Distributed training has not been properly initialized." + super().__init__() + self.word_size = dist.get_world_size() + self.rank = dist.get_rank() + self.scale_loss = scale_loss + self.temperature = temperature + + def __call__(self, x: Tensor, y: Tensor, **kwargs): + dist_x = self.gather_tensor(x) + dist_y = self.gather_tensor(y) + loss = super().__call__(dist_x, dist_y, **kwargs) + if self.scale_loss: + loss = loss * self.word_size + return loss + + def gather_tensor(self, t): + gathered = [torch.empty_like(t) for _ in range(self.word_size)] + dist.all_gather(gathered, t) + gathered[self.rank] = t + return torch.cat(gathered, dim=0) + + +class HardNegativeContrastiveLoss: + def __init__(self, temperature: float = 0.02): + self.temperature = temperature + + def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, reduction: str = 'mean') -> Tensor: + # x: query embeddings + # y: positive embeddings + # z: negative embeddings (optional) + + if z is None: # 如果没有负样本,退化为普通的对比学习 + target_per_qry = y.size(0) // x.size(0) + target = torch.arange( + 0, x.size(0) * target_per_qry, target_per_qry, + device=x.device, dtype=torch.long) + logits = torch.matmul(x, y.transpose(0, 1)) + loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) + return loss + + # 计算查询与正样本的相似度 + pos_logits = torch.matmul(x, y.transpose(0, 1)) # [batch_size, batch_size] + # 计算查询与负样本的相似度 + neg_logits = torch.matmul(x, z.transpose(0, 1)) # [batch_size, num_negs] + + # 将正负样本的相似度拼接在一起 + logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, batch_size + num_negs] + + # 创建目标标签(正样本的索引) + target = torch.arange(x.size(0), device=x.device) + + # 计算交叉熵损失 + loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) + return loss + + +class DistributedHardNegativeContrastiveLoss(HardNegativeContrastiveLoss): + def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02): + assert dist.is_initialized(), "Distributed training has not been properly initialized." + super().__init__(temperature=temperature) + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.scale_loss = scale_loss + + def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, **kwargs): + dist_x = self.gather_tensor(x) + dist_y = self.gather_tensor(y) + dist_z = self.gather_tensor(z) if z is not None else None + + loss = super().__call__(dist_x, dist_y, dist_z, **kwargs) + if self.scale_loss: + loss = loss * self.world_size + return loss + + def gather_tensor(self, t): + if t is None: + return None + gathered = [torch.empty_like(t) for _ in range(self.world_size)] + dist.all_gather(gathered, t) + gathered[self.rank] = t + return torch.cat(gathered, dim=0) diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0b69af0978f71a21268a5b02accc4606a584241c --- /dev/null +++ b/src/model.py @@ -0,0 +1,286 @@ +from typing import Dict, Optional +import torch +import torch.distributed as dist +from torch import nn, Tensor +from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig +from peft import LoraConfig, get_peft_model, PeftModel +from src.arguments import ModelArguments +from src.vlm_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM +from src.vlm_backbone.llava_next import LlavaNextForConditionalGeneration +from transformers import Qwen2VLForConditionalGeneration + + +class MMEBModel(nn.Module): + TRANSFORMER_CLS = AutoModelForCausalLM + + def __init__(self, + encoder: PreTrainedModel, + pooling: str = 'cls', + normalize: bool = False, + temperature: float = 1.0, + ): + super().__init__() + self.config = encoder.config + self.encoder = encoder + self.pooling = pooling + self.normalize = normalize + self.temperature = temperature + self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') + self.is_ddp = dist.is_initialized() + if self.is_ddp: + self.process_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def encode_input(self, input): + hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) + hidden_states = hidden_states.hidden_states[-1] + pooled_output = self._pooling(hidden_states, input['attention_mask']) + return pooled_output + + def _pooling(self, last_hidden_state, attention_mask): + if self.pooling == 'last' or self.pooling == 'eos': + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_state.shape[0] + reps = last_hidden_state[ + torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] + else: + raise NotImplementedError + if self.normalize: + reps = torch.nn.functional.normalize(reps, p=2, dim=-1) + return reps + + @classmethod + def build(cls, model_args: ModelArguments, **hf_kwargs): + # Loading the base model + lora_target_modules = None + if model_args.model_backbone == "llava_next": + config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + config.use_cache = False + config.padding_side = "left" + base_model = LlavaNextForConditionalGeneration.from_pretrained( + model_args.model_name, + config=config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + elif model_args.model_backbone == "qwen": + base_model = Qwen2VLForConditionalGeneration.from_pretrained( + model_args.model_name, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + base_model.padding_side = "right" + # Loading the base model + elif model_args.model_backbone == "phi35v": + config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + # config._attn_implementation = "eager" + config.attn_implementation = "flash_attention_2" + config.padding_side = "right" + config.use_cache = False + base_model = Phi3VForCausalLM.from_pretrained( + model_args.model_name, + config=config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + elif model_args.model_backbone == "internvl_2_5": + # from transformers import InternVLChatConfig, InternVLChatModel + from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name, + trust_remote_code=True + ) + # import pdb;pdb.set_trace() + + config = InternVLChatConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + # config.vision_config.image_size = data_args.force_image_size # 假设data_args包含图像尺寸 + config.use_flash_attn = False + base_model = InternVLChatModel.from_pretrained( + model_args.model_name, + config=config, + tokenizer=tokenizer, + # attn_implementation="flash_attention_2", + + torch_dtype=torch.bfloat16 + ) + lora_target_modules = base_model.get_lora_target_modules() + + + else: + config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + config.use_cache = False + config.padding_side = "right" + base_model = cls.TRANSFORMER_CLS.from_pretrained( + model_args.model_name, **hf_kwargs, config=config, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + trust_remote_code=True) + base_model.padding_side = "right" + + # # Print all model parameters + # import json + # import os + + # param_info = {} + # for name, param in base_model.named_parameters(): + # param_info[name] = { + # "shape": list(param.shape), + # "requires_grad": param.requires_grad + # } + + # with open('./model_parameters.json', 'w') as f: + # json.dump(param_info, f, indent=4) + # import pdb; pdb.set_trace() + if model_args.lora: + if lora_target_modules is None: + lora_target_modules = model_args.lora_target_modules.split(',') + lora_config = LoraConfig( + r=model_args.lora_r, + lora_alpha=model_args.lora_alpha, + target_modules=lora_target_modules, + lora_dropout=model_args.lora_dropout, + init_lora_weights="gaussian", + use_dora=True, + inference_mode=False + ) + lora_model = get_peft_model(base_model, lora_config) + model = cls( + encoder=lora_model, + pooling=model_args.pooling, + normalize=model_args.normalize, + temperature=model_args.temperature + ) + else: + model = cls( + encoder=base_model, + pooling=model_args.pooling, + normalize=model_args.normalize, + temperature=model_args.temperature + ) + return model + + @classmethod + def load(cls, model_args: ModelArguments, **hf_kwargs): + # Loading the base model + checkpoint_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name + if model_args.model_backbone == "llava_next": + config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + config.use_cache = False + base_model = LlavaNextForConditionalGeneration.from_pretrained( + model_args.model_name, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + # attn_implementation="flash_attention_2" + ) + base_model.padding_side = "left" + elif model_args.model_backbone == "phi35v": + # Loading the base model + config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + config.use_cache = False + config.padding_side = "right" + base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **hf_kwargs, config=config, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, trust_remote_code=True) + base_model.padding_side = "right" + elif model_args.model_backbone == "internvl_2_5": + print("loading model") + from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name, + trust_remote_code=True + ) + config = InternVLChatConfig.from_pretrained(model_args.model_name) + # config.vision_config.image_size = data_args.force_image_size # 假设data_args包含图像尺寸 + config.use_flash_attn = False + base_model = InternVLChatModel.from_pretrained( + model_args.model_name, + config=config, + tokenizer=tokenizer, + # attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16 + ) + else: + # Loading the base model + config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) + config.use_cache = False + config.padding_side = "right" + + base_model = cls.TRANSFORMER_CLS.from_pretrained( + checkpoint_path, **hf_kwargs, config=config, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + trust_remote_code=True) + base_model.padding_side = "right" + + # Building the model on top of the base + if model_args.lora: + print("loading lora parameters") + lora_config = LoraConfig.from_pretrained(checkpoint_path) + lora_model = PeftModel.from_pretrained(base_model, checkpoint_path, config=lora_config) + + merged_model = lora_model.merge_and_unload() + model = cls( + encoder=merged_model, + pooling=model_args.pooling, + normalize=model_args.normalize + ) + else: + model = cls( + encoder=base_model, + pooling=model_args.pooling, + normalize=model_args.normalize + ) + return model + + def save(self, output_dir: str): + self.encoder.save_pretrained(output_dir) + + def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, neg: Dict[str, Tensor] = None): + qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) + tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) + neg_reps = self.encode_input(neg) if neg else None # (bsz_per_device, dim) + + if qry_reps is None or tgt_reps is None: + return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} + + # Gather representations if using DDP + if self.is_ddp: + all_qry_reps = self._dist_gather_tensor(qry_reps) + all_tgt_reps = self._dist_gather_tensor(tgt_reps) + all_neg_reps = self._dist_gather_tensor(neg_reps) if neg_reps is not None else None + else: + all_qry_reps = qry_reps + all_tgt_reps = tgt_reps + all_neg_reps = neg_reps + + # Compute similarity scores + scores = self.compute_similarity(all_qry_reps, all_tgt_reps) + scores = scores.view(all_qry_reps.size(0), -1) + + # Add negative scores if available + if all_neg_reps is not None: + qry_neg_cos = self.compute_similarity(all_qry_reps, all_neg_reps) + scores = torch.cat([scores, qry_neg_cos], dim=1) + + # Compute loss + target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) + target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) + loss = self.cross_entropy(scores / self.temperature, target) + + if self.is_ddp: + loss = loss * self.world_size + + return loss + + def _dist_gather_tensor(self, t: Tensor): + t = t.contiguous() + all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] + dist.all_gather(all_tensors, t) + all_tensors[self.process_rank] = t + all_tensors = torch.cat(all_tensors, dim=0) + return all_tensors + + def compute_similarity(self, q_reps, p_reps): + return torch.matmul(q_reps, p_reps.transpose(0, 1)) diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2934acaae9cfdbf1ac60cabb4cc650d48cf2cae7 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,205 @@ +from transformers.trainer import Trainer, TRAINING_ARGS_NAME +import torch.distributed as dist +from typing import Optional +import os +import torch +from src.loss import SimpleContrastiveLoss, DistributedContrastiveLoss, HardNegativeContrastiveLoss, DistributedHardNegativeContrastiveLoss +from itertools import repeat +from grad_cache.grad_cache import GradCache + + +MAX_INPUT_ID = int(1e9) +LLAVA_IMAGE_TOKEN_ID = 32000 + +class MMEBTrainer(Trainer): + def __init__(self, *args, **kwargs): + super(MMEBTrainer, self).__init__(*args, **kwargs) + self.is_ddp = dist.is_initialized() + self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 + + def compute_loss(self, model, inputs, *args, **kwargs): + if self.args.hard_neg: + qry_inputs, tgt_inputs, neg_inputs = inputs + return model(qry=qry_inputs, tgt=tgt_inputs, neg=neg_inputs) + + qry_inputs, tgt_inputs = inputs + return model(qry=qry_inputs, tgt=tgt_inputs) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + os.makedirs(output_dir, exist_ok=True) + + if state_dict is None: + state_dict = self.model.state_dict() + prefix = 'encoder.' + assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) + state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} + self.model.encoder.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + +def split_dense_inputs(model_input: dict, chunk_size: int): + assert len(model_input) == 1 + arg_key = list(model_input.keys())[0] + arg_val = model_input[arg_key] + + keys = list(arg_val.keys()) + chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] + chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] + + return [{arg_key: c} for c in chunked_arg_val] + + +def split_vlm_inputs(model_input: dict, chunk_size: int): + assert len(model_input) == 1 + arg_key = list(model_input.keys())[0] + arg_val = model_input[arg_key] + keys = list(arg_val.keys()) + + # for input_ids and attention_mask, split directly + chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in ["input_ids", "attention_mask"]] + + # for pixel_values and image_sizes or any other image-related fields, need to split based on the position of images + + image_mask = "image_mask" if "image_mask" in keys else None + + if image_mask in keys: + row_contain_image = torch.nonzero(arg_val[image_mask], as_tuple=False).squeeze() # indicates which row in input_ids contain images + if image_mask == "image_mask": + keys.remove(image_mask) + num_chunks = len(chunked_tensors[0]) + chunk_image_count = [] + for chunk_idx in range(num_chunks): + chunk_image_count.append(torch.sum( + (row_contain_image >= chunk_idx * chunk_size) & (row_contain_image < (chunk_idx + 1) * chunk_size)).item()) + + if "pixel_values" in keys: + pixel_values = arg_val["pixel_values"] + chunked_tensors.append(torch.split(pixel_values, chunk_image_count)) + if "image_sizes" in keys: + image_sizes = arg_val["image_sizes"] + chunked_tensors.append(torch.split(image_sizes, chunk_image_count)) + if "image_grid_thw" in keys: + image_grid_thw = arg_val["image_grid_thw"] + chunked_tensors.append(torch.split(image_grid_thw, chunk_image_count)) + # 修改这里:image_flags 应该按照 chunk_size 分割,而不是 chunk_image_count + if "image_flags" in keys: + image_flags = arg_val["image_flags"] + chunked_tensors.append(torch.split(image_flags, chunk_size)) + keys.remove("image_flags") # 从keys中移除,后面单独处理 + + + chunked_arg_val = [] + for kk, tt in zip(repeat(keys), zip(*chunked_tensors)): + chunk_dict = {} + # 先添加基本字段 + if "pixel_values" in keys and tt[2].numel() == 0: # this chunk doesn't contain image + chunk_dict.update(dict(zip(kk[:2], tt[:2]))) + else: + chunk_dict.update(dict(zip(kk, tt))) + + # 如果有image_flags,添加对应的chunk + if "image_flags" in arg_val: + chunk_idx = len(chunked_arg_val) + chunk_dict["image_flags"] = chunked_tensors[-1][chunk_idx] + + chunked_arg_val.append(chunk_dict) + + return [{arg_key: c} for c in chunked_arg_val] + + +def get_dense_rep(x): + """ + Get either qry_reps or tgt_reps. + """ + if x["qry_reps"] is None: + return x["tgt_reps"] + else: + return x["qry_reps"] + + +class GradCacheTrainer(Trainer): + """ + Adapted from gradcache repo. + """ + def __init__(self, *args, **kwargs): + super(GradCacheTrainer, self).__init__(*args, **kwargs) + self.is_ddp = dist.is_initialized() + self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 + # loss_fn_cls = DistributedContrastiveLoss if self.is_ddp else SimpleContrastiveLoss + # 使用新的损失函数 + loss_fn_cls = DistributedHardNegativeContrastiveLoss if self.is_ddp else HardNegativeContrastiveLoss + loss_fn = loss_fn_cls(temperature=self.model.temperature) + + self.gc = GradCache( + models=[self.model, self.model], + chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], + loss_fn=loss_fn, + split_input_fn=split_vlm_inputs, + get_rep_fn=get_dense_rep, + fp16=self.args.fp16, + scaler=self.scaler if self.args.fp16 else None + ) + + def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: + model.train() + # 支持 hard negative 样本 + if self.args.hard_neg: + queries, passages, negatives = inputs + queries, passages, negatives = {'qry': queries}, {'tgt': passages}, {'neg': negatives} + + if self.args.local_rank == 0: + print(f"qry.shape={queries['qry']['input_ids'].shape}") + print(f"tgt.shape={passages['tgt']['input_ids'].shape}") + print(f"neg.shape={negatives['neg']['input_ids'].shape}") + if 'pixel_values' in queries['qry']: + print(f"qry_img.shape={queries['qry']['pixel_values'].shape}") + if 'pixel_values' in passages['tgt']: + print(f"tgt_img.shape={passages['tgt']['pixel_values'].shape}") + if 'pixel_values' in negatives['neg']: + print(f"neg_img.shape={negatives['neg']['pixel_values'].shape}") + + _distributed = self.args.local_rank > -1 + self.gc.models = [model, model, model] # 为 negative 样本添加一个模型 + loss = self.gc(queries, passages, negatives, no_sync_except_last=_distributed) + else: + queries, passages = inputs + queries, passages = {'qry': queries}, {'tgt': passages} + + if self.args.local_rank == 0: + print(f"qry.shape={queries['qry']['input_ids'].shape}") + print(f"tgt.shape={passages['tgt']['input_ids'].shape}") + if 'pixel_values' in queries['qry']: + print(f"qry_img.shape={queries['qry']['pixel_values'].shape}") + if 'pixel_values' in passages['tgt']: + print(f"tgt_img.shape={passages['tgt']['pixel_values'].shape}") + + _distributed = self.args.local_rank > -1 + self.gc.models = [model, model] + loss = self.gc(queries, passages, no_sync_except_last=_distributed) + + return loss / self._dist_loss_scale_factor + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + print(f"Saving model to {output_dir}") + os.makedirs(output_dir, exist_ok=True) + + if state_dict is None: + state_dict = self.model.state_dict() + prefix = 'encoder.' + assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) + state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} + self.model.encoder.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.model.encoder.config.to_json_file(os.path.join(output_dir, 'config.json')) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41df36ad1ba564af48f5bc7fd54b4cb469cdd637 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,21 @@ +def load_processor(model_args): + if model_args.model_backbone == "llava_next": + from src.vlm_backbone.llava_next.processing_llava_next import LlavaNextProcessor + processor = LlavaNextProcessor.from_pretrained( + model_args.processor_name if model_args.processor_name else model_args.model_name, + trust_remote_code=True) + elif model_args.model_backbone == "phi3_v": + from src.vlm_backbone.phi3_v.processing_phi3_v import Phi3VProcessor + processor = Phi3VProcessor.from_pretrained( + model_args.processor_name if model_args.processor_name else model_args.model_name, + trust_remote_code=True, + num_crops=model_args.num_crops, + ) + else: + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained( + model_args.processor_name if model_args.processor_name else model_args.model_name, + trust_remote_code=True, + ) + processor.tokenizer.padding_side = "right" + return processor diff --git a/src/vlm_backbone/intern_vl/__init__.py b/src/vlm_backbone/intern_vl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e82c53a2ef8131e883439728ac99fe31150211af --- /dev/null +++ b/src/vlm_backbone/intern_vl/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.vlm_backbone.intern_vl.configuration_internvl_chat import * +from src.vlm_backbone.intern_vl.modeling_intern_vit import * +from src.vlm_backbone.intern_vl.modeling_internvl_chat import * +from src.vlm_backbone.intern_vl.processing_internvl import * +# from src.vlm_backbone.intern_vl.modeling_internlm2 import * +# from src.vlm_backbone.intern_vl.modeling_internvl_chat_hico2 import * +# from src.vlm_backbone.intern_vl.tokenization_internlm2_fast import * +# from src.vlm_backbone.intern_vl.tokenization_internlm2 import * \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/__pycache__/__init__.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e93680371722458804516722509a0fefeb54ca5c Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/configuration_intern_vit.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/configuration_intern_vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c48481c868ddfd0151cc86cadbbaea49a6afea32 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/configuration_intern_vit.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/configuration_internlm2.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/configuration_internlm2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f34a7d6f42390c39e4e29674f7df14d86ba6108 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/configuration_internlm2.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/configuration_internvl_chat.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/configuration_internvl_chat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e092f7f2312c20a560eb25469bc94446e3ba738 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/configuration_internvl_chat.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/configuration_phi3.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/configuration_phi3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41edaf00f31333b991ce115d9f465a5f30384165 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/configuration_phi3.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/conversation.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e60b0d89090f9d8076a9e45782e19dba989e17 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/conversation.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/modeling_intern_vit.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/modeling_intern_vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63d2a49830a2b985c9135593f01adcd0ecd93288 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/modeling_intern_vit.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/modeling_internlm2.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/modeling_internlm2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba22488b3c9aacc2bf0a1744a672ff4208f414a2 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/modeling_internlm2.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/modeling_internvl_chat.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/modeling_internvl_chat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db64cb45fe549d30aba4456a33315c9d4a7abbe Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/modeling_internvl_chat.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/modeling_phi3.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/modeling_phi3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d91ab7bc97a59271f89cffc41fe723ba2a316044 Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/modeling_phi3.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/__pycache__/processing_internvl.cpython-310.pyc b/src/vlm_backbone/intern_vl/__pycache__/processing_internvl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a40a4d069287cb4a9d829daad4a208441ed7b96a Binary files /dev/null and b/src/vlm_backbone/intern_vl/__pycache__/processing_internvl.cpython-310.pyc differ diff --git a/src/vlm_backbone/intern_vl/configuration_intern_vit.py b/src/vlm_backbone/intern_vl/configuration_intern_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..714d8e758c6ccacff6d07b1900a9d7999bb30a57 --- /dev/null +++ b/src/vlm_backbone/intern_vl/configuration_intern_vit.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class InternVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to + instantiate a vision encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + Number of color channels in the input images (e.g., 3 for RGB). + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries and values in the self-attention layers. + hidden_size (`int`, *optional*, defaults to 3200): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 25): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 12800): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + qk_normalization (`bool`, *optional*, defaults to `True`): + Whether to normalize the queries and keys in the self-attention layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + use_flash_attn (`bool`, *optional*, defaults to `True`): + Whether to use flash attention mechanism. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for stochastic depth. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 0.1): + A factor for layer scale. + """ + + model_type = 'intern_vit_6b' + + def __init__( + self, + num_channels=3, + patch_size=14, + image_size=224, + qkv_bias=False, + hidden_size=3200, + num_attention_heads=25, + intermediate_size=12800, + qk_normalization=True, + num_hidden_layers=48, + use_flash_attn=True, + hidden_act='gelu', + norm_type='rms_norm', + layer_norm_eps=1e-6, + dropout=0.0, + drop_path_rate=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=0.1, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.dropout = dropout + self.drop_path_rate = drop_path_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.norm_type = norm_type + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.use_flash_attn = use_flash_attn + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if 'vision_config' in config_dict: + config_dict = config_dict['vision_config'] + + if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' + ) + + return cls.from_dict(config_dict, **kwargs) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/configuration_internlm2.py b/src/vlm_backbone/intern_vl/configuration_internlm2.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c75d8efc04c9cf318a4f8545b9bb24d4b2704c --- /dev/null +++ b/src/vlm_backbone/intern_vl/configuration_internlm2.py @@ -0,0 +1,150 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" InternLM2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate + an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + + """ + model_type = 'internlm2' + _auto_class = 'AutoConfig' + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act='silu', + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation='eager', + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = 'eager' + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' + f'got {self.rope_scaling}' + ) + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_factor = self.rope_scaling.get('factor', None) + if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}") \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/configuration_internvl_chat.py b/src/vlm_backbone/intern_vl/configuration_internvl_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..32a803e60670e899ea582c4c860a687bb4be2f91 --- /dev/null +++ b/src/vlm_backbone/intern_vl/configuration_internvl_chat.py @@ -0,0 +1,109 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import copy + +from .configuration_internlm2 import InternLM2Config +from .configuration_phi3 import Phi3Config +from transformers import AutoConfig, LlamaConfig, Qwen2Config +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from .configuration_intern_vit import InternVisionConfig + +logger = logging.get_logger(__name__) + + +class InternVLChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + pad2square=False, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {'architectures': ['InternVisionModel']} + logger.info('vision_config is None. Initializing the InternVisionConfig with default values.') + + if llm_config is None: + # TODO: There might still be a bug in transformers version 4.44 and above. + llm_config = {'architectures': ['']} + logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') + + self.vision_config = InternVisionConfig(**vision_config) + if llm_config['architectures'][0] == 'LlamaForCausalLM': + self.llm_config = LlamaConfig(**llm_config) + elif llm_config['architectures'][0] == 'InternLM2ForCausalLM': + self.llm_config = InternLM2Config(**llm_config) + elif llm_config['architectures'][0] == 'Phi3ForCausalLM': + self.llm_config = Phi3Config(**llm_config) + elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': + self.llm_config = Qwen2Config(**llm_config) + else: + raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0])) + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.pad2square = pad2square + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + + self.hidden_size = self.llm_config.hidden_size + # By default, we use tie_word_embeddings=False for models of all sizes. + self.tie_word_embeddings = False + self.llm_config.tie_word_embeddings = self.tie_word_embeddings + + logger.info(f'vision_select_layer: {self.select_layer}') + logger.info(f'ps_version: {self.ps_version}') + logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') + logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output['vision_config'] = self.vision_config.to_dict() + output['llm_config'] = self.llm_config.to_dict() + output['model_type'] = self.__class__.model_type + output['use_backbone_lora'] = self.use_backbone_lora + output['use_llm_lora'] = self.use_llm_lora + output['select_layer'] = self.select_layer + output['force_image_size'] = self.force_image_size + output['downsample_ratio'] = self.downsample_ratio + output['template'] = self.template + output['dynamic_image_size'] = self.dynamic_image_size + output['use_thumbnail'] = self.use_thumbnail + output['ps_version'] = self.ps_version + output['min_dynamic_patch'] = self.min_dynamic_patch + output['max_dynamic_patch'] = self.max_dynamic_patch + + return output \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/configuration_phi3.py b/src/vlm_backbone/intern_vl/configuration_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..43402b1eeffa2fe12953e2e77b5065fae2d45975 --- /dev/null +++ b/src/vlm_backbone/intern_vl/configuration_phi3.py @@ -0,0 +1,211 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License atd +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json', + 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json', +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = 'phi3' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act='silu', + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, ' + f'got {self.rope_scaling}' + ) + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_short_factor = self.rope_scaling.get('short_factor', None) + rope_scaling_long_factor = self.rope_scaling.get('long_factor', None) + if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']: + raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/conversation.py b/src/vlm_backbone/intern_vl/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..45bdefe0b32c075bb7d2a19606ca691128b66978 --- /dev/null +++ b/src/vlm_backbone/intern_vl/conversation.py @@ -0,0 +1,402 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Any, Dict, List, Tuple, Union + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + INTERNVL_ZH = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = '{system_message}' + # The system message + system_message: str = '' + # The names of two roles + roles: Tuple[str] = ('USER', 'ASSISTANT') + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = '\n' + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ': ' # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = '' if system_prompt == '' else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ': ' + + message.replace('\r\n', '\n').replace('\n\n', '\n') + ) + ret += '\n\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = '[INST] ' + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + ' ' + else: + ret += tag + ' ' + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == 'chatglm2' else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = '' + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f'[Round {i//2 + round_add_n}]{self.sep}' + + if message: + ret += f'{role}:{message}{self.sep}' + else: + ret += f'{role}:' + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + '\n' + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = '' + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + '\n' + ' ' + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + # if i % 2 == 0: + # ret += "" + if message: + ret += role + ':' + message + seps[i % 2] + '\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ':\n' + message + seps[i % 2] + if i % 2 == 1: + ret += '\n\n' + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ': ' + '' + message + '' + else: + ret += role + ': ' + '' + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ':\n' + message + self.sep + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = '' + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + + return ret + elif self.sep_style == SeparatorStyle.INTERNVL_ZH: + seps = [self.sep2, self.sep] + ret = self.system_message + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.MPT: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f'Invalid style: {self.sep_style}') + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{'role': 'system', 'content': self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({'role': 'user', 'content': msg}) + else: + if msg is not None: + ret.append({'role': 'assistant', 'content': msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + 'template_name': self.name, + 'system_message': self.system_message, + 'roles': self.roles, + 'messages': self.messages, + 'offset': self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f'{template.name} has been registered.' + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# InternVL-Chat-V1-1 template +register_conv_template( + Conversation( + name='internvl_zh', + system_template='', + roles=('', ''), + sep_style=SeparatorStyle.INTERNVL_ZH, + sep='', + sep2=' ', + ) +) + + +# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference +# is that during training, the preprocessing function for the Hermes-2 template doesn't add +# at the beginning of the tokenized sequence, while the internlm2-chat template does. +# Therefore, they are completely equivalent during inference. +register_conv_template( + Conversation( + name='Hermes-2', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + stop_str='<|endoftext|>', + ) +) + + +register_conv_template( + Conversation( + name='internlm2-chat', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + ) +) + + +register_conv_template( + Conversation( + name='phi3-chat', + system_template='<|system|>\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|user|>\n', '<|assistant|>\n'), + sep_style=SeparatorStyle.MPT, + sep='<|end|>', + ) +) + + +register_conv_template( + Conversation( + name='internvl2_5', + system_template='<|im_start|>system\n{system_message}', + system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>\n', + ) +) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/image_processor.py b/src/vlm_backbone/intern_vl/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2125695eb27d14d5e55b5f603aa91bb5ba465e --- /dev/null +++ b/src/vlm_backbone/intern_vl/image_processor.py @@ -0,0 +1,99 @@ +from transformers import AutoImageProcessor + +class InternVLImageProcessor(BaseImageProcessor): + """InternVL Image Processor""" + def __init__(self, input_size=448, max_num=12, use_thumbnail=False, **kwargs): + super().__init__(**kwargs) + self.input_size = input_size + self.max_num = max_num + self.use_thumbnail = use_thumbnail + + @staticmethod + def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + """Find the aspect ratio closest to the original proportion""" + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + """Get split and thumbnail images""" + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # Calculate the target aspect ratios based on min_num and max_num + target_ratios = set() + for i in range(1, max_num + 1): + min_j = max((min_num + i - 1) // i, 1) + max_j = min(max_num // i, max_num) + for j in range(min_j, max_j + 1): + target_ratios.add((i, j)) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # Find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, + image_size) + + # Calculate target width and height based on aspect ratio + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # Resize the image + resized_img = image.resize((target_width, target_height), Image.BICUBIC) + + # Split the image into blocks based on the correct logic + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # Crop the image using the calculated box + split_img = resized_img.crop(box) + processed_images.append(split_img) + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size), Image.BICUBIC) + processed_images.append(thumbnail_img) + + return processed_images + + def build_transform(self, input_size): + """Build the transformation pipeline.""" + means, stds = IMAGENET_MEAN, IMAGENET_STD + transform = [ + lambda img: img.convert('RGB') if img.mode != 'RGB' else img, + Resize((input_size, input_size), interpolation='bicubic'), + vision.ToTensor(), + vision.Normalize(mean=means, std=stds, is_hwc=False) + ] + return transform + + def apply_transform(self, image, transform): + for t in transform: + image = t(image) + return np.array(image) + + def load_image(self, image_file, input_size=448, max_num=12): + image = Image.open(image_file).convert('RGB') + transform = self.build_transform(input_size=input_size) + images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [self.apply_transform(image, transform) for image in images] + pixel_values = ms.Tensor(np.stack(pixel_values)) + return pixel_values + + # pylint: disable=W0613 + def preprocess(self, images: Union[ms.Tensor, Image.Image, np.ndarray, List[Image.Image]], **kwargs): + return self.load_image(images) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/modeling_intern_vit.py b/src/vlm_backbone/intern_vl/modeling_intern_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..28cf001db48c0da19ee60e6f579e7dbf3896595e --- /dev/null +++ b/src/vlm_backbone/intern_vl/modeling_intern_vit.py @@ -0,0 +1,430 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from timm.models.layers import DropPath +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_intern_vit import InternVisionConfig + +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import \ + flash_attn_varlen_qkvpacked_func + has_flash_attn = True +except: + print('FlashAttention2 is not installed.') + has_flash_attn = False + +logger = logging.get_logger(__name__) + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + + return output, None + + +class InternRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +try: + from apex.normalization import FusedRMSNorm + + InternRMSNorm = FusedRMSNorm # noqa + + logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') +except ImportError: + # using the normal InternRMSNorm + pass +except Exception: + logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') + pass + + +NORM2FN = { + 'rms_norm': InternRMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.randn(1, 1, self.embed_dim), + ) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ + reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) + ], dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_flash_attn = config.use_flash_attn and has_flash_attn + if config.use_flash_attn and not has_flash_attn: + print('Warning: Flash Attention is not available, use_flash_attn is set to False.') + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).' + ) + + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj_drop = nn.Dropout(config.dropout) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + if self.use_flash_attn: + self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _naive_attn(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _flash_attn(self, x, key_padding_mask=None, need_weights=False): + qkv = self.qkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + + if self.qk_normalization: + q, k, v = qkv.unbind(2) + q = self.q_norm(q.flatten(-2, -1)).view(q.shape) + k = self.k_norm(k.flatten(-2, -1)).view(k.shape) + qkv = torch.stack([q, k, v], dim=2) + + context, _ = self.inner_attn( + qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False + ) + outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) + outs = self.proj_drop(outs) + return outs + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) + return x + + +class InternMLP(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.act = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + def __init__(self, config: InternVisionConfig, drop_path_rate: float): + super().__init__() + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config) + self.mlp = InternMLP(config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: + """ + Args: + hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1) + + hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2) + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InternEncoderLayer`]. + + Args: + config (`InternConfig`): + The corresponding vision configuration for the `InternEncoder`. + """ + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) + self.gradient_checkpointing = True + + def forward( + self, + inputs_embeds, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + encoder_layer, + hidden_states) + else: + layer_outputs = encoder_layer( + hidden_states, + ) + hidden_states = layer_outputs + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states + ) + + +class InternVisionModel(PreTrainedModel): + main_input_name = 'pixel_values' + _supports_flash_attn_2 = True + config_class = InternVisionConfig + _no_split_modules = ['InternVisionEncoderLayer'] + + def __init__(self, config: InternVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder(config) + + def resize_pos_embeddings(self, old_size, new_size, patch_size): + pos_emb = self.embeddings.position_embedding + _, num_positions, embed_dim = pos_emb.shape + cls_emb = pos_emb[:, :1, :] + pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) + pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) + pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) + pos_emb = torch.cat([cls_emb, pos_emb], dim=1) + self.embeddings.position_embedding = nn.Parameter(pos_emb) + self.embeddings.image_size = new_size + logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and pixel_embeds is None: + raise ValueError('You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + else: + if len(pixel_values.shape) == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/modeling_internlm2.py b/src/vlm_backbone/intern_vl/modeling_internlm2.py new file mode 100644 index 0000000000000000000000000000000000000000..d46efc2759e1fc822ab03da3c959627f1ff32692 --- /dev/null +++ b/src/vlm_backbone/intern_vl/modeling_internlm2.py @@ -0,0 +1,1429 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch InternLM2 model.""" +import math +import queue +import threading +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + +from .configuration_internlm2 import InternLM2Config + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'InternLM2Config' + +flash_attn_func, flash_attn_varlen_func = None, None +pad_input, index_first_axis, unpad_input = None, None, None +try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + has_flash_attn = True +except: + has_flash_attn = False + + +def _import_flash_attn(): + global flash_attn_func, flash_attn_varlen_func + global pad_input, index_first_axis, unpad_input + try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import \ + flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import \ + index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + except ImportError: + raise ImportError('flash_attn is not installed.') + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2 +class InternLM2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + InternLM2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +try: + from functools import partial + + from apex.normalization import FusedRMSNorm + InternLM2RMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa + print('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternLM2RMSNorm') +except ImportError: + # using the normal LlamaRMSNorm + pass +except Exception: + print('discovered apex but it failed to load, falling back to InternLM2RMSNorm') + pass + + +# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 +class InternLM2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2 +class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2 +class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class InternLM2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + return down_proj + + +# Copied from transformers.model.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Modified from transformers.model.llama.modeling_llama.LlamaAttention +class InternLM2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternLM2Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).' + ) + + self.wqkv = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.bias, + ) + + self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = InternLM2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'dynamic': + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + elif scaling_type == 'linear': + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + else: + raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.") + return self.rotary_emb + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`' + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2 +class InternLM2FlashAttention2(InternLM2Attention): + """ + InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # InternLM2FlashAttention2 attention does not support output_attentions + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len + ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q.to(torch.int64), + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +INTERNLM2_ATTENTION_CLASSES = { + 'eager': InternLM2Attention, + 'flash_attention_2': InternLM2FlashAttention2, +} + + +# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer +class InternLM2DecoderLayer(nn.Module): + def __init__(self, config: InternLM2Config): + super().__init__() + self.hidden_size = config.hidden_size + + self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config) + + self.feed_forward = InternLM2MLP(config) + self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`' + ) + + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +InternLM2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InternLM2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2PreTrainedModel(PreTrainedModel): + config_class = InternLM2Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['InternLM2DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +InternLM2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Modified from transformers.model.llama.modeling_llama.LlamaModel +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2Model(InternLM2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`] + + Args: + config: InternLM2Config + """ + + _auto_class = 'AutoModel' + + def __init__(self, config: InternLM2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + if not has_flash_attn: + self.config.attn_implementation = 'eager' + print('Warning: Flash attention is not available, using eager attention instead.') + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tok_embeddings + + def set_input_embeddings(self, value): + self.tok_embeddings = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.attn_implementation == 'flash_attention_2': + _import_flash_attn() + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + if self.config.attn_implementation == 'flash_attention_2': + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM +class InternLM2ForCausalLM(InternLM2PreTrainedModel): + _auto_class = 'AutoModelForCausalLM' + + _tied_weights_keys = ['output.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = InternLM2Model(config) + self.vocab_size = config.vocab_size + self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, InternLM2ForCausalLM + + >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.output(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + device = input_ids.device if input_ids is not None else inputs_embeds.device + output = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + output['logits'] = output['logits'].to(device) + return output + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=''): + if tokenizer.add_bos_token: + prompt = '' + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" + for record in history: + prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" + prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" + return tokenizer([prompt], return_tensors='pt') + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n' + '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n' + '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.', + **kwargs, + ): + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split('<|im_end|>')[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + 'The version of `transformers` is too low. Please make sure ' + 'that you have installed `transformers>=4.28.0`.' + ) + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = '' + self.cache = [] + self.received_inputs = False + self.queue.put((self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError('ChatStreamer only supports batch size 1') + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode(self.cache, skip_special_tokens=True) + if token.strip() != '<|im_end|>': + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a sequence classification head on top (linear layer). + + [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, + as other causal models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForSequenceClassification(InternLM2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/modeling_internvl_chat.py b/src/vlm_backbone/intern_vl/modeling_internvl_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c832af055616dbf1a894160543cf67d8ea3608 --- /dev/null +++ b/src/vlm_backbone/intern_vl/modeling_internvl_chat.py @@ -0,0 +1,479 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import warnings +from typing import List, Optional, Tuple, Union + +import torch.distributed as dist +import torch.utils.checkpoint +import transformers +from .conversation import get_conv_template +from .modeling_internlm2 import InternLM2ForCausalLM +from .modeling_phi3 import Phi3ForCausalLM +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, + LlamaTokenizer, Qwen2ForCausalLM) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from .configuration_internvl_chat import InternVLChatConfig +from .modeling_intern_vit import InternVisionModel, has_flash_attn + +logger = logging.get_logger(__name__) + +IMG_CONTEXT_TOKEN = "" +def version_cmp(v1, v2, op='eq'): + import operator + + from packaging import version + op_func = getattr(operator, op) + return op_func(version.parse(v1), version.parse(v2)) + + +class InternVLChatModel(PreTrainedModel): + config_class = InternVLChatConfig + main_input_name = 'pixel_values' + base_model_prefix = 'language_model' + _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer', + 'Phi3DecoderLayer', 'Qwen2DecoderLayer'] + _supports_flash_attn_2 = True + supports_gradient_checkpointing = True + + def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True, tokenizer=None): + super().__init__(config) + + assert version_cmp(transformers.__version__, '4.37.0', 'ge') + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.select_layer = config.select_layer + self.template = config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + self.llm_arch_name = config.llm_config.architectures[0] + # Enable Flash Attention if supported, otherwise fall back to eager attention. + use_flash_attn = use_flash_attn if has_flash_attn else False + config.vision_config.use_flash_attn = True if use_flash_attn else False + config.llm_config.attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' + + logger.info(f'num_image_token: {self.num_image_token}') + logger.info(f'ps_version: {self.ps_version}') + if vision_model is not None: + self.vision_model = vision_model + else: + self.vision_model = InternVisionModel(config.vision_config) + # print(vision_model) + if language_model is not None: + self.language_model = language_model + else: + if config.llm_config.architectures[0] == 'LlamaForCausalLM': + self.language_model = LlamaForCausalLM(config.llm_config) + elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM': + self.language_model = InternLM2ForCausalLM(config.llm_config) + elif config.llm_config.architectures[0] == 'Phi3ForCausalLM': + self.language_model = Phi3ForCausalLM(config.llm_config) + elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': + self.language_model = Qwen2ForCausalLM(config.llm_config) + else: + raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') + + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size) + ) + + self.img_context_token_id = None + if tokenizer is not None: + self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.conv_template = get_conv_template(self.template) + if hasattr(config, 'system_message'): + self.system_message = config.system_message + else: + self.system_message = self.conv_template.system_message + self.num_samples = 0 + + if config.use_backbone_lora: + self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) + + if config.use_llm_lora: + self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) + + def get_lora_target_modules(self): + if self.llm_arch_name == 'InternLM2ForCausalLM': + target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3'] + elif self.llm_arch_name == 'Phi3ForCausalLM': + target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj'] + elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']: + target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', + 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'] + else: + raise NotImplemented + return ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'] + target_modules + + def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): + lora_config = LoraConfig( + r=r, + target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + self.vision_model = get_peft_model(self.vision_model, lora_config) + self.vision_model.print_trainable_parameters() + + def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): + # Determine the target modules based on the architecture of the language model + if self.llm_arch_name == 'InternLM2ForCausalLM': + target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3'] + elif self.llm_arch_name == 'Phi3ForCausalLM': + target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj'] + elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']: + target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', + 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'] + else: + raise NotImplemented + lora_config = LoraConfig( + r=r, + target_modules=target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + task_type='CAUSAL_LM' + ) + self.language_model = get_peft_model(self.language_model, lora_config) + self.language_model.enable_input_require_grads() + self.language_model.print_trainable_parameters() + + def forward( + self, + pixel_values: torch.FloatTensor = None, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_flags: Optional[torch.LongTensor] = None, # [batch_size], 1表示有图片,0表示无图片 + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + statistics: Optional[torch.LongTensor] = None, + loss_weight: Optional[List] = None, + loss_reduction_all_gather: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # import pdb; pdb.set_trace() + # 获取原始batch size和每个样本的序列长度 + B, N = input_ids.shape + input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() # [B, N, C] + + if pixel_values is not None: + vit_embeds = self.extract_feature(pixel_values) # [num_images, num_patches, C] + + # 找到input_ids中需要替换的图片token位置 + selected = torch.eq(input_ids, self.img_context_token_id) # [B, N] + + # 确保image_flags维度正确 + image_flags = image_flags.squeeze(-1) # [B] + + # # 记录两种方法的时间 + # import time + + # # 方法1: 循环替换 + # start_time1 = time.time() + # input_embeds2 = input_embeds.clone() + # vit_idx = 0 + # for i in range(B): + # if image_flags[i] == 1: + # sample_selected = selected[i] + # input_embeds2[i, sample_selected] = input_embeds2[i, sample_selected] * 0.0 + vit_embeds[vit_idx] + # vit_idx += 1 + # time1 = time.time() - start_time1 + + # 方法2: 向量化替换 + # start_time2 = time.time() + mask = selected & (image_flags.unsqueeze(-1)) == 1 + input_embeds[mask] = vit_embeds.reshape(-1, vit_embeds.shape[-1]) + # time2 = time.time() - start_time2 + + # print(f"循环替换用时: {time1:.6f}秒") + # print(f"向量化替换用时: {time2:.6f}秒") + # print(f"向量化方法比循环方法快 {time1/time2:.2f}倍") + + # print(f"input_ids.shape = {input_ids.shape}") # [B, N] + # print(f"input_embeds.shape = {input_embeds.shape}") # [B, N, C] + # print(f"pixel_values.shape = {pixel_values.shape}") # [num_images, ...] + # print(f"vit_embeds.shape = {vit_embeds.shape}") # [num_images, num_patches, C] + # print(f"image_flags.sum() = {image_flags.sum()}") # 应该等于num_images + + # print(torch.allclose(input_embeds2, input_embeds, rtol=1e-7)) + # assert torch.allclose(input_embeds2, input_embeds, rtol=1e-5), "input_embeds2 and input_embeds should have the same values" + + + outputs = self.language_model( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits + + loss = None + if labels is not None and loss_weight is not None: + loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_weights = loss_weight[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction='none') + shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_weights = shift_weights.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + shift_weights = shift_weights.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + shift_weights_sum = shift_weights.sum() + if loss_reduction_all_gather: + dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG) + + loss = loss * shift_weights + loss = loss.sum() / shift_weights_sum + if ignore_flag: + loss = loss * 0.0 + elif labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if ignore_flag: + loss = loss * 0.0 + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " + 'which results in a transposed image.') + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True).last_hidden_state + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=True).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, + history=None, return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', + IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None): + if history is not None or return_history: + print('Now multi-turn chat is not supported in batch_chat.') + raise NotImplementedError + + if image_counts is not None: + num_patches_list = image_counts + print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') + + img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.img_context_token_id = img_context_token_id + + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + print(f'dynamic ViT batch size: {image_bs}') + + queries = [] + for idx, num_patches in enumerate(num_patches_list): + question = questions[idx] + if pixel_values is not None and '' not in question: + question = '\n' + question + template = get_conv_template(self.template) + template.system_message = self.system_message + template.append_message(template.roles[0], question) + template.append_message(template.roles[1], None) + query = template.get_prompt() + + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace('', image_tokens, 1) + queries.append(query) + + tokenizer.padding_side = 'left' + model_inputs = tokenizer(queries, return_tensors='pt', padding=True) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + input_ids = model_inputs['input_ids'].to(device) + attention_mask = model_inputs['attention_mask'].to(device) + eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) + generation_config['eos_token_id'] = eos_token_id + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) + responses = [response.split(template.sep.strip())[0].strip() for response in responses] + return responses + + def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, + num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', + verbose=False): + + if history is None and pixel_values is not None and '' not in question: + question = '\n' + question + + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + + img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.img_context_token_id = img_context_token_id + + template = get_conv_template(self.template) + template.system_message = self.system_message + eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) + + history = [] if history is None else history + for (old_question, old_answer) in history: + template.append_message(template.roles[0], old_question) + template.append_message(template.roles[1], old_answer) + template.append_message(template.roles[0], question) + template.append_message(template.roles[1], None) + query = template.get_prompt() + + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + print(f'dynamic ViT batch size: {image_bs}') + + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace('', image_tokens, 1) + + model_inputs = tokenizer(query, return_tensors='pt') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + input_ids = model_inputs['input_ids'].to(device) + attention_mask = model_inputs['attention_mask'].to(device) + generation_config['eos_token_id'] = eos_token_id + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] + response = response.split(template.sep.strip())[0].strip() + history.append((question, response)) + if return_history: + return response, history + else: + query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') + query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') + if verbose: + print(query_to_print, response) + return response + + @torch.no_grad() + def generate( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + visual_features: Optional[torch.FloatTensor] = None, + generation_config: Optional[GenerationConfig] = None, + output_hidden_states: Optional[bool] = None, + **generate_kwargs, + ) -> torch.LongTensor: + + assert self.img_context_token_id is not None + if pixel_values is not None: + if visual_features is not None: + vit_embeds = visual_features + else: + vit_embeds = self.extract_feature(pixel_values) + input_embeds = self.language_model.get_input_embeddings()(input_ids) + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.img_context_token_id) + assert selected.sum() != 0 + input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) + + input_embeds = input_embeds.reshape(B, N, C) + else: + input_embeds = self.language_model.get_input_embeddings()(input_ids) + + outputs = self.language_model.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + use_cache=True, + **generate_kwargs, + ) + + return outputs + + @property + def lm_head(self): + return self.language_model.get_output_embeddings() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/modeling_phi3.py b/src/vlm_backbone/intern_vl/modeling_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..67dbad993a4bdbed2f3277b56c207bc19ae5c23e --- /dev/null +++ b/src/vlm_backbone/intern_vl/modeling_phi3.py @@ -0,0 +1,1610 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" PyTorch Phi-3 model.""" + +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import \ + _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, + replace_return_docstrings) + +from .configuration_phi3 import Phi3Config + +logger = logging.get_logger(__name__) + +# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements +# if is_flash_attn_2_available(): +_flash_supports_window_size = False +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa + unpad_input) + + _flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters) + has_flash_attn = True +except ImportError as error: + logger.warning( + f'`flash-attention` package not found, consider installing for better performance: {error}.' + ) + if not _flash_supports_window_size: + logger.warning( + "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`." + ) + has_flash_attn = False + +_CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct' +_CONFIG_FOR_DOC = 'Phi3Config' + +PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'microsoft/Phi-3-mini-4k-instruct', + 'microsoft/Phi-3-mini-128k-instruct', + # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 +] + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer('inv_freq', None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling['short_factor'] + self.long_factor = config.rope_scaling['long_factor'] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling['short_factor'] + self.long_factor = config.rope_scaling['long_factor'] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will ' + 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` ' + 'when creating this class.' + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).' + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + if scaling_type == 'su': + self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) + elif scaling_type == 'yarn': + self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.') + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.' + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Phi3FlashAttention2(Phi3Attention): + """ + Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Phi3FlashAttention2 attention does not support output_attentions + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." + ) + raise ValueError('The current flash attention version does not support sliding window attention.') + + output_attentions = False + + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.' + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, 'sliding_window', None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, 'sliding_window', None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got' + f' {past_key.shape}' + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.' + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=attn_dropout, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO @Arthur no longer copied from LLama after static cache +class Phi3SdpaAttention(Phi3Attention): + """ + Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Phi3Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + 'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == 'cuda' and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHI3_ATTENTION_CLASSES = { + 'eager': Phi3Attention, + 'flash_attention_2': Phi3FlashAttention2, + 'sdpa': Phi3SdpaAttention, +} + + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHI3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.', + PHI3_START_DOCSTRING, +) +class Phi3PreTrainedModel(PreTrainedModel): + config_class = Phi3Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['Phi3DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + _version = '0.0.5' + + def __init__(self, config: Phi3Config): + if not has_flash_attn: + config._attn_implementation = 'eager' + print('Warning: Flash attention is not available, using eager attention instead.') + super().__init__(config) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.', + PHI3_START_DOCSTRING, +) +class Phi3Model(Phi3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + ' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to ' + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == 'flash_attention_2': + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Phi3ForCausalLM(Phi3PreTrainedModel): + _tied_weights_keys = ['lm_head.weight'] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if (inputs_embeds is not None and past_key_values is None) or (inputs_embeds is not None and len(past_key_values) == 0): + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The [`Phi3Model`] with a sequence classification head on top (linear layer). + + [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs +class Phi3ForSequenceClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHI3_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs +class Phi3ForTokenClassification(Phi3PreTrainedModel): + def __init__(self, config: Phi3Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = Phi3Model(config) + if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) \ No newline at end of file diff --git a/src/vlm_backbone/intern_vl/processing_internvl.py b/src/vlm_backbone/intern_vl/processing_internvl.py new file mode 100644 index 0000000000000000000000000000000000000000..13dff10007ef4d55e76b1a8aae69ee4da9e1774e --- /dev/null +++ b/src/vlm_backbone/intern_vl/processing_internvl.py @@ -0,0 +1,160 @@ +# processing_internvl.py +from typing import List, Optional, Union +from transformers import ProcessorMixin, BatchFeature +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput, PaddingStrategy, TruncationStrategy +import torch +import re +import numpy as np + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_CONTEXT_TOKEN = "" + +# class InternVLProcessor(ProcessorMixin): +# attributes = ["image_processor", "tokenizer"] +# image_processor_class = "AutoImageProcessor" +# tokenizer_class = "AutoTokenizer" + +# def __init__(self, image_processor, tokenizer, num_img_tokens=256): +# super().__init__(image_processor, tokenizer) +# self.num_img_tokens = num_img_tokens +# self._add_special_tokens() + +# def _add_special_tokens(self): +# special_tokens = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN] +# self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) + +# def __call__( +# self, +# text: Union[TextInput, List[TextInput]] = None, +# images: ImageInput = None, +# padding: Union[bool, str, PaddingStrategy] = False, +# truncation: Union[bool, str, TruncationStrategy] = None, +# max_length: Optional[int] = None, +# return_tensors: Optional[str] = "pt", +# ) -> BatchFeature: + +# # Process images +# pixel_values = [] +# if images is not None: +# image_inputs = self.image_processor(images, return_tensors=return_tensors) +# pixel_values = image_inputs.pixel_values + +# # Process text with image tokens +# processed_text = self._insert_image_tokens(text, num_images=len(pixel_values)) + +# # Tokenize text +# text_inputs = self.tokenizer( +# processed_text, +# padding=padding, +# truncation=truncation, +# max_length=max_length, +# return_tensors=return_tensors, +# add_special_tokens=False +# ) + +# # Build final inputs +# inputs = BatchFeature(data={ +# **text_inputs, +# "pixel_values": pixel_values, +# }) + +# return inputs + +# def _insert_image_tokens(self, text: str, num_images: int) -> str: +# """Replace tags with image context tokens""" +# image_tokens = [] +# for _ in range(num_images): +# image_tokens.append( +# f"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * self.num_img_tokens}{IMG_END_TOKEN}" +# ) + +# # Replace the first N occurrences of +# pattern = re.compile(r"") +# return pattern.sub(lambda x: image_tokens.pop(0) if image_tokens else "", text, count=num_images) + + +class InternVLProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor, tokenizer, num_img_tokens=256): + super().__init__(image_processor, tokenizer) + self.num_img_tokens = num_img_tokens + self.img_context_token = "" + self._add_special_tokens() + + def _add_special_tokens(self): + special_tokens = [self.img_context_token] + num_added = self.tokenizer.add_special_tokens({ + "additional_special_tokens": special_tokens + }) + # print(self.tokenizer) + # assert num_added == 1, f"Failed to add IMG_CONTEXT token, added {num_added}" + + def __call__( + self, + text: Union[str, List[str]], + images: Union[ImageInput, List[ImageInput]] = None, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + return_tensors: str = "pt" + ) -> BatchFeature: + # import pdb; pdb.set_trace() + + # 处理单样本输入 + if isinstance(text, str): + text = [text] + + if not isinstance(images, list): + images = [images] if images else [] + + # 生成image_flags + image_flags = [1] if len(images) else [0] + + # 图像预处理 + pixel_values = [] + if any(image_flags): + pixel_values = self.image_processor( + [img for img in images if img], # img.size(525, 704) + return_tensors=return_tensors + ).pixel_values # torch.Size([1, 3, 448, 448]) + + # 文本预处理 + processed_texts = [ + self._insert_image_tokens(t, count) + for t, count in zip(text, image_flags) + ] + # print("process text:") + # print(processed_texts) + # print("text") + # print(text) + # print(images) + # print(image_flags) + # Tokenize文本 + text_inputs = self.tokenizer( + processed_texts, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + add_special_tokens=True + ) + + # 构建最终输入 + return BatchFeature({ + **text_inputs, + "pixel_values": pixel_values, + "image_flags": torch.tensor(image_flags), + }, tensor_type=return_tensors) + + def _insert_image_tokens(self, text: str, image_count: int) -> str: + """动态插入图像token""" + if image_count == 0: + return text + + image_tokens = f"{self.img_context_token * self.num_img_tokens * image_count}" + return text.replace("", image_tokens, 1) \ No newline at end of file diff --git a/src/vlm_backbone/phi3_v/__pycache__/configuration_phi3_v.cpython-310.pyc b/src/vlm_backbone/phi3_v/__pycache__/configuration_phi3_v.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1e6c9a5ce869fe52bd707114a2302074867098c Binary files /dev/null and b/src/vlm_backbone/phi3_v/__pycache__/configuration_phi3_v.cpython-310.pyc differ diff --git a/src/vlm_backbone/phi3_v/__pycache__/image_embedding_phi3_v.cpython-310.pyc b/src/vlm_backbone/phi3_v/__pycache__/image_embedding_phi3_v.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8925a0a81b4ad2bd2d35c8593f1b3b6fce5e0bf6 Binary files /dev/null and b/src/vlm_backbone/phi3_v/__pycache__/image_embedding_phi3_v.cpython-310.pyc differ diff --git a/src/vlm_backbone/phi3_v/__pycache__/image_processing_phi3_v.cpython-310.pyc b/src/vlm_backbone/phi3_v/__pycache__/image_processing_phi3_v.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b7abde7bafab0b26fe0d648f876e74c27e6785b Binary files /dev/null and b/src/vlm_backbone/phi3_v/__pycache__/image_processing_phi3_v.cpython-310.pyc differ diff --git a/src/vlm_backbone/phi3_v/__pycache__/modeling_phi3_v.cpython-310.pyc b/src/vlm_backbone/phi3_v/__pycache__/modeling_phi3_v.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9680f5eeae2fcd59a2f02c5afed9cf563aa94f1 Binary files /dev/null and b/src/vlm_backbone/phi3_v/__pycache__/modeling_phi3_v.cpython-310.pyc differ diff --git a/src/vlm_backbone/phi3_v/__pycache__/processing_phi3_v.cpython-310.pyc b/src/vlm_backbone/phi3_v/__pycache__/processing_phi3_v.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6835a71f436586a6188c1b44a06e9da6a8dff2f4 Binary files /dev/null and b/src/vlm_backbone/phi3_v/__pycache__/processing_phi3_v.cpython-310.pyc differ diff --git a/src/vlm_backbone/phi3_v/configuration_phi3_v.py b/src/vlm_backbone/phi3_v/configuration_phi3_v.py new file mode 100644 index 0000000000000000000000000000000000000000..f015f5f8ed98219263296d4e0677d1aa174d44ec --- /dev/null +++ b/src/vlm_backbone/phi3_v/configuration_phi3_v.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3-V model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PHI3V_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-vision-128k-instruct": "https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/config.json", +} + + +class Phi3VConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3VModel`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3-V model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3VModel`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + embd_layer (`str`, *optional*, defaults to `"default"`): + The embedding layer to use. Can be either `"default"` or `"image"`. "default" uses the standard embedding for text. + + Example: + + ```python + >>> from transformers import Phi3VModel, Phi3VConfig + + >>> # Initializing a Phi-3-V style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-vision-128k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3VModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3_v" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + embd_layer: str = "default", + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.sliding_window = sliding_window + self.embd_layer = embd_layer + + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) \ No newline at end of file diff --git a/src/vlm_backbone/phi3_v/image_embedding_phi3_v.py b/src/vlm_backbone/phi3_v/image_embedding_phi3_v.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7c2990f75eb3ba36be2cc5096c16083c250429 --- /dev/null +++ b/src/vlm_backbone/phi3_v/image_embedding_phi3_v.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import warnings + +import torch +from torch import nn +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig +from transformers.models.clip.modeling_clip import CLIPAttention +from transformers.utils import logging + +try: + from flash_attn import flash_attn_func +except ImportError: + pass + +logger = logging.get_logger(__name__) + + +MAX_INPUT_ID = int(1e9) + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( + attention_dropout=0.0, + dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768 +) + +class CLIPAttentionFA2(CLIPAttention): + """Add flash attention 2 to CLIPAttention. (This is only used in the vision encoder)""" + + def forward(self, + hidden_states, + attention_mask=None, + causal_attention_mask=None, + output_attentions=False, + ): + """Input shape: Batch x Time x Channel""" + + assert attention_mask is None, "CLIPAttentionFA2 does not support attention_mask" + assert causal_attention_mask is None, "CLIPAttentionFA2 does not support causal_attention_mask" + assert output_attentions is False, "CLIPAttentionFA2 does not support output_attentions" + + bsz, tgt_len, embed_dim = hidden_states.size() + query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=self.dropout if self.training else 0.0, + softmax_scale=self.scale, + causal=False, + ).reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None + + +class Phi3ImageEmbedding(nn.Module): + """Phi3 Image embedding.""" + + def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): + embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + self.wte = wte + + if isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model': + assert 'model_name' in config.img_processor, 'model_name must be provided for CLIPVisionModel' + assert 'image_dim_out' in config.img_processor, 'image_dim_out must be provided for CLIPVisionModel' + assert 'num_img_tokens' in config.img_processor, 'num_img_tokens must be provided for CLIPVisionModel' + assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336' + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config) + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + + # FA2 in CLIP + if config._attn_implementation == 'flash_attention_2': + for layer in self.img_processor.vision_model.encoder.layers: + clip_fa2 = CLIPAttentionFA2(clip_config) + del layer.self_attn + layer.self_attn = clip_fa2 + else: + raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented') + + self.image_dim_out = image_dim_out + self.img_sizes = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = kwargs.get('use_hd_transform', False) + self.with_learnable_separator = kwargs.get('with_learnable_separator', False) + self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' + if self.with_learnable_separator: + assert self.use_hd_transform, 'learnable separator is only for hd transform' + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4])) + logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') + + projection_cls = kwargs.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.img_projection = nn.Linear(image_dim_out, hidden_size) + elif projection_cls == 'mlp' and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == 'mlp': + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get('type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + + def set_img_features(self, img_features: torch.FloatTensor) -> None: + self.img_features = img_features + + def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: + self.img_sizes = img_sizes + + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + raise NotImplementedError + + def forward( + self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None + ) -> torch.FloatTensor: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + new_input_ids = copy.deepcopy(input_ids) + # warnings.warn( + # "Phi-3-V modifies `input_ids` in-place and the tokens indicating images will be " + # "removed after model forward. If your workflow requires multiple forward passes on " + # "the same `input_ids`, please make a copy of `input_ids` before passing it to the " + # "model." + # ) + + # positions for image tokens. -1 indicating the PADDING positions + positions = torch.nonzero((new_input_ids < 0) & (new_input_ids > -MAX_INPUT_ID), as_tuple=True) + has_image = len(positions[0].tolist()) > 0 + # input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach() + new_input_ids.clamp_min_(0).clamp_max_(self.vocab_size) + + hidden_states = self.wte(new_input_ids) # [BS, seq_len, hidden_dim] + + if has_image: + assert self.use_hd_transform + num_images, num_crops, c, h, w = pixel_values.shape + assert c == 3 and h == w == 336 + # pixel_values.shape=(num_images, num_crops, channel, height, width) + # img_features.shape=(num_images, num_crops, 576=24*24, hidden_dim) + img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape( + num_images, num_crops, -1, self.image_dim_out + ) + # image_features_proj.shape=(1514,3072)=[BS*(num_token_crops+1+num_token_global),3072] + image_features_proj = self.hd_feature_transform(img_features, image_sizes) + # simply assign (accumulate=False) image_features_proj=(2,757,3072) into hidden_states=(2,784,3072), offset by 1 token. So hidden_states[:,1:758,:] becomes image_features_proj + hidden_states = hidden_states.index_put( + positions, image_features_proj, accumulate=False + ) + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + + return hidden_states + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + assert ( + self.hd_transform_order == 'sub_glb' + ), f'hd_transform_order `{self.hd_transform_order}` not implemented' + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + global_image_features = image_features[:, 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline(global_image_features_hd) + + all_image_embeddings = [] + # need a for loop to process each image because of different image sizes + # (patch arrangement is different for each image) + for i, img_size in enumerate(image_sizes): + h, w = img_size.int() + h_crop = h // 336 + w_crop = w // 336 + num_crops = h_crop * w_crop + if num_crops > 0: + # NOTE: real num_crops is padded + # sub_image_features.shape=(num_crops, 24*24, 1024), sub_image_features_hd.shape=(1,24,24,4096), sub_image_features_hd_newline.shape=(1,600,4096)=(num_images, (h_crop*12) * (w_crop*12+1), 4096) + sub_image_features = image_features[i, 1 : 1 + num_crops] + sub_image_features_hd = self.reshape_hd_patches_2x2merge( + sub_image_features, h_crop, w_crop + ) + sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd) + + # [sub features, separator, global features] + all_image_embeddings.append( + torch.cat([ + sub_image_features_hd_newline.squeeze(0), # hd crops, (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), # seperator (1,1,4096) + global_image_features_hd_newline[i], # global thumbnails, (156,4096)==[(h_crop*12) * (w_crop*12+1), 4096] + ]) + ) + # else: + # all_image_embeddings.append(None) + # img_embedding_shape = [v.shape for v in all_image_embeddings if isinstance(v, torch.Tensor)][0] + # all_image_embeddings = [v if isinstance(v, torch.Tensor) else torch.zeros(img_embedding_shape) for v in all_image_embeddings] + # all_image_embeddings = [v.to(image_sizes.device) for v in all_image_embeddings] + # concatenate embeddings of all images (both HD crops and global thumbnails) in the batch + # [BS*(num_token_crops+1+num_token_global),4096]=[BS*(600+1+156),4096]=[BS*757,4096]->[BS*757,3072] + image_features_proj = self.img_projection(torch.cat(all_image_embeddings, dim=0).to(target_device).to(target_dtype)) + return image_features_proj + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape( + num_images, h_crop, w_crop, H // 2, H // 2, -1 + ) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape( + num_images, h_crop * H // 2, w_crop * H // 2, 4 * C + ) # n_img, h_crop*12, w_crop*12, 4096 + ) + + # alternative implementation using einops + # from einops import rearrange + # image_features_nhwc = rearrange( + # image_features, + # 'N (H W) c -> N H W c', + # H=H, + # W=H, + # ) + # image_features_2x2merge = rearrange( + # image_features_nhwc, + # 'N (h h_pool) (w w_pool) c -> N h w (h_pool w_pool c)', + # h_pool=2, + # w_pool=2, + # ) + # image_features_hd = rearrange( + # image_features_2x2merge, + # '(n_img h_crop w_crop) h w C -> n_img (h_crop h) (w_crop w) C', + # h_crop=h_crop, + # w_crop=w_crop, + # ) + + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat( + [image_features_hd, newline_embeddings], dim=2 + ).reshape(num_images, -1, hid_dim) + return image_features_hd_newline \ No newline at end of file diff --git a/src/vlm_backbone/phi3_v/image_processing_phi3_v.py b/src/vlm_backbone/phi3_v/image_processing_phi3_v.py new file mode 100644 index 0000000000000000000000000000000000000000..be42fac041f0a1ec87239e1e3abd6dfde893357c --- /dev/null +++ b/src/vlm_backbone/phi3_v/image_processing_phi3_v.py @@ -0,0 +1,286 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Image processor class for Phi3-V.""" + +from typing import List, Optional, Union + +import numpy as np + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_transforms import ( + convert_to_rgb, +) +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ImageInput, + make_list_of_images, + valid_images, +) +from transformers.utils import TensorType, is_vision_available, logging + +from transformers import AutoImageProcessor + +logger = logging.get_logger(__name__) + +if is_vision_available(): + from PIL import Image + +import torch +import torchvision + + +def padding_336(b): + width, height = b.size + tar = int(np.ceil(height / 336) * 336) + top_padding = int((tar - height) / 2) + bottom_padding = tar - height - top_padding + left_padding = 0 + right_padding = 0 + b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], + fill=[255, 255, 255]) + + return b + + +def calc_padded_size(width, height, padding_unit=336): + target_height = int(np.ceil(height / padding_unit) * padding_unit) + top_padding = int((target_height - height) / 2) + bottom_padding = target_height - height - top_padding + left_padding = 0 + right_padding = 0 + padded_width = width + left_padding + right_padding + padded_height = height + top_padding + bottom_padding + return padded_width, padded_height + + +def HD_transform(img, hd_num=16): + width, height = img.size + trans = False + if width < height: + img = img.transpose(Image.TRANSPOSE) + trans = True + width, height = img.size + ratio = (width / height) + scale = 1 + while scale * np.ceil(scale / ratio) <= hd_num: + scale += 1 + scale -= 1 + new_w = int(scale * 336) + new_h = int(new_w / ratio) + + img = torchvision.transforms.functional.resize(img, [new_h, new_w], ) + img = padding_336(img) + width, height = img.size + if trans: + img = img.transpose(Image.TRANSPOSE) + + return img + + +def calc_hd_transform_size(width, height, hd_num=16): + transposed = False + if width < height: + width, height = height, width + transposed = True + + ratio = width / height + scale = 1 + while scale * np.ceil(scale / ratio) <= hd_num: + scale += 1 + scale -= 1 + + new_width = int(scale * 336) + new_height = int(new_width / ratio) + + padded_width, padded_height = calc_padded_size(new_width, new_height) + + if transposed: + padded_width, padded_height = padded_height, padded_width + + return padded_width, padded_height + + +def pad_to_max_num_crops_tensor(images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if B < max_crops: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + +class Phi3VImageProcessor(BaseImageProcessor): + r""" + Constructs a Phi3 image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques + for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512) + + Args: + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + num_crops: int = 1, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.num_crops = num_crops + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def calc_num_image_tokens( + self, + images: ImageInput + ): + """ Calculate the number of image tokens for each image. + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + """ + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + images = [image.convert('RGB') for image in images] + # (H, W, C) + elems = [HD_transform(im, hd_num=self.num_crops) for im in images] + shapes = [[im.size[1], im.size[0]] for im in elems] + num_img_tokens = [int((h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12) for h, w in shapes] + return num_img_tokens + + def calc_num_image_tokens_from_image_size(self, width, height): + """ + Calculate the number of image tokens for a given image size. + Args: + width (`int`): Width of the image. + height (`int`): Height of the image. + """ + new_width, new_height = calc_hd_transform_size(width, height, hd_num=self.num_crops) + num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12) + return num_img_tokens + + def preprocess( + self, + images: ImageInput, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + """ + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + image_sizes = [] + img_processor = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(image_mean, image_std) + ]) + + # PIL images + # HD_transform pad images to size of multiiply of 336, 336 + # convert to RGB first + images = [image.convert('RGB') for image in images] + elems = [HD_transform(im, hd_num=self.num_crops) for im in images] + # tensor transform and normalize + hd_images = [img_processor(im) for im in elems] + # create global image + global_image = [ + torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336), mode='bicubic', ).to(im.dtype) for + im in hd_images] + + # [(3, h, w)], where h, w is multiple of 336 + shapes = [[im.size(1), im.size(2)] for im in hd_images] + num_img_tokens = [int(((h // 336) * (w // 336) + 1) * 144 + 1 + (h // 336 + 1) * 12) for h, w in shapes] + # reshape to channel dimension -> (num_images, num_crops, 3, 336, 336) + # (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336) + hd_images_reshape = [ + im.reshape(1, 3, h // 336, 336, w // 336, 336).permute(0, 2, 4, 1, 3, 5).reshape(-1, 3, 336, + 336).contiguous() for + im, (h, w) in zip(hd_images, shapes)] + # concat global image and local image + hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in + zip(global_image, hd_images_reshape)] + + # pad to max_num_crops + image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops + 1) for im in hd_images_reshape] + image_transformed = torch.stack(image_transformed, dim=0) + image_sizes = [torch.LongTensor(_shapes) for _shapes in shapes] + padded_images = image_transformed + image_sizes = shapes + + data = {"pixel_values": padded_images, + "image_sizes": image_sizes, + "num_img_tokens": num_img_tokens + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +AutoImageProcessor.register("Phi3VImageProcessor", Phi3VImageProcessor) \ No newline at end of file diff --git a/src/vlm_backbone/phi3_v/modeling_phi3_v.py b/src/vlm_backbone/phi3_v/modeling_phi3_v.py new file mode 100644 index 0000000000000000000000000000000000000000..389753fd5bd2dda96689113e9b04b9e8482ee45a --- /dev/null +++ b/src/vlm_backbone/phi3_v/modeling_phi3_v.py @@ -0,0 +1,1633 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" PyTorch Phi-3-V model.""" + +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_phi3_v import Phi3VConfig +from .image_embedding_phi3_v import Phi3ImageEmbedding + + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError: + pass + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-vision-128k-instruct" +_CONFIG_FOR_DOC = "Phi3VConfig" + +PHI3V_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/Phi-3-vision-128k-instruct", + # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 +] + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3VConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "su": + self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) + elif scaling_type == "yarn": + self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Phi3FlashAttention2(Phi3Attention): + """ + Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Phi3FlashAttention2 attention does not support output_attentions + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." + ) + raise ValueError("The current flash attention version does not support sliding window attention.") + + output_attentions = False + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=attn_dropout, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO @Arthur no longer copied from LLama after static cache +class Phi3SdpaAttention(Phi3Attention): + """ + Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Phi3Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHI3_ATTENTION_CLASSES = { + "eager": Phi3Attention, + "flash_attention_2": Phi3FlashAttention2, + "sdpa": Phi3SdpaAttention, +} + + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3VConfig, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHI3V_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3VConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.", + PHI3V_START_DOCSTRING, +) +class Phi3VPreTrainedModel(PreTrainedModel): + config_class = Phi3VConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3V_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`]. + See [`Phi3ImageProcessor.__call__`] for details. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.", + PHI3V_START_DOCSTRING, +) +class Phi3VModel(Phi3VPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3VConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.vision_embed_tokens = None + if isinstance(config.embd_layer, dict): + # vision embedding layer + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config) + # # set wte the same for vision embedding + # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight + + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + if pixel_values is not None and image_sizes is not None: + assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined" + inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes) + else: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Phi3VForCausalLM(Phi3VPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 + def __init__(self, config): + super().__init__(config) + self.model = Phi3VModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The [`Phi3VModel`] with a sequence classification head on top (linear layer). + + [`Phi3VForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3V_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs +class Phi3VForSequenceClassification(Phi3VPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3VModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`Phi3VModel`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHI3V_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs +class Phi3VForTokenClassification(Phi3VPreTrainedModel): + def __init__(self, config: Phi3VConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = Phi3VModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) \ No newline at end of file diff --git a/src/vlm_backbone/phi3_v/processing_phi3_v.py b/src/vlm_backbone/phi3_v/processing_phi3_v.py new file mode 100644 index 0000000000000000000000000000000000000000..e3109983827310e39e0c7fb475adad390efd5ed1 --- /dev/null +++ b/src/vlm_backbone/phi3_v/processing_phi3_v.py @@ -0,0 +1,223 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi3-V. +""" +import re +from typing import List, Optional, Union + +import torch + +import transformers +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy +from transformers.utils import TensorType +from .image_processing_phi3_v import Phi3VImageProcessor + +transformers.Phi3VImageProcessor = Phi3VImageProcessor + + +class Phi3VProcessor(ProcessorMixin): + r""" + Constructs a Phi3-V processor which wraps a Phi3-V image processor and a LLaMa tokenizer into a single processor. + + [`Phi3VProcessor`] offers all the functionalities of [`Phi3VImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~Phi3VProcessor.__call__`] and [`~Phi3VProcessor.decode`] for more information. + + Args: + image_processor ([`Phi3VImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "Phi3VImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + special_image_token = "<|image|>" + + def __init__(self, image_processor, tokenizer): + self.image_processor = image_processor + self.tokenizer = tokenizer + self.num_img_tokens = image_processor.num_img_tokens + self.img_tokens = [f"<|image_{i + 1}|>" for i in range(1000000)] + + def __call__( + self, + text: Union[TextInput, List[TextInput]], + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + Phi3ImageProcessor's [`~Phi3ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is not None: + image_inputs = self.image_processor(images, return_tensors=return_tensors) + else: + image_inputs = {} + inputs = self._convert_images_texts_to_inputs(image_inputs, text, padding=padding, truncation=truncation, + max_length=max_length, return_tensors=return_tensors) + return inputs + + def calc_num_image_tokens(self, images: ImageInput): + """ Calculate the number of image tokens for each image. + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + """ + return self.image_processor.calc_num_image_tokens(images) + + def calc_num_image_tokens_from_image_size(self, width, height): + """ Calculate the number of image token for an image with given width and height. + Args: + width (`int`): + Width of the image. + height (`int`): + Height of the image. + """ + return self.image_processor.calc_num_image_tokens_from_image_size(width, height) + + @property + def special_image_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.special_image_token) + + def get_special_image_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.special_image_token) + + def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, + return_tensors=None): + + if not len(images): + model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, + max_length=max_length) + return BatchFeature(data={**model_inputs}) + + pattern = r"<\|image_\d+\|>" + prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)] + + if 'num_img_tokens' in images: + num_img_tokens = images['num_img_tokens'] + else: + assert 'num_crops' in images, 'num_crops must be provided in images if num_img_tokens is not provided' + num_crops = images['num_crops'] + num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops] + + images, image_sizes = images['pixel_values'], images['image_sizes'] + + # image_tags needs to start from 1 to n + image_tags = re.findall(pattern, texts) + # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags] + # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)] + image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] + unique_image_ids = sorted(list(set(image_ids))) + # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5] + # check the condition + assert unique_image_ids == list(range(1, + len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + # total images must be the same as the number of image tags + assert len(unique_image_ids) == len( + images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images" + + image_ids_pad = [[-iid] * num_img_tokens[iid - 1] for iid in image_ids] + + def insert_separator(X, sep_list): + if len(X) > len(sep_list): + sep_list.append([]) + return [ele for sublist in zip(X, sep_list) for ele in sublist] + + input_ids = [] + offset = 0 + for x in insert_separator(prompt_chunks, image_ids_pad): + input_ids.extend(x[offset:]) + + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + attention_mask = (input_ids > -1000000).to(torch.long) + + return BatchFeature(data={"input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": images, + "image_sizes": image_sizes}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) \ No newline at end of file