Spaces:
Running
Running
| """Utilities for the Fujitsu-LLM-KG-8x7B models. | |
| """ | |
| from typing import Literal, Sequence, Tuple | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| ############################################################################### | |
| # Generation | |
| ############################################################################### | |
| class Fujitsu_LLM_KG: | |
| """The Fujitsu-LLM-KG-8x7B model. | |
| """ | |
| def __init__(self, model_id: str, *, device_map: str = "auto") -> None: | |
| """Initializes the model and tokenizer. | |
| """ | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map=device_map, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| def generate(self, prompt:str, | |
| *, | |
| max_new_tokens: int = 2048, | |
| num_beams: int = 1, | |
| ) -> str: | |
| """Generate an answer. | |
| """ | |
| tokenized = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| tokenized["input_ids"].to("cuda"), | |
| attention_mask=tokenized["attention_mask"].to("cuda"), | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| num_beams=num_beams, | |
| ) | |
| answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):] | |
| return answer | |
| ############################################################################### | |
| # Extraction | |
| ############################################################################### | |
| def extract_turtle(text: str, *, with_rationale = False) -> str: | |
| """Extracts the RDF Turtle part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model. | |
| """ | |
| TOKENS = ["<", "rel:", "rdf:", "]"] | |
| if with_rationale: | |
| TOKENS.append("#@") | |
| turtle = "" | |
| for line in text.splitlines(): | |
| line_ = line.strip() | |
| if line == "" or any(line_.startswith(c) for c in TOKENS): | |
| if turtle: | |
| turtle += "\n" | |
| turtle += line | |
| return turtle | |
| def extract_answer(text: str) -> Tuple[str, Sequence[str]]: | |
| """Extracts the final answer part from the output text of Fujitsu-LLM-KG-8x7B_inst-infer model. | |
| """ | |
| path = [] | |
| answer = "" | |
| state: Literal["path", "answer"] = "path" | |
| for line in text.splitlines(): | |
| if line.strip() and "```" not in line and "## " not in line: | |
| if state == "path": | |
| path.append(line) | |
| elif state == "answer": | |
| if answer: | |
| answer += "\n" | |
| answer += line | |
| if "## Explore Path" in line: | |
| state = "path" | |
| path = [] | |
| elif "## Answer" in line: | |
| state = "answer" | |
| answer = "" | |
| elif "```" in line and answer: | |
| break | |
| path = tuple(p.strip() for p in path) | |
| answer = answer.strip() | |
| return answer, path | |