| import gradio as gr |
| import torch |
| from mamba_model import MambaModel |
|
|
| |
| model = MambaModel.from_pretrained(pretrained_model_name="Zyphra/BlackMamba-2.8B") |
| model = model.cuda().half() |
|
|
| |
| def generate_output(input_text): |
| |
| try: |
| input_ids = [int(x.strip()) for x in input_text.split(",")] |
| inputs = torch.tensor(input_ids).cuda().long().unsqueeze(0) |
| |
| |
| with torch.no_grad(): |
| out = model(inputs) |
| |
| |
| return out.cpu().numpy().tolist() |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| input_component = gr.Textbox(label="Input IDs (comma-separated)", placeholder="Enter input IDs like: 1, 2") |
| output_component = gr.Textbox(label="Output") |
|
|
| iface = gr.Interface(fn=generate_output, inputs=input_component, outputs=output_component, title="BlackMamba Model") |
|
|
| |
| if __name__ == "__main__": |
| iface.launch() |