| | import gradio as gr |
| | import torch |
| | from mamba_model import MambaModel |
| |
|
| | |
| | model = MambaModel.from_pretrained(pretrained_model_name="Zyphra/BlackMamba-2.8B") |
| | model = model.cuda().half() |
| |
|
| | |
| | def generate_output(input_text): |
| | |
| | try: |
| | input_ids = [int(x.strip()) for x in input_text.split(",")] |
| | inputs = torch.tensor(input_ids).cuda().long().unsqueeze(0) |
| | |
| | |
| | with torch.no_grad(): |
| | out = model(inputs) |
| | |
| | |
| | return out.cpu().numpy().tolist() |
| | except Exception as e: |
| | return f"Error: {str(e)}" |
| |
|
| | |
| | input_component = gr.Textbox(label="Input IDs (comma-separated)", placeholder="Enter input IDs like: 1, 2") |
| | output_component = gr.Textbox(label="Output") |
| |
|
| | iface = gr.Interface(fn=generate_output, inputs=input_component, outputs=output_component, title="BlackMamba Model") |
| |
|
| | |
| | if __name__ == "__main__": |
| | iface.launch() |