# -------------------------------------------------------------- # app.py – a Gradio chat UI for maya-research/maya1 # -------------------------------------------------------------- import json import os from pathlib import Path from typing import List, Tuple, Dict import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from huggingface_hub import HfApi, Repository, create_repo, upload_folder # ------------------- CONFIGURATION ----------------------------- MODEL_ID = "maya-research/maya1" # the model you want to use MAX_NEW_TOKENS = 256 # generation length TEMPERATURE = 0.7 TOP_P = 0.9 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Folder inside the Space where we keep per‑session JSON files HISTORY_DIR = Path("history") HISTORY_DIR.mkdir(exist_ok=True) # ---------------------------------------------------------------- # 1️⃣ Load the model once (global, reused across requests) # ---------------------------------------------------------------- print(f"🔧 Loading {MODEL_ID} …") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, ) # -------------------------------------------------------------- # 2️⃣ Helper functions for history persistence # -------------------------------------------------------------- def _history_path(session_id: str) -> Path: """File that stores a JSON list of (user,assistant) pairs.""" safe_id = session_id.replace("/", "_") return HISTORY_DIR / f"{safe_id}.json" def load_history(session_id: str) -> List[Tuple[str, str]]: """Read JSON file → list of (user,assistant). Return [] if not present.""" p = _history_path(session_id) if p.is_file(): try: return json.loads(p.read_text(encoding="utf-8")) except Exception as e: print(f"⚠️ Failed to read history for {session_id}: {e}") return [] def save_history(session_id: str, chat: List[Tuple[str, str]]) -> None: """Write JSON file and push it back to the repo.""" p = _history_path(session_id) p.write_text(json.dumps(chat, ensure_ascii=False, indent=2), encoding="utf-8") # ----------------------------------------------------------------- # Push the new file to the repo (so it survives container restarts) # ----------------------------------------------------------------- # NOTE: This only works if the Space has a write token (see step 3). try: api = HfApi() # `repo_id` is the full name of the Space (owner/space-name) repo_id = os.getenv("HF_SPACE_REPO") # automatically set by the Hub if repo_id: api.upload_file( path_or_fileobj=str(p), path_in_repo=str(p), repo_id=repo_id, repo_type="space", token=os.getenv("HF_TOKEN"), ) except Exception as exc: # Failing to push is not fatal – the file stays on the container. print(f"⚠️ Could not push history to hub: {exc}") def list_sessions() -> List[str]: """Return a list of all stored session IDs (file names).""" return [f.stem for f in HISTORY_DIR.glob("*.json")] # -------------------------------------------------------------- # 3️⃣ The generation function – called by Gradio # -------------------------------------------------------------- def generate_reply( user_message: str, chat_history: List[Tuple[str, str]], session_id: str, ) -> Tuple[List[Tuple[str, str]], str]: """ 1. Append the user's new message. 2. Build the full `messages` list in the format expected by the model's chat_template. 3. Use the tokenizer's `apply_chat_template(..., add_generation_prompt=True)` to create the prompt. 4. Run the pipeline, decode, and strip the extra tokens. 5. Append the assistant answer and persist the whole chat. """ # ----- 1️⃣ Append user message to history ----- chat_history.append((user_message, "")) # placeholder for assistant # ----- 2️⃣ Build the messages list for the template ----- messages = [{"role": "user", "content": user_message}] # prepend previous exchanges (system messages are not needed here) for user, assistant in chat_history[:-1]: # exclude the placeholder messages.append({"role": "user", "content": user}) messages.append({"role": "assistant", "content": assistant}) # ----- 3️⃣ Render the prompt with the model's chat template ----- prompt = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) # Debug – uncomment if you want to see the raw prompt in logs # print("\n--- Prompt sent to model ---\n", prompt) # ----- 4️⃣ Generate the answer ----- result = generator(prompt, max_new_tokens=MAX_NEW_TOKENS)[0]["generated_text"] # The pipeline returns the **whole** text (prompt + answer). Remove the prompt. answer = result[len(prompt) :].strip() # Some models still emit special tokens like <|eot_id|>; strip them. for stop in ["<|eot_id|>", "", ""]: answer = answer.replace(stop, "").strip() # ----- 5️⃣ Update history and persist ----- chat_history[-1] = (user_message, answer) # replace placeholder save_history(session_id, chat_history) return chat_history, answer # -------------------------------------------------------------- # 4️⃣ UI definition # -------------------------------------------------------------- with gr.Blocks(theme=gr.themes.Default()) as demo: # ----------------------------------------------------------------- # Top bar – session selector + "New chat" button # ----------------------------------------------------------------- with gr.Row(): session_dropdown = gr.Dropdown( choices=list_sessions(), label="🗂️ Load previous chat", interactive=True, allow_custom_value=False, ) new_chat_btn = gr.Button("🆕 New chat", variant="primary") status_txt = gr.Markdown("", visible=False) # ----------------------------------------------------------------- # Main chat area # ----------------------------------------------------------------- chatbot = gr.Chatbot(label="🗨️ Maya‑1 Chat", height=600) txt = gr.Textbox( placeholder="Type your message and hit Enter …", label="Your message", container=False, ) submit_btn = gr.Button("Send", variant="secondary") # ----------------------------------------------------------------- # Hidden state – we keep the full list of (user,assistant) tuples # ----------------------------------------------------------------- chat_state = gr.State([]) # List[Tuple[str,str]] session_state = gr.State("") # Current session_id (string) # ----------------------------------------------------------------- # 5️⃣ Callbacks # ----------------------------------------------------------------- # When the app loads, generate a fresh anonymous session ID def init_session(): import uuid sid = str(uuid.uuid4()) return sid, [] # session_state, chat_state demo.load(fn=init_session, outputs=[session_state, chat_state]) # ----------------------------------------------------------------- # New chat → reset everything and give a brand‑new session ID # ----------------------------------------------------------------- def new_chat(): import uuid sid = str(uuid.uuid4()) return sid, [], [] # session_id, empty chat_state, empty UI new_chat_btn.click( fn=new_chat, outputs=[session_state, chat_state, chatbot], ) # ----------------------------------------------------------------- # Load a saved session from the dropdown # ----------------------------------------------------------------- def load_session(selected: str): if not selected: return "", [], [] # nothing selected → blank # The file name is the session_id we used when saving. session_id = selected history = load_history(session_id) # Convert List[Tuple] → format expected by Gradio.Chatbot ui_history = [(u, a) for u, a in history] return session_id, history, ui_history session_dropdown.change( fn=load_session, inputs=[session_dropdown], outputs=[session_state, chat_state, chatbot], ) # ----------------------------------------------------------------- # When the user hits "Enter" or clicks Send → generate a reply # ----------------------------------------------------------------- def user_submit(user_msg: str, chat_hist: List[Tuple[str, str]], sid: str): # Call the generation function updated_hist, answer = generate_reply(user_msg, chat_hist, sid) # UI expects List[Tuple[user,assistant]] ui_hist = [(u, a) for u, a in updated_hist] return "", ui_hist, updated_hist, answer txt.submit( fn=user_submit, inputs=[txt, chat_state, session_state], outputs=[txt, chatbot, chat_state, status_txt], ) submit_btn.click( fn=user_submit, inputs=[txt, chat_state, session_state], outputs=[txt, chatbot, chat_state, status_txt], ) # ----------------------------------------------------------------- # Keep the session‑dropdown up‑to‑date after each new save # ----------------------------------------------------------------- def refresh_dropdown(): return gr.Dropdown.update(choices=list_sessions()) # Whenever we save a new session (i.e., after every reply) we refresh the list. chat_state.change(fn=refresh_dropdown, inputs=None, outputs=session_dropdown) # -------------------------------------------------------------- # 6️⃣ Run the demo # -------------------------------------------------------------- if __name__ == "__main__": demo.queue() # enables concurrent users demo.launch()