| | import gradio as gr |
| | from huggingface_hub import HfApi |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| | import json |
| | import os |
| |
|
| | DEFAULT_FILE = "default_models.json" |
| | USER_FILE = "models.json" |
| |
|
| | |
| | |
| | |
| | def load_default_models(): |
| | with open(DEFAULT_FILE, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| |
|
| |
|
| | def load_user_models(): |
| | if os.path.exists(USER_FILE): |
| | with open(USER_FILE, "r", encoding="utf-8") as f: |
| | try: |
| | return json.load(f) |
| | except json.JSONDecodeError: |
| | return {} |
| | return {} |
| |
|
| |
|
| | def save_user_models(data): |
| | with open(USER_FILE, "w", encoding="utf-8") as f: |
| | json.dump(data, f, indent=2, ensure_ascii=False) |
| |
|
| |
|
| | def merge_models(): |
| | """ |
| | Merge default + user models into one tree: |
| | Category -> Family -> Model -> meta |
| | User models can introduce new categories/families. |
| | """ |
| | base = load_default_models() |
| | user = load_user_models() |
| |
|
| | for category, families in user.items(): |
| | if category not in base: |
| | base[category] = {} |
| | for family, models in families.items(): |
| | if family not in base[category]: |
| | base[category][family] = {} |
| | for model_name, meta in models.items(): |
| | base[category][family][model_name] = meta |
| |
|
| | return base |
| |
|
| |
|
| | |
| | |
| | |
| | def flatten_models(model_tree): |
| | """ |
| | Returns a dict: |
| | full_key -> (meta, category, family, model_name) |
| | where full_key = "Category / Family / Model" |
| | """ |
| | flat = {} |
| | for category, families in model_tree.items(): |
| | for family, models in families.items(): |
| | for model_name, meta in models.items(): |
| | full_key = f"{category} / {family} / {model_name}" |
| | flat[full_key] = (meta, category, family, model_name) |
| | return flat |
| |
|
| |
|
| | |
| | |
| | |
| | DEBUG_MESSAGES = [] |
| |
|
| | def debug(msg): |
| | """Append a debug message to the global log.""" |
| | DEBUG_MESSAGES.append(str(msg)) |
| | if len(DEBUG_MESSAGES) > 300: |
| | DEBUG_MESSAGES.pop(0) |
| | return "\n".join(DEBUG_MESSAGES) |
| |
|
| |
|
| | def get_debug_log(): |
| | return "\n".join(DEBUG_MESSAGES) |
| |
|
| |
|
| | |
| | |
| | |
| | def add_model_box( |
| | category, |
| | family, |
| | model_name, |
| | model_id, |
| | description, |
| | link, |
| | emoji |
| | ): |
| | try: |
| | if not model_id: |
| | debug("Add model failed: no model_id provided") |
| | return gr.Markdown.update( |
| | value="Please provide a Model ID like `user/model`." |
| | ) |
| |
|
| | if not category: |
| | category = "Custom" |
| | if not family: |
| | family = "User-Added" |
| | if not model_name: |
| | model_name = model_id.split("/")[-1] |
| | if not description: |
| | description = "User-added model." |
| | if not link: |
| | link = f"https://huggingface.co/{model_id}" |
| | if not emoji: |
| | emoji = "✨" |
| |
|
| | user_models = load_user_models() |
| |
|
| | if category not in user_models: |
| | user_models[category] = {} |
| | if family not in user_models[category]: |
| | user_models[category][family] = {} |
| |
|
| | user_models[category][family][model_name] = { |
| | "id": model_id, |
| | "description": description, |
| | "link": link, |
| | "emoji": emoji |
| | } |
| |
|
| | save_user_models(user_models) |
| |
|
| | msg = ( |
| | f"Added model under `{category} / {family}`: " |
| | f"{emoji} **{model_name}** (`{model_id}`)\n\n" |
| | f"It will appear in the model tree after reloading the Space." |
| | ) |
| | debug(f"Model added: {category} / {family} / {model_name} ({model_id})") |
| | return gr.Markdown.update(value=msg) |
| | except Exception: |
| | import traceback |
| | tb = traceback.format_exc() |
| | debug(f"ERROR in add_model_box:\n{tb}") |
| | return gr.Markdown.update( |
| | value="An error occurred while adding the model. Check Debug Log." |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | def check_model_access(model_id, hf_token): |
| | """ |
| | Try to get model info; return (ok: bool, message: str). |
| | This helps distinguish auth/gating vs other issues. |
| | For local loading, this is not strictly required, but we keep |
| | it to give clearer messages for private/gated models. |
| | """ |
| | try: |
| | api = HfApi(token=hf_token.token if hf_token else None) |
| | _ = api.model_info(model_id) |
| | return True, "" |
| | except Exception as e: |
| | import traceback |
| | tb = traceback.format_exc() |
| | debug(f"ERROR in check_model_access for {model_id}:\n{tb}") |
| | return False, str(e) |
| |
|
| |
|
| | |
| | |
| | |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| |
|
| | LOCAL_MODEL_CACHE = {} |
| |
|
| | def load_local_model(model_id): |
| | """ |
| | Load a model + tokenizer locally and cache them. |
| | This makes the Space behave like a dedicated model Space: |
| | models are executed inside the container, not via Inference API. |
| | """ |
| | if model_id in LOCAL_MODEL_CACHE: |
| | debug(f"Using cached model: {model_id}") |
| | return LOCAL_MODEL_CACHE[model_id] |
| |
|
| | debug(f"Loading model locally: {model_id}") |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | except Exception as e: |
| | debug(f"ERROR loading tokenizer for {model_id}: {e}") |
| | raise |
| |
|
| | try: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | device_map="auto" |
| | ) |
| | except Exception as e: |
| | debug(f"ERROR loading model weights for {model_id}: {e}") |
| | raise |
| |
|
| | LOCAL_MODEL_CACHE[model_id] = (tokenizer, model) |
| | return tokenizer, model |
| |
|
| |
|
| | |
| | |
| | |
| | def respond( |
| | message, |
| | history, |
| | system_message, |
| | max_tokens, |
| | temperature, |
| | top_p, |
| | active_model_key, |
| | hf_token: gr.OAuthToken |
| | ): |
| | |
| | if active_model_key is None: |
| | yield "No model selected. Please choose a model in the sidebar and click 'Use this model'." |
| | return |
| |
|
| | models = merge_models() |
| | flat = flatten_models(models) |
| |
|
| | meta_tuple = flat.get(active_model_key) |
| | if meta_tuple is None: |
| | yield "Selected model not found. Please choose a model again." |
| | return |
| |
|
| | meta, _, _, _ = meta_tuple |
| | model_id = meta["id"] |
| |
|
| | debug(f"Chat using local model: {model_id}") |
| |
|
| | |
| | ok, msg = check_model_access(model_id, hf_token) |
| | if not ok: |
| | yield ( |
| | f"Could not access model `{model_id}` on Hugging Face.\n\n" |
| | f"This is usually because:\n" |
| | f"- The repo is private or gated and this token has no access\n" |
| | f"- Or the token is invalid/expired\n\n" |
| | f"Raw error:\n{msg}\n\n" |
| | f"Check Debug Log for more details." |
| | ) |
| | return |
| |
|
| | |
| | try: |
| | tokenizer, model = load_local_model(model_id) |
| | except Exception: |
| | import traceback |
| | tb = traceback.format_exc() |
| | debug(f"ERROR in load_local_model for {model_id}:\n{tb}") |
| | yield ( |
| | f"Failed to load model `{model_id}` locally inside the Space.\n" |
| | f"Check the Debug Log for details (likely out of memory or missing files)." |
| | ) |
| | return |
| |
|
| | |
| | prompt = system_message.strip() + "\n\n" |
| | for turn in history or []: |
| | role = turn.get("role", "user") |
| | content = turn.get("content", "") |
| | if role == "user": |
| | prompt += f"User: {content}\n" |
| | else: |
| | prompt += f"Assistant: {content}\n" |
| | prompt += f"User: {message}\nAssistant:" |
| |
|
| | debug(f"Prompt length (chars): {len(prompt)}") |
| |
|
| | try: |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
|
| | |
| | output_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=int(max_tokens), |
| | do_sample=True, |
| | temperature=float(temperature), |
| | top_p=float(top_p), |
| | pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else None, |
| | ) |
| |
|
| | output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| |
|
| | |
| | if "Assistant:" in output_text: |
| | answer = output_text.split("Assistant:")[-1].strip() |
| | else: |
| | answer = output_text.strip() |
| |
|
| | yield answer |
| |
|
| | except Exception: |
| | import traceback |
| | tb = traceback.format_exc() |
| | debug(f"ERROR during local generation for {model_id}:\n{tb}") |
| | yield ( |
| | "An error occurred during local text generation.\n" |
| | "This is often due to running out of memory for large models.\n" |
| | "Try a smaller model, fewer max tokens, or check the Debug Log." |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | def use_model(fk, old_fk): |
| | """ |
| | fk: full key "Category / Family / Model" (from gr.State(full_key)) |
| | old_fk: previous active model key (from active_model_state) |
| | Returns: (new_active_key, current_model_label_text) |
| | """ |
| | try: |
| | models_local = merge_models() |
| | flat_local = flatten_models(models_local) |
| | meta_loc_tuple = flat_local.get(fk) |
| |
|
| | if not meta_loc_tuple: |
| | debug(f"use_model: key not found: {fk}") |
| | return old_fk, "**Current model:** _none selected_" |
| |
|
| | meta_loc, _, _, mname = meta_loc_tuple |
| | emoji_local = meta_loc.get("emoji", "✨") |
| | label_text = f"**Current model:** {emoji_local} {mname}" |
| |
|
| | debug(f"use_model: selected {fk}") |
| | return fk, label_text |
| |
|
| | except Exception: |
| | import traceback |
| | tb = traceback.format_exc() |
| | debug(f"ERROR in use_model:\n{tb}") |
| | return old_fk, "**Current model:** _error occurred (see Debug Log)_" |
| |
|
| |
|
| | |
| | |
| | |
| | def build_model_tree( |
| | models, |
| | active_model_state, |
| | current_model_label |
| | ): |
| | """ |
| | models: merged models dict (Category -> Family -> Model -> meta) |
| | active_model_state: gr.State storing current active full key |
| | current_model_label: gr.Markdown for 'Current model: ...' |
| | """ |
| |
|
| | for category, families in models.items(): |
| | with gr.Accordion(category, open=False): |
| | for family, model_dict in families.items(): |
| | with gr.Accordion(family, open=False): |
| | for model_name, meta in model_dict.items(): |
| | emoji = meta.get("emoji", "✨") |
| | full_key = f"{category} / {family} / {model_name}" |
| |
|
| | |
| | with gr.Accordion(f"{emoji} {model_name}", open=False): |
| | info_text = ( |
| | f"**Model ID:** `{meta['id']}` \n" |
| | f"**Description:** {meta['description']} \n" |
| | f"[Model card]({meta['link']})" |
| | ) |
| | gr.Markdown(info_text) |
| |
|
| | use_btn = gr.Button("Use this model", size="sm") |
| |
|
| | |
| | use_btn.click( |
| | use_model, |
| | inputs=[gr.State(full_key), active_model_state], |
| | outputs=[active_model_state, current_model_label], |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | with gr.Blocks() as demo: |
| | models_tree = merge_models() |
| |
|
| | |
| | active_model_key = gr.State(value=None) |
| |
|
| | with gr.Sidebar(): |
| | gr.LoginButton() |
| |
|
| | |
| | with gr.Accordion("Add New Model", open=False): |
| | category_input = gr.Textbox( |
| | label="Category (e.g. Exotic or new category)", |
| | placeholder="Exotic" |
| | ) |
| | family_input = gr.Textbox( |
| | label="Family (e.g. RWKV)", |
| | placeholder="RWKV" |
| | ) |
| | model_name_input = gr.Textbox( |
| | label="Model Name (e.g. RWKV-World-7B)", |
| | placeholder="RWKV-World-7B" |
| | ) |
| | model_id_input = gr.Textbox( |
| | label="Model ID (e.g. BlinkDL/rwkv-7-world)", |
| | placeholder="BlinkDL/rwkv-7-world" |
| | ) |
| | description_input = gr.Textbox( |
| | label="Description (optional)", |
| | lines=2 |
| | ) |
| | link_input = gr.Textbox( |
| | label="Link (optional, will default to https://huggingface.co/ModelID if empty)", |
| | lines=1 |
| | ) |
| | emoji_input = gr.Textbox( |
| | label="Emoji (optional, e.g. 🌍)", |
| | lines=1 |
| | ) |
| |
|
| | add_button = gr.Button("Add Model") |
| | add_status = gr.Markdown("") |
| |
|
| | add_button.click( |
| | add_model_box, |
| | inputs=[ |
| | category_input, |
| | family_input, |
| | model_name_input, |
| | model_id_input, |
| | description_input, |
| | link_input, |
| | emoji_input, |
| | ], |
| | outputs=add_status, |
| | ) |
| |
|
| | |
| | with gr.Accordion("Debug Log", open=False): |
| | debug_log = gr.Textbox( |
| | label="System Debug Output", |
| | value="", |
| | lines=15, |
| | max_lines=200, |
| | interactive=False, |
| | show_copy_button=True, |
| | ) |
| |
|
| | |
| | refresh_debug = gr.Button("Refresh Debug Log", size="sm") |
| |
|
| | refresh_debug.click( |
| | get_debug_log, |
| | inputs=None, |
| | outputs=debug_log |
| | ) |
| |
|
| | |
| | current_model_label = gr.Markdown("**Current model:** _none selected_") |
| |
|
| | gr.Markdown("### Models") |
| |
|
| | |
| | build_model_tree( |
| | models_tree, |
| | active_model_state=active_model_key, |
| | current_model_label=current_model_label, |
| | ) |
| |
|
| | |
| | chatbot = gr.ChatInterface( |
| | respond, |
| | title=current_model_label, |
| | type="messages", |
| | additional_inputs=[ |
| | gr.Textbox( |
| | value="You are a friendly chatbot.", |
| | label="System message" |
| | ), |
| | gr.Slider( |
| | minimum=1, |
| | maximum=100000, |
| | value=512, |
| | step=1, |
| | label="Max new tokens" |
| | ), |
| | gr.Slider( |
| | minimum=0.1, |
| | maximum=4.0, |
| | value=0.7, |
| | step=0.1, |
| | label="Temperature" |
| | ), |
| | gr.Slider( |
| | minimum=0.1, |
| | maximum=1.0, |
| | value=0.95, |
| | step=0.05, |
| | label="Top-p" |
| | ), |
| | active_model_key, |
| | ], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|