| |
| |
| |
| |
| |
|
|
| r""" |
| Low Ranking Adaptation for LLMs scheme. |
| |
| ┌───────────────────┐ |
| ┆ h ┆ |
| └───────────────────┘ |
| ▲ |
| | |
| + |
| / \ |
| ┌─────────────────┐ ╭───────────────╮ Matrix initialization: |
| ┆ ┆ \ B / B = 0 |
| ┆ pretrained ┆ \ r*d / A = N(0, sigma^2) |
| ┆ weights ┆ ╰─────────╯ |
| ┆ ┆ | r | r - rank |
| ┆ W e R^(d*d) ┆ | ◀─────▶ | |
| ┆ ┆ ╭─────────╮ |
| └─────────────────┘ / A \ |
| ▲ / d*r \ |
| \ ╰───────────────╯ |
| \ ▲ |
| \ / |
| \ / |
| ┌───────────────────┐ |
| ┆ x ┆ |
| └───────────────────┘ |
| |
| With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, |
| we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates |
| for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of |
| course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen |
| pretrained weights and thus fine-tune the model. |
| |
| The goal of this approach is to move weight updates into a separate matrix which is decomposed with |
| two matrices of a lower rank. |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple, Type, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from typing_extensions import Self |
|
|
| import lit_gpt |
| from lit_gpt.config import Config as BaseConfig |
| from lit_gpt.model import GPT as BaseModel |
| from lit_gpt.model import Block as BaseBlock |
| from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention |
| from lit_gpt.model import KVCache |
| from lit_gpt.utils import map_old_state_dict_weights |
|
|
|
|
| class LoRALayer(nn.Module): |
| def __init__(self, r: int, lora_alpha: int, lora_dropout: float): |
| """Store LoRA specific attributes in a class. |
| |
| Args: |
| r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
| the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
| lora_alpha: alpha is needed for scaling updates as alpha/r |
| "This scaling helps to reduce the need to retune hyperparameters when we vary r" |
| https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
| lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
| """ |
| super().__init__() |
| assert r >= 0 |
| self.r = r |
| self.lora_alpha = lora_alpha |
| |
| if lora_dropout > 0.0: |
| self.lora_dropout = nn.Dropout(p=lora_dropout) |
| else: |
| self.lora_dropout = lambda x: x |
| |
| self.merged = False |
|
|
|
|
| class LoRALinear(LoRALayer): |
| |
| def __init__( |
| self, |
| |
| in_features: int, |
| out_features: int, |
| |
| r: int = 0, |
| lora_alpha: int = 1, |
| lora_dropout: float = 0.0, |
| **kwargs, |
| ): |
| """LoRA wrapper around linear class. |
| |
| This class has three weight matrices: |
| 1. Pretrained weights are stored as `self.linear.weight` |
| 2. LoRA A matrix as `self.lora_A` |
| 3. LoRA B matrix as `self.lora_B` |
| Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
| |
| Args: |
| in_features: number of input features of the pretrained weights |
| out_features: number of output features of the pretrained weights |
| r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
| the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
| lora_alpha: alpha is needed for scaling updates as alpha/r |
| "This scaling helps to reduce the need to retune hyperparameters when we vary r" |
| https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
| lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
| """ |
| super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) |
| self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
|
| |
| if r > 0: |
| self.lora_A = nn.Parameter(torch.zeros((r, in_features))) |
| self.lora_B = nn.Parameter(torch.zeros((out_features, r))) |
| self.scaling = self.lora_alpha / self.r |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| """Reset all the weights, even including pretrained ones.""" |
| if hasattr(self, "lora_A"): |
| |
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
|
|
| def merge(self) -> None: |
| """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" |
| if self.r > 0 and not self.merged: |
| |
| self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling |
| self.merged = True |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| pretrained = self.linear(x) |
| if self.r == 0 or self.merged: |
| return pretrained |
| lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling |
| return pretrained + lora |
|
|
|
|
| class LoRAQKVLinear(LoRALinear): |
| |
| def __init__( |
| self, |
| |
| in_features: int, |
| out_features: int, |
| |
| n_head: int, |
| n_query_groups: int, |
| r: int = 0, |
| lora_alpha: int = 1, |
| lora_dropout: float = 0.0, |
| enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, |
| **kwargs, |
| ): |
| """LoRA wrapper around linear class that is used for calculation of q, k and v matrices. |
| |
| This class has three weight matrices: |
| 1. Pretrained weights are stored as `self.linear.weight` |
| 2. LoRA A matrix as `self.lora_A` |
| 3. LoRA B matrix as `self.lora_B` |
| Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
| |
| Args: |
| in_features: number of input features of the pretrained weights |
| out_features: number of output features of the pretrained weights |
| n_head: number of attention heads |
| n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`) |
| r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
| the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
| lora_alpha: alpha is needed for scaling updates as alpha/r |
| "This scaling helps to reduce the need to retune hyperparameters when we vary r" |
| https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
| lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
| enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we |
| don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` |
| and `value` but keep `key` without weight updates we should pass `[True, False, True]` |
| """ |
| super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) |
| self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
| self.n_head = n_head |
| self.n_query_groups = n_query_groups |
| if isinstance(enable_lora, bool): |
| enable_lora = [enable_lora] * 3 |
| assert len(enable_lora) == 3 |
| self.enable_lora = enable_lora |
|
|
| |
| |
| |
| |
| |
| |
| if r > 0 and any(enable_lora): |
| self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) |
| enable_q, enable_k, enable_v = enable_lora |
| self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups) |
| |
| qkv_shapes = ( |
| self.linear.in_features * enable_q, |
| self.kv_embd_size * enable_k, |
| self.kv_embd_size * enable_v, |
| ) |
| self.qkv_shapes = [s for s in qkv_shapes if s] |
| self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| self.scaling = self.lora_alpha / self.r |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.lora_ind = [] |
| if enable_q: |
| self.lora_ind.extend(range(0, self.linear.in_features)) |
| if enable_k: |
| self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) |
| if enable_v: |
| self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) |
| self.reset_parameters() |
|
|
| def zero_pad(self, x: torch.Tensor) -> torch.Tensor: |
| """Properly pad weight updates with zeros. |
| |
| If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, |
| then the weights update should be: |
| |
| [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], |
| [....................................], |
| [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] |
| ↑ ↑ ↑ |
| ________________________________________ |
| | query | key | value | |
| ---------------------------------------- |
| |
| Args: |
| x: tensor with weights update that will be padded with zeros if necessary |
| |
| Returns: |
| A tensor with weight updates and zeros for deselected q, k or v |
| """ |
| |
| if all(self.enable_lora): |
| return x |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| x = x.transpose(0, 1) |
| result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) |
| result = result.view(-1, self.linear.out_features) |
| result = result.index_copy( |
| 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) |
| ) |
| return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) |
|
|
| def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
| """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. |
| |
| If the number of heads is equal to the number of query groups - grouped queries are disabled |
| (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized |
| query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the |
| input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple |
| conv layers side by side). |
| |
| Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, |
| apply each part of the weight matrix to the corresponding input's part and concatenate the result. |
| |
| Args: |
| input: input matrix of shape (B, C, T) |
| weight: weight matrix of shape (C_output, rank, 1). |
| "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). |
| |
| Returns: |
| A tensor with a shape (B, C_output, T) |
| |
| """ |
| if self.n_head == self.n_query_groups: |
| return F.conv1d(input, weight, groups=sum(self.enable_lora)) |
|
|
| |
| |
| |
| |
|
|
| input_splitted = input.chunk(sum(self.enable_lora), dim=1) |
| weight_splitted = weight.split(self.qkv_shapes) |
| return torch.cat( |
| [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 |
| ) |
|
|
| def merge(self) -> None: |
| """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" |
|
|
| |
| |
| |
| |
| if self.r > 0 and any(self.enable_lora) and not self.merged: |
| delta_w = self.conv1d( |
| self.lora_A.data.unsqueeze(0), |
| self.lora_B.data.unsqueeze(-1), |
| ).squeeze( |
| 0 |
| ) |
| |
| self.linear.weight.data += self.zero_pad(delta_w * self.scaling) |
| self.merged = True |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Do the forward pass. |
| |
| If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. |
| If not, then multiply pretrained weights with input, apply LoRA on input and do summation. |
| |
| Args: |
| x: input tensor of shape (batch_size, context_length, embedding_size) |
| |
| Returns: |
| Output tensor of shape (batch_size, context_length, 3 * embedding_size) |
| """ |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| pretrained = self.linear(x) |
| if self.r == 0 or not any(self.enable_lora) or self.merged: |
| return pretrained |
| after_A = F.linear(self.lora_dropout(x), self.lora_A) |
| |
| |
| |
| after_B = self.conv1d( |
| after_A.transpose(-2, -1), |
| self.lora_B.unsqueeze(-1), |
| ).transpose( |
| -2, -1 |
| ) |
| lora = self.zero_pad(after_B) * self.scaling |
| return pretrained + lora |
|
|
|
|
| def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
| """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. |
| |
| Args: |
| model: model with LoRA layers |
| bias: |
| ``"none"``: all bias weights will be frozen, |
| ``"lora_only"``: only bias weight for LoRA layers will be unfrozen, |
| ``"all"``: all bias weights will be unfrozen. |
| |
| Raises: |
| NotImplementedError: if `bias` not in ["none", "lora_only", "all"] |
| """ |
| |
| for n, p in model.named_parameters(): |
| if "lora_" not in n: |
| p.requires_grad = False |
|
|
| |
| if bias == "none": |
| return |
| if bias == "all": |
| for n, p in model.named_parameters(): |
| if "bias" in n: |
| p.requires_grad = True |
| elif bias == "lora_only": |
| for m in model.modules(): |
| if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: |
| m.bias.requires_grad = True |
| else: |
| raise NotImplementedError |
|
|
|
|
| def lora_filter(key: str, value: Any) -> bool: |
| return "lora_" in key |
|
|
|
|
| @dataclass |
| class Config(BaseConfig): |
| """ |
| Args: |
| r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
| the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
| alpha: alpha is needed for scaling updates as alpha/r |
| "This scaling helps to reduce the need to retune hyperparameters when we vary r" |
| https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
| dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
| to_*: either apply LoRA to the specified weights or not |
| """ |
|
|
| r: int = 0 |
| alpha: int = 1 |
| dropout: float = 0.0 |
| to_query: bool = False |
| to_key: bool = False |
| to_value: bool = False |
| to_projection: bool = False |
| to_mlp: bool = False |
| to_head: bool = False |
|
|
| @property |
| def mlp_class(self) -> Type: |
| return getattr(lit_gpt.lora, self._mlp_class) |
|
|
|
|
| class GPT(BaseModel): |
| def __init__(self, config: Config) -> None: |
| nn.Module.__init__(self) |
| assert config.padded_vocab_size is not None |
| self.config = config |
|
|
| self.lm_head = LoRALinear( |
| config.n_embd, |
| config.padded_vocab_size, |
| bias=config.lm_head_bias, |
| r=(config.r if config.to_head else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
| h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
| ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
| ) |
| ) |
| self.max_seq_length = self.config.block_size |
| self.mask_cache: Optional[torch.Tensor] = None |
|
|
| def forward( |
| self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0, maxlen: int = None |
| ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| T = idx.size(1) if maxlen is None else maxlen |
| if self.max_seq_length < T: |
| raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") |
| |
| if input_pos is not None: |
| cos = self.cos.index_select(0, input_pos) |
| sin = self.sin.index_select(0, input_pos) |
| if self.mask_cache is None: |
| raise TypeError("You need to call `gpt.set_kv_cache()`") |
| mask = self.mask_cache.index_select(2, input_pos) |
| else: |
| cos = self.cos[:T] |
| sin = self.sin[:T] |
| mask = None |
|
|
| if type(idx) is tuple: |
| |
| stack_before_tokens_x, motion_tokens, before_len = idx |
| |
| |
| |
| |
| x = self.transformer.wte(stack_before_tokens_x) |
| |
| for i in range(len(x)): |
| x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i] |
| else: |
| x = self.transformer.wte(idx) |
| for block in self.transformer.h: |
| x = block(x, cos, sin, mask, input_pos) |
| x = self.transformer.ln_f(x) |
| if lm_head_chunk_size > 0: |
| |
| return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] |
| return self.lm_head(x) |
|
|
| @classmethod |
| def from_name(cls, name: str, **kwargs: Any) -> Self: |
| return cls(Config.from_name(name, **kwargs)) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" |
| super()._init_weights(module) |
| if isinstance(module, LoRALinear): |
| module.reset_parameters() |
|
|
| def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
| """For compatibility with base checkpoints.""" |
| mapping = {"lm_head.weight": "lm_head.linear.weight"} |
| state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
| class Block(BaseBlock): |
| def __init__(self, config: Config) -> None: |
| nn.Module.__init__(self) |
| self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
| self.attn = CausalSelfAttention(config) |
| if not config.shared_attention_norm: |
| self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) |
| self.mlp = config.mlp_class(config) |
|
|
| self.config = config |
|
|
|
|
| class CausalSelfAttention(BaseCausalSelfAttention): |
| def __init__(self, config: Config) -> None: |
| |
| |
| nn.Module.__init__(self) |
| shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
| |
| self.attn = LoRAQKVLinear( |
| in_features=config.n_embd, |
| out_features=shape, |
| r=config.r, |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| enable_lora=(config.to_query, config.to_key, config.to_value), |
| bias=config.bias, |
| |
| n_head=config.n_head, |
| n_query_groups=config.n_query_groups, |
| ) |
| |
| self.proj = LoRALinear( |
| config.n_embd, |
| config.n_embd, |
| bias=config.bias, |
| r=(config.r if config.to_projection else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
| |
| self.kv_cache: Optional[KVCache] = None |
|
|
| self.config = config |
|
|
| def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
| """For compatibility with base checkpoints.""" |
| mapping = { |
| "attn.weight": "attn.linear.weight", |
| "attn.bias": "attn.linear.bias", |
| "proj.weight": "proj.linear.weight", |
| "proj.bias": "proj.linear.bias", |
| } |
| state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
| class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): |
| def __init__(self, config: Config) -> None: |
| nn.Module.__init__(self) |
| self.fc = LoRALinear( |
| config.n_embd, |
| config.intermediate_size, |
| bias=config.bias, |
| r=(config.r if config.to_mlp else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
| self.proj = LoRALinear( |
| config.intermediate_size, |
| config.n_embd, |
| bias=config.bias, |
| r=(config.r if config.to_mlp else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
|
|
| self.config = config |
|
|
| def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
| """For compatibility with base checkpoints.""" |
| mapping = { |
| "fc.weight": "fc.linear.weight", |
| "fc.bias": "fc.linear.bias", |
| "proj.weight": "proj.linear.weight", |
| "proj.bias": "proj.linear.bias", |
| } |
| state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
| class LLaMAMLP(lit_gpt.model.LLaMAMLP): |
| def __init__(self, config: Config) -> None: |
| nn.Module.__init__(self) |
| self.fc_1 = LoRALinear( |
| config.n_embd, |
| config.intermediate_size, |
| bias=config.bias, |
| r=(config.r if config.to_mlp else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
| self.fc_2 = LoRALinear( |
| config.n_embd, |
| config.intermediate_size, |
| bias=config.bias, |
| r=(config.r if config.to_mlp else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
| self.proj = LoRALinear( |
| config.intermediate_size, |
| config.n_embd, |
| bias=config.bias, |
| r=(config.r if config.to_mlp else 0), |
| lora_alpha=config.alpha, |
| lora_dropout=config.dropout, |
| ) |
|
|
| def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
| """For compatibility with base checkpoints.""" |
| mapping = { |
| "fc_1.weight": "fc_1.linear.weight", |
| "fc_1.bias": "fc_1.linear.bias", |
| "fc_2.weight": "fc_2.linear.weight", |
| "fc_2.bias": "fc_2.linear.bias", |
| "proj.weight": "proj.linear.weight", |
| "proj.bias": "proj.linear.bias", |
| } |
| state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
| def merge_lora_weights(model: GPT) -> None: |
| """Merge LoRA weights into the full-rank weights to speed up inference.""" |
| for module in model.modules(): |
| if isinstance(module, LoRALinear): |
| module.merge() |
|
|