Spaces:
Running
on
Zero
Running
on
Zero
v1
Browse files- app.py +8 -15
- meteor/arch/modeling_internlm2.py +2 -2
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -8,13 +8,17 @@ from PIL import Image
|
|
| 8 |
from utils.utils import *
|
| 9 |
from threading import Thread
|
| 10 |
import torch.nn.functional as F
|
|
|
|
| 11 |
from meteor.load_mmamba import load_mmamba
|
| 12 |
from meteor.load_meteor import load_meteor
|
| 13 |
from transformers import TextIteratorStreamer
|
| 14 |
from torchvision.transforms.functional import pil_to_tensor
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
# loading meteor model
|
| 17 |
-
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
|
| 18 |
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=16)
|
| 19 |
|
| 20 |
# freeze model
|
|
@@ -24,7 +28,6 @@ freeze_model(meteor)
|
|
| 24 |
# previous length
|
| 25 |
previous_length = 0
|
| 26 |
|
| 27 |
-
@spaces.GPU
|
| 28 |
def threading_function(inputs, image_token_number, streamer, device):
|
| 29 |
|
| 30 |
# Meteor Mamba
|
|
@@ -49,24 +52,14 @@ def threading_function(inputs, image_token_number, streamer, device):
|
|
| 49 |
generation_kwargs.update({'use_cache': True})
|
| 50 |
return meteor.generate(**generation_kwargs)
|
| 51 |
|
| 52 |
-
def add_message(history, message):
|
| 53 |
-
for x in message["files"]:
|
| 54 |
-
history.append(((x,), None))
|
| 55 |
-
if message["text"] is not None:
|
| 56 |
-
history.append((message["text"], None))
|
| 57 |
-
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
| 58 |
-
|
| 59 |
@spaces.GPU
|
| 60 |
def bot_streaming(message, history):
|
| 61 |
|
| 62 |
-
# device
|
| 63 |
-
device = torch.cuda.current_device()
|
| 64 |
-
|
| 65 |
# param
|
| 66 |
for param in mmamba.parameters():
|
| 67 |
-
param.data = param.to(device)
|
| 68 |
for param in meteor.parameters():
|
| 69 |
-
param.data = param.to(device)
|
| 70 |
|
| 71 |
# prompt type -> input prompt
|
| 72 |
image_token_number = int((490/14)**2)
|
|
@@ -83,7 +76,7 @@ def bot_streaming(message, history):
|
|
| 83 |
streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
|
| 84 |
|
| 85 |
# Threading generation
|
| 86 |
-
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=device))
|
| 87 |
thread.start()
|
| 88 |
|
| 89 |
# generated text
|
|
|
|
| 8 |
from utils.utils import *
|
| 9 |
from threading import Thread
|
| 10 |
import torch.nn.functional as F
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
from meteor.load_mmamba import load_mmamba
|
| 13 |
from meteor.load_meteor import load_meteor
|
| 14 |
from transformers import TextIteratorStreamer
|
| 15 |
from torchvision.transforms.functional import pil_to_tensor
|
| 16 |
|
| 17 |
+
# accel
|
| 18 |
+
accel = Accelerator()
|
| 19 |
+
|
| 20 |
# loading meteor model
|
| 21 |
+
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
|
| 22 |
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=16)
|
| 23 |
|
| 24 |
# freeze model
|
|
|
|
| 28 |
# previous length
|
| 29 |
previous_length = 0
|
| 30 |
|
|
|
|
| 31 |
def threading_function(inputs, image_token_number, streamer, device):
|
| 32 |
|
| 33 |
# Meteor Mamba
|
|
|
|
| 52 |
generation_kwargs.update({'use_cache': True})
|
| 53 |
return meteor.generate(**generation_kwargs)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
@spaces.GPU
|
| 56 |
def bot_streaming(message, history):
|
| 57 |
|
|
|
|
|
|
|
|
|
|
| 58 |
# param
|
| 59 |
for param in mmamba.parameters():
|
| 60 |
+
param.data = param.to(accel.device)
|
| 61 |
for param in meteor.parameters():
|
| 62 |
+
param.data = param.to(accel.device)
|
| 63 |
|
| 64 |
# prompt type -> input prompt
|
| 65 |
image_token_number = int((490/14)**2)
|
|
|
|
| 76 |
streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
|
| 77 |
|
| 78 |
# Threading generation
|
| 79 |
+
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=accel.device))
|
| 80 |
thread.start()
|
| 81 |
|
| 82 |
# generated text
|
meteor/arch/modeling_internlm2.py
CHANGED
|
@@ -277,8 +277,8 @@ def rotate_half(x):
|
|
| 277 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
| 278 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 279 |
"""Applies Rotary Position Embedding to the query and key tensors."""
|
| 280 |
-
cos = cos
|
| 281 |
-
sin = sin
|
| 282 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 283 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 284 |
return q_embed, k_embed
|
|
|
|
| 277 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
| 278 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 279 |
"""Applies Rotary Position Embedding to the query and key tensors."""
|
| 280 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 281 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 282 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 283 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 284 |
return q_embed, k_embed
|
requirements.txt
CHANGED
|
@@ -13,4 +13,5 @@ timm
|
|
| 13 |
shortuuid
|
| 14 |
matplotlib
|
| 15 |
gradio
|
| 16 |
-
spaces
|
|
|
|
|
|
| 13 |
shortuuid
|
| 14 |
matplotlib
|
| 15 |
gradio
|
| 16 |
+
spaces
|
| 17 |
+
accelerate
|