Spaces:
Runtime error
Runtime error
| import dataclasses | |
| import logging | |
| import os | |
| from typing import Any, Dict, List | |
| import gradio as gr | |
| import PIL.Image as Image | |
| import PIL.ImageOps as ImageOps | |
| import spaces | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoProcessor | |
| from transformers import Idefics2ForConditionalGeneration, Idefics2Processor | |
| from adapter import IdeficsAdapter | |
| from config_generator import GameConfig, generate_game_config | |
| from utils import device, nested_to_device, sorted_list | |
| import copy | |
| ### Constants | |
| IMG_DIR = "tangram_pngs" | |
| ### Bot server | |
| GEN_KWS: Dict[str, Any] = { | |
| "max_new_tokens": 10, | |
| "do_sample": True, | |
| "temperature": 1.0, | |
| "output_logits": True, | |
| "return_dict_in_generate": True, | |
| "remove_invalid_values": True, # just to be safe | |
| "renormalize_logits": True, | |
| "suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS | |
| } | |
| def get_model_response( # predict | |
| model: PeftModel, adapter_name: str, adapter: IdeficsAdapter, | |
| image_paths: List[str], chat : str, chats: List[str], | |
| previous_selected: List[List[str]] | |
| ) -> List[str]: | |
| if model.active_adapter != adapter_name: | |
| model.set_adapter(adapter_name) | |
| model.to(device()) | |
| new_chats = chats + [chat] | |
| currently_selected = previous_selected[-1] if len(previous_selected) > 0 else [] | |
| model_input: Dict[str, Any] = adapter.compose( | |
| image_paths, new_chats, previous_selected, True, True) | |
| model_input = nested_to_device(model_input) | |
| with torch.inference_mode(), torch.autocast(device_type=device().type, | |
| dtype=torch.bfloat16): | |
| model_output = model.generate(**model_input, **GEN_KWS) | |
| decoded_out: str = adapter.tokenizer.decode( | |
| model_output.sequences[0], skip_special_tokens=True) | |
| model_clicks = adapter.parse( | |
| image_paths, decoded_out, currently_selected) | |
| if len(model_clicks) == 0: | |
| logging.warning("empty clicks by model") | |
| model_clicks = [image_paths[0]] | |
| logging.debug(f"{image_paths=}") | |
| logging.debug(f"selecting {model_clicks}") | |
| prob = -1 | |
| else: | |
| prob = -3 | |
| logging.debug(f"{prob=}") | |
| logging.info(f"User input: {chat}") | |
| logging.info(f"Model selected: {model_clicks}") | |
| logging.debug(f"Model output: {decoded_out}") | |
| return model_clicks | |
| def get_model() -> PeftModel: | |
| model_id = 'lil-lab/respect' | |
| checkpoint = "HuggingFaceM4/idefics2-8b" | |
| model = Idefics2ForConditionalGeneration.from_pretrained( | |
| checkpoint, torch_dtype=torch.bfloat16,) | |
| peft_model = PeftModel.from_pretrained( | |
| model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp") | |
| # Add other adapter - hack to avoid conflict | |
| lora_config = copy.deepcopy(peft_model.active_peft_config) | |
| targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters() | |
| if 'lora' in n)) | |
| lora_config.target_modules = targets | |
| peft_model.add_adapter("r0", lora_config) | |
| peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0", | |
| peft_config=lora_config) | |
| return peft_model | |
| def get_processor() -> Idefics2Processor: | |
| checkpoint = "HuggingFaceM4/idefics2-8b" | |
| processor = AutoProcessor.from_pretrained( | |
| checkpoint, do_image_splitting=False, | |
| size={"longest_edge": 224, "shortest_edge": 224}) | |
| return processor | |
| def get_adapter() -> IdeficsAdapter: | |
| processor = get_processor() | |
| return IdeficsAdapter(IMG_DIR, processor) | |
| ### Game logic | |
| class GameState: | |
| config: GameConfig | |
| adapter_name: str | |
| chats: List[str] | |
| currently_selected: List[str] | |
| selected_accum: List[List[str]] | |
| clicks_accum: List[List[str]] | |
| turn: int = 0 | |
| def has_ended(self): | |
| return self.has_successfully_ended() or self.turn >= 10 | |
| def has_successfully_ended(self): | |
| return set(self.currently_selected) == set(self.config.targets) | |
| ### UI helpers | |
| def serialize_conversation(self): | |
| output = [f"Turn {i+1}: {message}" | |
| for i, message in enumerate(self.chats)] | |
| return "\n".join(output) | |
| def markup_images(self): | |
| context = self.config.speaker_context | |
| targets = self.config.targets | |
| selected = self.currently_selected | |
| changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else [] | |
| tangram_list = self._display_context(context, targets, changes, selected) | |
| return tangram_list | |
| def _display_context(context: List[str], targets: List[str], | |
| changes: List[str], selected: List[str]) -> List[Image.Image]: | |
| tangram_list: List[Image.Image] = [] | |
| arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA") | |
| for img in context: | |
| image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB") | |
| image = ImageOps.expand(image, border=2, fill="white") | |
| if img in targets and img in selected: # listener selected a target image | |
| image = ImageOps.expand(image, border=10, fill="green") | |
| elif img in targets and img not in selected: # unselected target: | |
| image = ImageOps.expand(image, border=10, fill="black") | |
| elif img in selected and img not in targets: # listener selected a wrong image | |
| image = ImageOps.expand(image, border=10, fill="red") | |
| else: | |
| image = ImageOps.expand(image, border=10, fill="white") | |
| image = ImageOps.expand(image, border=2, fill="white") | |
| if img in changes: | |
| image.paste(arrow, (68, 0), mask=arrow) | |
| tangram_list.append(image) | |
| return tangram_list | |
| class GameFlow: | |
| def initialize(cls, model_iteration: str) -> GameState: | |
| config = generate_game_config() | |
| adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp" | |
| state = GameState( | |
| config=config, | |
| adapter_name=adapter_name, | |
| chats=[], | |
| currently_selected=[], | |
| selected_accum=[], | |
| clicks_accum=[], | |
| turn=0, | |
| ) | |
| return state | |
| def progress(cls, state: GameState, chat: str, | |
| model: PeftModel, | |
| adapter: IdeficsAdapter) -> GameState: | |
| turn = state.turn | |
| model_context_images = state.config.listener_context | |
| model_clicks = get_model_response( | |
| model, state.adapter_name, adapter, | |
| model_context_images, chat, | |
| state.chats, state.selected_accum | |
| ) | |
| # symmetric difference (apply deselection, then selection) | |
| currently_selected2 = sorted_list( | |
| (set(state.currently_selected) - set(model_clicks)) \ | |
| | (set(model_clicks) - set(state.currently_selected)) | |
| ) | |
| state2 = GameState( | |
| # constants | |
| config=state.config, | |
| adapter_name=state.adapter_name, | |
| # updates | |
| chats=state.chats.copy() + [chat], | |
| currently_selected=currently_selected2, | |
| selected_accum=state.selected_accum.copy() + [currently_selected2], | |
| clicks_accum=state.clicks_accum.copy() + [model_clicks], | |
| turn=turn+1, | |
| ) | |
| return state2 | |
| ### UI | |
| def create_app_inner(): | |
| ### layout | |
| gr.Markdown("# Tangram Multi-Reference Game") | |
| gr.Markdown( | |
| "Your goal is to describe the target tangrams (with black borders) \ | |
| by sending messages in the textbox. \ | |
| You have 10 turns to complete the game.") | |
| gr.Markdown("Targets have black borders. \ | |
| Correctly selected targets have green borders. \ | |
| Incorrectly selected targets have red borders. \ | |
| Actions are marked with yellow dot. \ | |
| The model cannot see boxes or colors and the order is different.") | |
| with gr.Row(): | |
| model_iteration = gr.Radio(["Initial System", "Final System"], | |
| label="Model Iteration", | |
| value="Final System") | |
| start_btn = gr.Button("Start Game") | |
| status = gr.Textbox(label="Status", interactive=False, | |
| text_align="center", value='Press "Start Game"') | |
| with gr.Row(): | |
| image_output = gr.Gallery( | |
| label="CONTEXT", show_label=False, elem_id="gallery", | |
| columns=5, rows=2, object_fit="contain", height="250px", | |
| allow_preview=False, container=True, interactive=False | |
| ) | |
| with gr.Row(): | |
| conversation_output = gr.Textbox(label="Interaction History") | |
| with gr.Column(): | |
| user_input = gr.Textbox(label="Your Message as Speaker", interactive=True) | |
| send_btn = gr.Button("Send", interactive=True) | |
| ### globals | |
| model = get_model() | |
| adapter = get_adapter() | |
| game_state = gr.State(value=None) | |
| ### callbacks | |
| def output_from_state(state: GameState): | |
| has_ended = state.has_ended() | |
| success = "Success" if state.has_successfully_ended() else "Failure" | |
| status = f"{success} (Turn {state.turn}/10) - Start another game?" \ | |
| if has_ended else f"Turn {state.turn+1}/10" | |
| return ( | |
| state.markup_images(), # image_output | |
| state.serialize_conversation(), # conversation_output | |
| status, # status | |
| gr.update(interactive=not has_ended, value=""), # user_input | |
| gr.update(interactive=not has_ended), # send_btn | |
| gr.update(interactive=has_ended), # model_iteration | |
| state, # game_history | |
| ) | |
| def on_start_interaction(model_iteration: str): | |
| assert model_iteration in ["Initial System", "Final System"] | |
| state = GameFlow.initialize(model_iteration) | |
| return output_from_state(state) | |
| def on_send_message(message: str, state: GameState): | |
| nonlocal model | |
| nonlocal adapter | |
| if message.strip() == "": | |
| logging.info("Empty message") | |
| return output_from_state(state) | |
| state = GameFlow.progress(state, message, model, adapter) | |
| return output_from_state(state) | |
| start_btn.click( | |
| on_start_interaction, | |
| inputs=[model_iteration], | |
| outputs=[image_output, conversation_output, status, | |
| user_input, send_btn, model_iteration, game_state], | |
| queue=False | |
| ) | |
| send_btn.click( | |
| on_send_message, | |
| inputs=[user_input, game_state], | |
| outputs=[image_output, conversation_output, status, | |
| user_input, send_btn, model_iteration, game_state], | |
| queue=True | |
| ) | |
| def create_app(): | |
| with gr.Blocks(theme='saq1b/gradio-theme') as app: | |
| create_app_inner() | |
| return app | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.queue() | |
| app.launch() | |