Spaces:
Running
on
Zero
Running
on
Zero
| """Template Demo for IBM Granite Hugging Face spaces.""" | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from themes.research_monochrome import theme | |
| today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
| SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
| Today's Date: {today_date}. | |
| You are Granite, developed by IBM. You are a helpful AI assistant""" | |
| TITLE = "IBM Granite 3.1 8b Instruct" | |
| DESCRIPTION = """ | |
| <p>Granite 3.1 is a general purpose large language model released in the open under an Apache 2.0 license. Granite | |
| models support a 128k context length.</p> | |
| <p>Try one of the sample prompts below or write your own. Remember, AI models can make mistakes. | |
| <span class="gr_docs_link"> | |
| <a href="https://www.ibm.com/granite/docs/">View Documentation</a> <i class="fa fa-external-link"></i> | |
| </span> | |
| </p> | |
| """ | |
| MAX_INPUT_TOKEN_LENGTH = 128_000 | |
| MAX_NEW_TOKENS = 1024 | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.85 | |
| TOP_K = 50 | |
| REPETITION_PENALTY = 1.05 | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\nThis demo does not work on CPU." | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct") | |
| tokenizer.use_default_system_prompt = False | |
| def generate( | |
| message: str, | |
| chat_history: list[dict], | |
| temperature: float = TEMPERATURE, | |
| top_p: float = TOP_P, | |
| top_k: float = TOP_K, | |
| repetition_penalty: float = REPETITION_PENALTY, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| ) -> Iterator[str]: | |
| """Generate function for chat demo.""" | |
| # Build messages | |
| conversation = [] | |
| conversation.append({"role": "system", "content": SYS_PROMPT}) | |
| conversation += chat_history | |
| conversation.append({"role": "user", "content": message}) | |
| # Convert messages to prompt format | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens, | |
| ) | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| # advanced settings (displayed in Accordion) | |
| temperature_slider = gr.Slider( | |
| minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"] | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"] | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"] | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=0, | |
| maximum=2.0, | |
| value=REPETITION_PENALTY, | |
| step=0.1, | |
| label="Repetition Penalty", | |
| elem_classes=["gr_accordion_element"], | |
| ) | |
| max_new_tokens_slider = gr.Slider( | |
| minimum=1, | |
| maximum=2000, | |
| value=MAX_NEW_TOKENS, | |
| step=1, | |
| label="Max New Tokens", | |
| elem_classes=["gr_accordion_element"], | |
| ) | |
| chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False) | |
| with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
| gr.HTML( | |
| f"<img src='https://www.ibm.com/granite/docs/images/granite-pictogram.svg'/><h1>{TITLE}</h1>", | |
| elem_classes=["gr_title"], | |
| ) | |
| gr.HTML(DESCRIPTION) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| examples=[ | |
| ["Explain quantum computing"], | |
| ["What is OpenShift?"], | |
| ["Importance of low latency inference"], | |
| ["Write a binary search in Python"], | |
| ], | |
| cache_examples=False, | |
| type="messages", | |
| additional_inputs=[ | |
| temperature_slider, | |
| repetition_penalty_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| max_new_tokens_slider, | |
| ], | |
| additional_inputs_accordion=chat_interface_accordion, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |