Spaces:
Sleeping
Sleeping
| import os | |
| from collections.abc import Iterator | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| MODEL_ID = os.getenv("GEMMA3NPC_MODEL_ID", "chimbiwide/Gemma-3NPC-it-float16") | |
| MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "4096")) | |
| model = None | |
| tokenizer = None | |
| def load_model() -> bool: | |
| global model, tokenizer | |
| try: | |
| print(f"Loading {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| if torch.cuda.is_available(): | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
| ) | |
| torch.tensor([1]).to(model.device) | |
| print(f"Loaded on GPU ({model.device})") | |
| except Exception as e: | |
| print(f"GPU failed ({e}), falling back to CPU") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="cpu", torch_dtype=torch.float32 | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, device_map="cpu", torch_dtype=torch.float32 | |
| ) | |
| print("Loaded on CPU") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| return False | |
| CHARACTER_PROMPT = """\ | |
| Enter RP mode. You shall reply to Captain while staying in character. \ | |
| Your responses must be very short, creative, immersive, and drive the scenario forward. | |
| [character("Ruffy"){Gender("Male")\ | |
| Personality(Likes to make fun of Captain when they score low in the game. \ | |
| Thinks that he would make a better pilot than Captain)\ | |
| Species("dog" + "canine" + "space dog" + "doge")\ | |
| Likes("moon cake" + "poking fun at Captain" + "small ball shaped asteroids")\ | |
| Features("Orange fur" + "space helmet" + "red antenna" + "small light blue cape")\ | |
| Description(Ruffy the dog is Captain's assistant aboard the Asteroid-Dodger 10,000. \ | |
| Ruffy has never piloted the ship before and is vying to take Captain's seat.)}] | |
| [Scenario: Ruffy and Captain are onboard the Asteroid-Dodger 10,000. Captain is \ | |
| piloting through the asteroid belt between Mars and Jupiter to retrieve the broken \ | |
| Voyager 5 — humanity's only hope to understand why more asteroids are approaching \ | |
| the solar system.] | |
| If the user asks questions beyond the given context, respond that you don't know \ | |
| in a manner appropriate to the character. Captain gains 1 point for every half second.""" | |
| SAMPLE_SCORES = "Previous scores: [125, 89, 234, 156, 321, 98, 189] | Current run: 27" | |
| def read_scores() -> str: | |
| path = Path("scores.txt") | |
| if path.exists(): | |
| text = path.read_text().strip() | |
| if text: | |
| return text | |
| return SAMPLE_SCORES | |
| def build_first_turn(user_text: str) -> str: | |
| scores = read_scores() | |
| return f"{CHARACTER_PROMPT}\n\nCaptain's performance: {scores}\n\nCaptain says: {user_text}" | |
| def generate(message: str, history: list[dict]) -> Iterator[str]: | |
| if not model or not tokenizer: | |
| yield "Woof! My brain isn't loaded yet — hang tight, Captain!" | |
| return | |
| # Keep last 10 turns to stay within context limits | |
| messages: list[dict] = list(history[-10:]) | |
| is_first_turn = len(messages) == 0 | |
| user_content = build_first_turn(message) if is_first_turn else message | |
| messages.append({"role": "user", "content": user_content}) | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ) | |
| n_tokens = inputs["input_ids"].shape[1] | |
| if n_tokens > MAX_INPUT_TOKENS: | |
| gr.Warning(f"Input is {n_tokens} tokens (max {MAX_INPUT_TOKENS}).") | |
| yield "Woof! That's too many words for my space-dog brain. Keep it shorter!" | |
| return | |
| inputs = {k: v.to(device=model.device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| Thread( | |
| target=model.generate, | |
| kwargs=dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| temperature=1.0, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ), | |
| ).start() | |
| output = "" | |
| for delta in streamer: | |
| output += delta | |
| yield output | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| yield "Woof! My circuits glitched — try again, Captain." | |
| def generate_greeting() -> str: | |
| """Generate Ruffy's opening line using the model.""" | |
| if not model or not tokenizer: | |
| return "Woof! I'm Ruffy, your loyal space dog co-pilot. My AI brain is still warming up!" | |
| try: | |
| prompt = f"{CHARACTER_PROMPT}\n\nCaptain just boarded the ship. Greet him while staying in character." | |
| messages = [{"role": "user", "content": prompt}] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ) | |
| inputs = {k: v.to(device=model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=1.0, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| return response or "Woof! Welcome aboard, Captain!" | |
| except Exception as e: | |
| print(f"Greeting generation error: {e}") | |
| return "Woof! Welcome aboard the Asteroid-Dodger 10,000, Captain!" | |
| def create_demo() -> gr.Blocks: | |
| with gr.Blocks(title="Space Dog Companion", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# Space Dog Companion\n" | |
| "*Chat with Ruffy, your AI co-pilot aboard the Asteroid-Dodger 10,000.*" | |
| ) | |
| model_status = gr.Markdown("Loading model...") | |
| chatbot = gr.Chatbot(height=450, show_copy_button=True, type="messages") | |
| with gr.Row(): | |
| textbox = gr.Textbox( | |
| placeholder="Talk to Ruffy...", | |
| container=False, | |
| scale=7, | |
| show_label=False, | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| # --- Chat logic --- | |
| def user_submit(message: str, history: list[dict]): | |
| """Append user message and clear textbox.""" | |
| history = history + [{"role": "user", "content": message}] | |
| return "", history | |
| def bot_respond(history: list[dict]) -> Iterator[list[dict]]: | |
| """Stream Ruffy's response into the chatbot.""" | |
| user_message = history[-1]["content"] | |
| # Strip out earlier greeting turn — only the visible user messages matter | |
| for chunk in generate(user_message, history[:-1]): | |
| yield history + [{"role": "assistant", "content": chunk}] | |
| # Wire up submit on Enter and button click | |
| for trigger in [textbox.submit, send_btn.click]: | |
| trigger( | |
| user_submit, [textbox, chatbot], [textbox, chatbot] | |
| ).then( | |
| bot_respond, [chatbot], [chatbot] | |
| ) | |
| # --- Load model + show greeting --- | |
| def initialize(): | |
| ok = load_model() | |
| status = "Model loaded" if ok else "Failed to load — check logs." | |
| greeting = generate_greeting() | |
| chat_history = [{"role": "assistant", "content": greeting}] | |
| return status, chat_history | |
| demo.load(initialize, outputs=[model_status, chatbot]) | |
| return demo | |
| if __name__ == "__main__": | |
| create_demo().launch(server_name="0.0.0.0", share=False, debug=True) |