import dataclasses import logging import os from typing import Any, Dict, List import gradio as gr # type: ignore import PIL.Image as Image import PIL.ImageOps as ImageOps import spaces # type: ignore import torch from peft import PeftModel # type: ignore from transformers import AutoProcessor # type: ignore from transformers import Idefics2ForConditionalGeneration, Idefics2Processor from adapter import IdeficsAdapter from config_generator import GameConfig, generate_game_config from utils import device, nested_to_device, sorted_list import copy ### Constants css=""" .radio-group .wrap { display: grid; grid-template-columns: repeat(5, 1fr); grid-template-rows: repeat(5, 1fr); width: 100%; height: 100% } """ IMG_DIR = "tangram_pngs" ### Bot server GEN_KWS: Dict[str, Any] = { "max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "output_logits": True, "return_dict_in_generate": True, "remove_invalid_values": True, # just to be safe "renormalize_logits": True, "suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS } @spaces.GPU(duration=20) def get_model_response( # predict model: PeftModel, adapter_name: str, adapter: IdeficsAdapter, image_paths: List[str], chat : str, chats: List[str], previous_selected: List[List[str]] ) -> List[str]: if model.active_adapter != adapter_name: model.set_adapter(adapter_name) model.to(device()) new_chats = chats + [chat] currently_selected = previous_selected[-1] if len(previous_selected) > 0 else [] model_input: Dict[str, Any] = adapter.compose( # type: ignore image_paths, new_chats, previous_selected, True, False) model_input = nested_to_device(model_input) # type: ignore with torch.inference_mode(), torch.autocast(device_type=device().type, dtype=torch.bfloat16): model_output = model.generate(**model_input, **GEN_KWS) # type: ignore decoded_out: str = adapter.tokenizer.decode( # type: ignore model_output.sequences[0], skip_special_tokens=True) model_clicks = adapter.parse( image_paths, decoded_out, currently_selected) # type: ignore if len(model_clicks) == 0: logging.warning("empty clicks by model") model_clicks = [image_paths[0]] logging.debug(f"{image_paths=}") logging.debug(f"selecting {model_clicks}") prob = -1 else: prob = -3 logging.debug(f"{prob=}") logging.info(f"User input: {chat}") logging.info(f"Model selected: {model_clicks}") logging.debug(f"Model output: {decoded_out}") return model_clicks def get_model() -> PeftModel: model_id = 'lil-lab/respect' checkpoint = "HuggingFaceM4/idefics2-8b" model = Idefics2ForConditionalGeneration.from_pretrained( # type: ignore checkpoint, torch_dtype=torch.bfloat16, ) peft_model = PeftModel.from_pretrained( # type: ignore model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp") # Add other adapter - hack to avoid conflict lora_config = copy.deepcopy(peft_model.active_peft_config) targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters() if 'lora' in n)) lora_config.target_modules = targets peft_model.add_adapter("r0", lora_config) peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0", peft_config=lora_config) return peft_model def get_processor() -> Idefics2Processor: checkpoint = "HuggingFaceM4/idefics2-8b" processor = AutoProcessor.from_pretrained( # type: ignore checkpoint, do_image_splitting=False, size={"longest_edge": 224, "shortest_edge": 224}) return processor # type: ignore def get_adapter() -> IdeficsAdapter: processor = get_processor() return IdeficsAdapter(IMG_DIR, processor) ### Game logic @dataclasses.dataclass(frozen=False) class GameState: config: GameConfig adapter_name: str chats: List[str] currently_selected: List[str] selected_accum: List[List[str]] clicks_accum: List[List[str]] turn: int = 0 def has_ended(self): return self.has_successfully_ended() or self.turn >= 10 def has_successfully_ended(self): return set(self.currently_selected) == set(self.config.targets) ### UI helpers def serialize_conversation(self): output = [f"Turn {i+1}: {message}" for i, message in enumerate(self.chats)] return "\n".join(output) def markup_images(self): context = self.config.speaker_context targets = self.config.targets selected = self.currently_selected changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else [] tangram_list = self._display_context(context, targets, changes, selected) # return [(img, f"Image {i+1}") for i, img in enumerate(tangram_list)] return tangram_list @staticmethod def _display_context(context: List[str], targets: List[str], changes: List[str], selected: List[str]) -> List[Image.Image]: tangram_list: List[Image.Image] = [] arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA") for img in context: image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB") image = ImageOps.expand(image, border=2, fill="white") if img in targets and img in selected: # listener selected a target image image = ImageOps.expand(image, border=10, fill="green") elif img in targets and img not in selected: # unselected target: image = ImageOps.expand(image, border=10, fill="black") elif img in selected and img not in targets: # listener selected a wrong image image = ImageOps.expand(image, border=10, fill="red") else: image = ImageOps.expand(image, border=10, fill="white") image = ImageOps.expand(image, border=2, fill="white") if img in changes: image.paste(arrow, (68, 0), mask=arrow) tangram_list.append(image) return tangram_list class GameFlow: @classmethod def initialize(cls, model_iteration: str) -> GameState: config = generate_game_config() adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp" state = GameState( config=config, adapter_name=adapter_name, chats=[], currently_selected=[], selected_accum=[], clicks_accum=[], turn=0, ) return state @classmethod def progress(cls, state: GameState, chat: str, model: PeftModel, adapter: IdeficsAdapter) -> GameState: turn = state.turn model_context_images = state.config.listener_context model_clicks = get_model_response( model, state.adapter_name, adapter, model_context_images, chat, state.chats, state.selected_accum ) # symmetric difference (apply deselection, then selection) currently_selected2 = sorted_list( (set(state.currently_selected) - set(model_clicks)) \ | (set(model_clicks) - set(state.currently_selected)) ) state2 = GameState( # constants config=state.config, adapter_name=state.adapter_name, # updates chats=state.chats.copy() + [chat], currently_selected=currently_selected2, selected_accum=state.selected_accum.copy() + [currently_selected2], clicks_accum=state.clicks_accum.copy() + [model_clicks], turn=turn+1, ) return state2 ### UI def create_app_inner(): ### layout gr.Markdown("# Tangram Multi-Reference Game") gr.Markdown( '### You will be playing a multi-reference games against a model. \ To start a game, first select whether you wish to play against our \ initial trained model ("Initial System") or \ our model at the end of continual learning ("Final System") \ and press the "Start Game" button. \ You will take on a "speaker" role at each round. \ Your goal is to describe this image (via a message in the textbox) \ so that the model can guess what it is.' ) gr.Markdown("Targets have black borders. Correctly selected targets have green borders. Incorrectly selected targets have red borders. Actions are marked with yellow dot.") gr.Markdown("The listener cannot see boxes or colors and the order is different.") gr.Markdown( '### Press "Send" to submit your action to proceed to the next turn. \ You have 10 turns in total.' ) with gr.Row(): model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration", value="Final System") start_btn = gr.Button("Start Game") with gr.Row(): current_turn = gr.Textbox(label="TURN") success = gr.Textbox(label="Success") with gr.Row(): image_output = gr.Gallery( label="CONTEXT", show_label=False, elem_id="gallery", columns=5, rows=2, object_fit="contain", height="250px", allow_preview=False, container=True, interactive=False ) with gr.Row(): conversation_output = gr.Textbox(label="Interaction History") user_input = gr.Textbox(label="Your Message as Speaker", interactive=True) send_btn = gr.Button("Send", interactive=True) ### globals model = get_model() adapter = get_adapter() game_state = gr.State(value=None) ### callbacks def output_from_state(state: GameState): has_ended = state.has_ended() success = "success" if state.has_successfully_ended() else "failure" return ( state.markup_images(), # image_output state.serialize_conversation(), # conversation_output f"{state.turn+1}/10", # current_turn success if has_ended else "n/a", # success gr.update(interactive=not has_ended, value=""), # user_input gr.update(interactive=not has_ended), # send_btn gr.update(interactive=has_ended), # model_iteration state, # game_history ) def on_start_interaction(model_iteration: str): assert model_iteration in ["Initial System", "Final System"] state = GameFlow.initialize(model_iteration) return output_from_state(state) def on_send_message(message: str, state: GameState): nonlocal model nonlocal adapter if message.strip() == "": logging.info("Empty message") return output_from_state(state) state = GameFlow.progress(state, message, model, adapter) return output_from_state(state) start_btn.click( on_start_interaction, inputs=[model_iteration], outputs=[image_output, conversation_output, current_turn, success, user_input, send_btn, model_iteration, game_state], queue=False ) send_btn.click( on_send_message, inputs=[user_input, game_state], outputs=[image_output, conversation_output, current_turn, success, user_input, send_btn, model_iteration, game_state], queue=True ) def create_app(): with gr.Blocks(css=css) as app: create_app_inner() return app if __name__ == "__main__": app = create_app() app.queue() app.launch()