Spaces:
Sleeping
Sleeping
| """Developed by Ruslan Magana Vsevolodovna""" | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| import random | |
| from themes.research_monochrome import theme | |
| # ============================================================================= | |
| # Constants & Prompts | |
| # ============================================================================= | |
| today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
| SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
| Today's Date: {today_date}. | |
| You are Granite, developed by IBM. You are a helpful AI assistant""" | |
| TITLE = "IBM Granite 3.1 8b Instruct & Vision Preview" | |
| DESCRIPTION = """ | |
| <p>Granite 3.1 8b instruct is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample prompts | |
| or enter your own. Keep in mind that AI can occasionally make mistakes. | |
| <span class="gr_docs_link"> | |
| <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a> | |
| </span> | |
| </p> | |
| """ | |
| MAX_INPUT_TOKEN_LENGTH = 128_000 | |
| MAX_NEW_TOKENS = 1024 | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.85 | |
| TOP_K = 50 | |
| REPETITION_PENALTY = 1.05 | |
| # Vision defaults (advanced settings) | |
| VISION_TEMPERATURE = 0.2 | |
| VISION_TOP_P = 0.95 | |
| VISION_TOP_K = 50 | |
| VISION_MAX_TOKENS = 128 | |
| if not torch.cuda.is_available(): | |
| print("This demo may not work on CPU.") | |
| # ============================================================================= | |
| # Text Model Loading | |
| # ============================================================================= | |
| text_model = AutoModelForCausalLM.from_pretrained( | |
| "ibm-granite/granite-3.1-8b-instruct", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct") | |
| tokenizer.use_default_system_prompt = False | |
| # ============================================================================= | |
| # Vision Model Loading | |
| # ============================================================================= | |
| vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview" | |
| vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True) | |
| vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| vision_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True # Ensure the custom code is used so that weight shapes match. | |
| ) | |
| # ============================================================================= | |
| # Text Generation Function (for text-only chat) | |
| # ============================================================================= | |
| def generate( | |
| message: str, | |
| chat_history: list[dict], | |
| temperature: float = TEMPERATURE, | |
| repetition_penalty: float = REPETITION_PENALTY, | |
| top_p: float = TOP_P, | |
| top_k: float = TOP_K, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| ) -> Iterator[str]: | |
| """Generate function for text chat demo.""" | |
| conversation = [] | |
| conversation.append({"role": "system", "content": SYS_PROMPT}) | |
| conversation.extend(chat_history) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens, | |
| ) | |
| input_ids = input_ids.to(text_model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=text_model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| # ============================================================================= | |
| # Vision Chat Inference Function (for image+text chat) | |
| # ============================================================================= | |
| def get_text_from_content(content): | |
| texts = [] | |
| for item in content: | |
| if item["type"] == "text": | |
| texts.append(item["text"]) | |
| elif item["type"] == "image": | |
| texts.append("<image>") | |
| return " ".join(texts) | |
| def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS): | |
| if conversation is None: | |
| conversation = [] | |
| user_content = [] | |
| if image is not None: | |
| user_content.append({"type": "image", "image": image}) | |
| if text and text.strip(): | |
| user_content.append({"type": "text", "text": text.strip()}) | |
| if not user_content: | |
| return display_vision_conversation(conversation), conversation | |
| conversation.append({"role": "user", "content": user_content}) | |
| inputs = vision_processor.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to("cuda") | |
| torch.manual_seed(random.randint(0, 10000)) | |
| generation_kwargs = { | |
| "max_new_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "do_sample": True, | |
| } | |
| output = vision_model.generate(**inputs, **generation_kwargs) | |
| assistant_response = vision_processor.decode(output[0], skip_special_tokens=True) | |
| conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response.strip()}]}) | |
| return display_vision_conversation(conversation), conversation | |
| # ============================================================================= | |
| # Helper Functions to Format Conversation for Display | |
| # ============================================================================= | |
| def display_text_conversation(conversation): | |
| """Convert a text conversation (list of dicts) into a list of (user, assistant) tuples.""" | |
| chat_history = [] | |
| i = 0 | |
| while i < len(conversation): | |
| if conversation[i]["role"] == "user": | |
| user_msg = conversation[i]["content"] | |
| assistant_msg = "" | |
| if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant": | |
| assistant_msg = conversation[i+1]["content"] | |
| i += 2 | |
| else: | |
| i += 1 | |
| chat_history.append((user_msg, assistant_msg)) | |
| else: | |
| i += 1 | |
| return chat_history | |
| def display_vision_conversation(conversation): | |
| """Convert a vision conversation (with mixed content types) into a list of (user, assistant) tuples.""" | |
| chat_history = [] | |
| i = 0 | |
| while i < len(conversation): | |
| if conversation[i]["role"] == "user": | |
| user_msg = get_text_from_content(conversation[i]["content"]) | |
| assistant_msg = "" | |
| if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant": | |
| # Extract assistant text; remove any special tokens if present. | |
| assistant_msg = conversation[i+1]["content"][0]["text"].split("<|assistant|>")[-1].strip() | |
| i += 2 | |
| else: | |
| i += 1 | |
| chat_history.append((user_msg, assistant_msg)) | |
| else: | |
| i += 1 | |
| return chat_history | |
| # ============================================================================= | |
| # Unified Send-Message Function | |
| # ============================================================================= | |
| def send_message(image, text, | |
| text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens, | |
| vision_temperature, vision_top_p, vision_top_k, vision_max_tokens, | |
| text_state, vision_state): | |
| """ | |
| If an image is uploaded, use the vision model; otherwise, use the text model. | |
| Returns updated conversation (as a list of tuples) and state for each branch. | |
| """ | |
| if image is not None: | |
| # Vision branch | |
| conv = vision_state if vision_state is not None else [] | |
| chat_history, updated_conv = chat_inference( | |
| image, text, conv, | |
| temperature=vision_temperature, | |
| top_p=vision_top_p, | |
| top_k=vision_top_k, | |
| max_tokens=vision_max_tokens | |
| ) | |
| vision_state = updated_conv | |
| # In vision mode, the conversation display is produced from the vision branch. | |
| return chat_history, text_state, vision_state | |
| else: | |
| # Text branch | |
| conv = text_state if text_state is not None else [] | |
| output_text = "" | |
| for chunk in generate( | |
| text, conv, | |
| temperature=text_temperature, | |
| repetition_penalty=text_repetition_penalty, | |
| top_p=text_top_p, | |
| top_k=text_top_k, | |
| max_new_tokens=text_max_new_tokens | |
| ): | |
| output_text = chunk | |
| conv.append({"role": "user", "content": text}) | |
| conv.append({"role": "assistant", "content": output_text}) | |
| text_state = conv | |
| chat_history = display_text_conversation(text_state) | |
| return chat_history, text_state, vision_state | |
| def clear_chat(): | |
| # Clear the conversation and input fields. | |
| return [], [], [], None # (chat_history, text_state, vision_state, cleared text and image inputs) | |
| # ============================================================================= | |
| # UI Layout with Gradio | |
| # ============================================================================= | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
| gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"]) | |
| gr.HTML(DESCRIPTION) | |
| chatbot = gr.Chatbot(label="Chat History", height=500) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Upload Image (optional)") | |
| text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Text Advanced Settings", open=False): | |
| text_temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]) | |
| repetition_penalty_slider = gr.Slider(minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.05, label="Repetition Penalty", elem_classes=["gr_accordion_element"]) | |
| top_p_slider = gr.Slider(minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]) | |
| top_k_slider = gr.Slider(minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]) | |
| max_new_tokens_slider = gr.Slider(minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens", elem_classes=["gr_accordion_element"]) | |
| with gr.Accordion("Vision Advanced Settings", open=False): | |
| vision_temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=VISION_TEMPERATURE, step=0.01, label="Vision Temperature", elem_classes=["gr_accordion_element"]) | |
| vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"]) | |
| vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"]) | |
| vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"]) | |
| send_button = gr.Button("Send Message") | |
| clear_button = gr.Button("Clear Chat") | |
| # Conversation state variables for each branch. | |
| text_state = gr.State([]) | |
| vision_state = gr.State([]) | |
| send_button.click( | |
| send_message, | |
| inputs=[ | |
| image_input, text_input, | |
| text_temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider, | |
| vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider, | |
| text_state, vision_state | |
| ], | |
| outputs=[chatbot, text_state, vision_state] | |
| ) | |
| clear_button.click( | |
| clear_chat, | |
| inputs=None, | |
| outputs=[chatbot, text_state, vision_state, text_input, image_input] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"], | |
| [None, "Explain quantum computing to a beginner."], | |
| [None, "What is OpenShift?"], | |
| [None, "Importance of low latency inference"], | |
| [None, "Boosting productivity habits"], | |
| [None, "Explain and document your code"], | |
| [None, "Generate Java Code"] | |
| ], | |
| inputs=[image_input, text_input], | |
| example_labels=[ | |
| "Vision Example: What is in this image?", | |
| "Explain quantum computing", | |
| "What is OpenShift?", | |
| "Importance of low latency inference", | |
| "Boosting productivity habits", | |
| "Explain and document your code", | |
| "Generate Java Code" | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |