Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from dataclasses import dataclass | |
| from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer | |
| from optimum.onnxruntime import ORTModelForCausalLM | |
| import onnx | |
| import logging | |
| from threading import Thread | |
| logging.basicConfig(level=logging.INFO) | |
| # ----------------------------------------------------------------------------- | |
| # Configuration and Special Tokens | |
| # ----------------------------------------------------------------------------- | |
| SPECIAL_TOKENS = { | |
| "bos": "<|bos|>", | |
| "eot": "<|eot|>", | |
| "user": "<|user|>", | |
| "assistant": "<|assistant|>", | |
| "system": "<|system|>", | |
| "think": "<|think|>", | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) | |
| SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------------------------------------------------------- | |
| # Custom Model Configuration | |
| # ----------------------------------------------------------------------------- | |
| class Sam3Config(PretrainedConfig): | |
| vocab_size: int = 50257 | |
| d_model: int = 384 | |
| n_layers: int = 10 | |
| n_heads: int = 6 | |
| ff_mult: float = 4.0 | |
| dropout: float = 0.1 | |
| input_modality: str = "text" | |
| head_type: str = "causal_lm" | |
| version: str = "0.1" | |
| _attn_implementation_internal: str = "eager" | |
| is_encoder_decoder: bool = False | |
| hidden_size: int = 384 | |
| num_attention_heads: int = 6 | |
| def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.n_layers = n_layers | |
| self.n_heads = n_heads | |
| self.ff_mult = ff_mult | |
| self.dropout = dropout | |
| self.input_modality = input_modality | |
| self.head_type = head_type | |
| self.version = version | |
| self.hidden_size = self.d_model | |
| self.num_attention_heads = self.n_heads | |
| # Instantiate the custom configuration | |
| model_config = Sam3Config() | |
| # Load the ONNX model by providing the configuration | |
| try: | |
| model = ORTModelForCausalLM.from_pretrained( | |
| "Smilyai-labs/Sam-3.0-2-onnx", | |
| config=model_config, | |
| trust_remote_code=True, | |
| ) | |
| logging.info("ONNX model loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Failed to load ONNX model: {e}") | |
| raise e | |
| # ----------------------------------------------------------------------------- | |
| # Streaming Generation Function | |
| # ----------------------------------------------------------------------------- | |
| def generate_text_stream(prompt, max_length, temperature, top_k, top_p): | |
| """ | |
| This function acts as a generator to stream text. | |
| It yields each new token as it's generated by the model. | |
| """ | |
| # Create a streamer to iterate over the generated tokens | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Prepare the generation inputs | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| # Set generation parameters within a GenerationConfig object | |
| # We explicitly set use_cache=False to avoid the ONNX export bug | |
| gen_config = GenerationConfig( | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| do_sample=True, | |
| use_cache=False, | |
| ) | |
| # Create a thread to run the generation in the background | |
| generation_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| generation_config=gen_config, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield each token from the streamer as it is generated | |
| for new_text in streamer: | |
| yield new_text | |
| # ----------------------------------------------------------------------------- | |
| # Gradio Interface | |
| # ----------------------------------------------------------------------------- | |
| demo = gr.Interface( | |
| fn=generate_text_stream, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", lines=2), | |
| gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature"), | |
| gr.Slider(minimum=1, maximum=100, value=60, label="Top K"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"), | |
| ], | |
| outputs="text", | |
| title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)", | |
| description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.", | |
| ) | |
| demo.launch() |