Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import InferenceClient, login | |
| import tensorflow as tf | |
| from transformers import AutoTokenizer, TFAutoModelForCausalLM | |
| import os | |
| # Set up mixed precision and distribution strategy | |
| policy = tf.keras.mixed_precision.Policy('mixed_bfloat16') | |
| tf.keras.mixed_precision.set_global_policy(policy) | |
| strategy = tf.distribute.MultiWorkerMirroredStrategy() | |
| # Log into Hugging Face | |
| login(os.environ.get("hf_token")) | |
| # Load tokenizer and model | |
| name = "WICKED4950/GPT2-InstEsther0.28eV3.1" | |
| tokenizer = AutoTokenizer.from_pretrained(name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| with strategy.scope(): | |
| model = TFAutoModelForCausalLM.from_pretrained(name) | |
| # Raw Prediction Function | |
| def raw_pred(input, model, tokenizer, max_length=1024, temperature=0.2): | |
| input_ids = tokenizer.encode(input, return_tensors='tf') | |
| # Initialize variables | |
| generated_ids = input_ids | |
| stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0] | |
| all_generated_tokens = [] # To store generated token IDs | |
| tokens_yielded = [] # To store tokens as they are yielded | |
| with strategy.scope(): | |
| for _ in range(max_length // 1): # Generate in chunks of 3 tokens | |
| # Generate tokens | |
| outputs = model.generate( | |
| generated_ids, | |
| max_length=generated_ids.shape[1] + 1, | |
| temperature=temperature, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=stop_token_id, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| # Get the newly generated tokens | |
| new_tokens = outputs[0, -1:] | |
| generated_ids = outputs # Update the generated_ids with the new tokens | |
| # Store and yield the generated tokens | |
| all_generated_tokens.extend(new_tokens.numpy().tolist()) | |
| tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False) | |
| tokens_yielded.append(tokens_text) | |
| yield tokens_text | |
| # Stop if stop token is encountered | |
| if stop_token_id in new_tokens.numpy(): | |
| final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False) | |
| yield "<|Clean|>" + final_text | |
| break | |
| # Response Handler Function | |
| def respond(message, history): | |
| give_mod = "" | |
| history = history[-3:] # Limit history to last 3 exchanges | |
| for chunk in history: | |
| give_mod += f"<|SOH|>{chunk[0]}<|SOB|>{chunk[1]}" | |
| give_mod += f"<|SOH|>{message.capitalize()}<|SOB|>" | |
| print(give_mod) | |
| response = "" | |
| for token in raw_pred(give_mod, model, tokenizer): | |
| if "<|Clean|>" in token: | |
| response = token | |
| print(response) | |
| else: | |
| response += token | |
| yield response.replace("<|SOH|>", "").replace("<|Clean|>", "") | |
| # Gradio Chat Interface Setup | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="Chat with Esther", # Title of the app | |
| description="A friendly chatbot ready to help and chat with you! 😊 (Everything you say will be stored and will be used for further improving the model)", # Description of the app | |
| theme="compact", # Choose the theme | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |