| import os | |
| from threading import Event, Thread | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, | |
| ) | |
| from huggingface_hub import login | |
| import gradio as gr | |
| import torch | |
| login(os.getenv("HF_TOKEN", None)) | |
| model_name = "richardr1126/spider-natsql-wizard-coder-8bit" | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| max_new_tokens = 1536 | |
| print(f"Starting to load the model {model_name}") | |
| m = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map=0, | |
| load_in_8bit=True, | |
| ) | |
| m.config.pad_token_id = m.config.eos_token_id | |
| m.generation_config.pad_token_id = m.config.eos_token_id | |
| stop_tokens = [";", "###", "Result"] | |
| stop_token_ids = tok.convert_tokens_to_ids(stop_tokens) | |
| print(f"Successfully loaded the model {model_name} into memory") | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| for stop_id in stop_token_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| def bot(input_message: str, temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08): | |
| stop = StopOnTokens() | |
| messages = input_message | |
| input_ids = tok(messages, return_tensors="pt").input_ids | |
| input_ids = input_ids.to(m.device) | |
| streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0.0, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| streamer=streamer, | |
| stopping_criteria=StoppingCriteriaList([stop]), | |
| ) | |
| stream_complete = Event() | |
| def generate_and_signal_complete(): | |
| m.generate(**generate_kwargs) | |
| stream_complete.set() | |
| t1 = Thread(target=generate_and_signal_complete) | |
| t1.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| return partial_text | |
| gradio_interface = gr.Interface( | |
| fn=bot, | |
| inputs=[ | |
| "text", | |
| gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1), | |
| gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01), | |
| gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1), | |
| gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.1) | |
| ], | |
| outputs="text", | |
| title="REST API with Gradio and Huggingface Spaces", | |
| description="This is a demo of how to build an AI powered REST API with Gradio and Huggingface Spaces – for free! See the **Use via API** link at the bottom of this page.", | |
| ) | |
| gradio_interface.launch() | |