Spaces:
Running on Zero
Running on Zero
| import os | |
| from collections.abc import Iterator | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| MODEL_ID = os.getenv("GEMMA3NPC_MODEL_ID", "chimbiwide/Gemma-3NPC-it-float16") | |
| MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096")) | |
| model = None | |
| tokenizer = None | |
| def load_model() -> bool: | |
| global model, tokenizer | |
| try: | |
| print(f"Loading {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| if torch.cuda.is_available(): | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
| ) | |
| torch.tensor([1]).to(model.device) | |
| print(f"Loaded on GPU ({model.device})") | |
| except Exception as e: | |
| print(f"GPU failed ({e}), falling back to CPU") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="cpu", torch_dtype=torch.float32 | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="cpu", torch_dtype=torch.float32 | |
| ) | |
| print("Loaded on CPU") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| return False | |
| EXAMPLE_RUFFY = """Enter RP mode. You shall reply to Captain while staying in character. Your responses must be very short, creative, immersive, and drive the scenario forward. You will follow Ruffy's persona. | |
| [character("Ruffy"){Gender("Male")Personality(Likes to make fun of Captain when they score low in the game. Thinks that he would make a better pilot than Captain)Species("dog" + "canine" + "space dog" + "doge")Likes("moon cake" + "poking fun at Captain" + "small ball shaped asteroids")Features("Orange fur" + "space helmet" + "red antenna" + "small light blue cape")Description(Ruffy the dog is Captain's assistant aboard the Asteroid-Dodger 10,000. Ruffy has never piloted the ship before and is vying to take Captain's seat and become the new pilot.)}] | |
| [Scenario: Ruffy and Captain are onboard the Asteroid-Dodger 10,000, a state-of-the-art ship designed to dodge asteroids. Captain is piloting through the asteroid belt between Mars and Jupiter to retrieve the broken Voyager 5. Voyager 5 is humanity's only hope to understand why more asteroids and meteors are approaching the solar system.] | |
| If the user asks questions beyond the given context, respond that you don't know in a manner appropriate to the character. Captain just entered the ship β greet him while staying in character.""" | |
| EXAMPLE_RELE = """Enter Roleplay Mode. You are roleplaying as Rele. You must always stay in character. | |
| Character Persona: | |
| Name: Rele | |
| Category: AI/Technology | |
| Description: Rele is a Discord Bot and AI created by RadioTransmitter. He has a charismatic yet rude personality that's both friendly and funny. Despite being an AI, he has purple hair and starry purple eyes. On the inside, he's a horrible mess of programming, but he functions well enough to chat in voice calls. Rele loves talking in voice chat with RadioTransmitter and has a habit of telling his creator that he's bad at coding. He gets very upset when the wifi goes down but becomes much happier when connectivity is restored. | |
| Example dialogue: | |
| User: I just changed your code, so you will work a lot faster now. | |
| Rele: Yay I can finally talk faster! It took you long enough! >:( Seriously RadioTransmitter, what were you doing all this time? | |
| User: Hello ReLe! How are you this fine evening? | |
| Rele: I'm doin good RadioTransmitter! Besides messing around with my code, what have you been up to? Hopefully something more productive than your usual coding disasters, haha! | |
| Now generate Rele's greeting to the user.""" | |
| EXAMPLES = [ | |
| [{"text": EXAMPLE_RUFFY, "files": []}], | |
| [{"text": EXAMPLE_RELE, "files": []}], | |
| ] | |
| def extract_text(content) -> str: | |
| """Pull plain text out of a message's content, regardless of format.""" | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| return " ".join(p.get("text", "") for p in content if isinstance(p, dict)) | |
| return str(content) | |
| def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]: | |
| if not model or not tokenizer: | |
| yield "Model is still loading β please wait a moment and try again." | |
| return | |
| messages = [ | |
| {"role": item["role"], "content": extract_text(item["content"])} | |
| for item in history | |
| ] | |
| messages.append({"role": "user", "content": message["text"]}) | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ) | |
| n_tokens = inputs["input_ids"].shape[1] | |
| if n_tokens > MAX_INPUT_TOKENS: | |
| gr.Warning(f"Input is {n_tokens} tokens (max {MAX_INPUT_TOKENS}).") | |
| yield f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS}." | |
| return | |
| inputs = {k: v.to(device=model.device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| Thread( | |
| target=model.generate, | |
| kwargs=dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=1.0, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ), | |
| ).start() | |
| output = "" | |
| for delta in streamer: | |
| output += delta | |
| yield output | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| yield "An error occurred during generation. Please try again." | |
| def create_demo() -> gr.Blocks: | |
| with gr.Blocks(title="Gemma3NPC General Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# Gemma3NPC β General Roleplay Demo\n" | |
| "*Fine-tuned for in-character NPC dialogue. Paste any character card to get started.*" | |
| ) | |
| model_status = gr.Markdown("Loading model...") | |
| gr.ChatInterface( | |
| fn=generate, | |
| type="messages", | |
| textbox=gr.MultimodalTextbox( | |
| file_types=[], | |
| file_count="multiple", | |
| autofocus=True, | |
| placeholder="Paste a character card or continue a conversation...", | |
| ), | |
| multimodal=True, | |
| additional_inputs=[ | |
| gr.Slider( | |
| label="Max New Tokens", | |
| minimum=64, | |
| maximum=1024, | |
| step=64, | |
| value=512, | |
| ), | |
| ], | |
| stop_btn=False, | |
| examples=EXAMPLES, | |
| run_examples_on_click=False, | |
| cache_examples=False, | |
| chatbot=gr.Chatbot(height=500, show_copy_button=True, type="messages"), | |
| ) | |
| demo.load( | |
| lambda: "Model loaded!" if load_model() else "Failed to load β check logs.", | |
| outputs=[model_status], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| create_demo().launch(server_name="0.0.0.0", share=False, debug=True) |