| """This module contains functionalities for running inference on Gemma 2 model |
| finetuned for urgency detection using the HuggingFace library. |
| """ |
|
|
| |
| import ast |
|
|
| from textwrap import dedent |
| from typing import Any, Optional |
|
|
| |
| import torch |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase |
|
|
|
|
| def _construct_prompt(*, rules_list: list[str]) -> str: |
| """Construct the prompt for the finetuned model. |
| |
| Parameters |
| ---------- |
| rules_list |
| The list of urgency rules to match against the user message. |
| |
| Returns |
| ------- |
| str |
| The prompt for the finetuned model. |
| """ |
|
|
| _prompt_base: str = dedent( |
| """ |
| You are a highly sensitive urgency detector. Score if ANY part of the |
| user message corresponds to any part of the urgency rules provided below. |
| Ignore any part of the user message that does not correspond to the rules. |
| Respond with (a) the rule that is most consistent with the user message, |
| (b) the probability between 0 and 1 with increments of 0.1 that ANY part of |
| the user message matches the rule, and (c) the reason for the probability. |
| |
| |
| Respond in json string: |
| |
| { |
| best_matching_rule: str |
| probability: float |
| reason: str |
| } |
| """ |
| ).strip() |
| _prompt_rules: str = dedent( |
| """ |
| Urgency Rules: |
| {urgency_rules} |
| """ |
| ).strip() |
| urgency_rules_str = "\n".join( |
| [f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)] |
| ) |
| prompt = ( |
| _prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str) |
| ) |
| return prompt |
|
|
|
|
| def get_completions( |
| *, |
| model, |
| rules_list: list[str], |
| skip_special_tokens_during_decode: bool = False, |
| text_generation_params: Optional[dict[str, Any]] = None, |
| tokenizer: PreTrainedTokenizerBase, |
| user_message: str, |
| ) -> dict[str, Any]: |
| """Get completions from the model for the given data. |
| |
| Parameters |
| ---------- |
| model |
| The model for inference. |
| rules_list |
| The list of urgency rules to match against the user message. |
| skip_special_tokens_during_decode |
| Specifies whether to skip special tokens during the decoding process. |
| text_generation_params |
| Dictionary containing text generation parameters for the LLM model. If not |
| specified, then default values will be used. |
| tokenizer |
| The tokenizer for the model. |
| user_message |
| The user message to match against the urgency rules. |
| |
| Returns |
| ------- |
| dict[str, Any] |
| The completion from the model. If the model output does not produce a valid |
| JSON string, then the original output is returned in the "generated_json" key. |
| """ |
|
|
| assert all(x for x in rules_list), "Rules must be non-empty strings!" |
| text_generation_params = text_generation_params or { |
| "do_sample": True, |
| "eos_token_id": tokenizer.eos_token_id, |
| "max_new_tokens": 1024, |
| "num_return_sequences": 1, |
| "repetition_penalty": 1.1, |
| "temperature": 1e-6, |
| "top_p": 0.9, |
| } |
| tokenizer.add_special_tokens = False |
|
|
| start_of_turn, end_of_turn = tokenizer.additional_special_tokens |
| eos = tokenizer.eos_token |
| start_of_turn_model = f"{start_of_turn}model" |
| end_of_turn_model = f"{end_of_turn}{eos}" |
| input_ = ( |
| _construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}" |
| ) |
| chat = [{"role": "user", "content": input_}] |
| prompt = tokenizer.apply_chat_template( |
| chat, add_generation_prompt=True, tokenize=False |
| ) |
| inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") |
| outputs = model.generate( |
| input_ids=inputs.to(model.device), **text_generation_params |
| ) |
| decoded_output = tokenizer.decode( |
| outputs[0], skip_special_tokens=skip_special_tokens_during_decode |
| ) |
| completion_dict = {"user_message": user_message, "generated_json": decoded_output} |
| try: |
| start_of_turn_model_index = decoded_output.index(start_of_turn_model) |
| end_of_turn_model_index = decoded_output.index(end_of_turn_model) |
| generated_response = decoded_output[ |
| start_of_turn_model_index |
| + len(start_of_turn_model) : end_of_turn_model_index |
| ].strip() |
| completion_dict["generated_json"] = ast.literal_eval(generated_response) |
| except (SyntaxError, ValueError): |
| pass |
| return completion_dict |
|
|
|
|
| if __name__ == "__main__": |
| DTYPE = torch.bfloat16 |
| MODEL_ID = "idinsight/gemma-2-2b-it-ud" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False) |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE |
| ) |
|
|
| text_generation_params = { |
| "do_sample": True, |
| "eos_token_id": tokenizer.eos_token_id, |
| "max_new_tokens": 1024, |
| "num_return_sequences": 1, |
| "repetition_penalty": 1.1, |
| "temperature": 1e-6, |
| "top_p": 0.9, |
| } |
|
|
| response = get_completions( |
| model=model, |
| rules_list=[ |
| "NOT URGENT", |
| "Bleeding from the vagina", |
| "Bad tummy pain", |
| "Bad headache that won’t go away", |
| "Bad headache that won’t go away", |
| "Changes to vision", |
| "Trouble breathing", |
| "Hot or very cold, and very weak", |
| "Fits or uncontrolled shaking", |
| "Baby moves less", |
| "Fluid from the vagina", |
| "Feeding problems", |
| "Fits or uncontrolled shaking", |
| "Fast, slow or difficult breathing", |
| "Too hot or cold", |
| "Baby’s colour changes", |
| "Vomiting and watery poo", |
| "Infected belly button", |
| "Swollen or infected eyes", |
| "Bulging or sunken soft spot", |
| ], |
| skip_special_tokens_during_decode=False, |
| text_generation_params=text_generation_params, |
| tokenizer=tokenizer, |
| user_message="If my newborn can't able to breathe what can i do", |
| ) |
| print(f"{response = }") |
|
|