Spaces:
Running
Running
| import os | |
| import tarfile | |
| import requests | |
| import gradio as gr | |
| import gpt_2_simple as gpt2 | |
| import tensorflow as tf | |
| # --- CONFIGURATION --- | |
| REPO_OWNER = "PizzaTowerFanGD" | |
| REPO_NAME = "owotgpt" | |
| # List of models to load | |
| MODELS = ['owotgpt', 'owotgpt-code'] | |
| CHECKPOINT_DIR = "checkpoint" | |
| def download_model(run_name): | |
| """Downloads and extracts a specific model by run_name.""" | |
| model_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest/download/{run_name}.tar" | |
| tar_path = f"{run_name}.tar" | |
| if not os.path.exists(os.path.join(CHECKPOINT_DIR, run_name)): | |
| print(f"Downloading {run_name} from {model_url}...") | |
| response = requests.get(model_url, stream=True) | |
| if response.status_code == 200: | |
| with open(tar_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Extracting {tar_path}...") | |
| try: | |
| with tarfile.open(tar_path, 'r') as tar_ref: | |
| tar_ref.extractall(path=".") | |
| print(f"Extraction of {run_name} complete.") | |
| except Exception as e: | |
| print(f"Error extracting {run_name}: {e}") | |
| finally: | |
| if os.path.exists(tar_path): | |
| os.remove(tar_path) | |
| else: | |
| print(f"Failed to download {run_name}. Status: {response.status_code}") | |
| # --- INITIALIZATION --- | |
| # Dictionary to store sessions and graphs for each model | |
| loaded_models = {} | |
| for model_name in MODELS: | |
| download_model(model_name) | |
| # Create a unique graph and session for each model to avoid collisions | |
| g = tf.Graph() | |
| with g.as_default(): | |
| s = gpt2.start_tf_sess() | |
| gpt2.load_gpt2(s, run_name=model_name) | |
| loaded_models[model_name] = {"sess": s, "graph": g} | |
| def generate_text(model_choice, prompt, length=100, temperature=0.8, top_k=40): | |
| """ | |
| Inference function that selects the model based on user choice. | |
| """ | |
| if model_choice not in loaded_models: | |
| return "Error: Model not loaded." | |
| selected = loaded_models[model_choice] | |
| sess = selected["sess"] | |
| graph = selected["graph"] | |
| with graph.as_default(): | |
| with sess.as_default(): | |
| output = gpt2.generate( | |
| sess, | |
| run_name=model_choice, | |
| length=int(length), | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| prefix=prompt, | |
| return_as_list=True, | |
| include_prefix=True | |
| )[0] | |
| return output | |
| # --- GRADIO UI --- | |
| description = """ | |
| ### OWoTGPT Multi-Model Inference | |
| Choose between the standard conversation model or the OWoT source code model. | |
| - **owotgpt**: Trained on chat logs. | |
| - **owotgpt-code**: Trained on OWoT source code. | |
| """ | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model"), | |
| gr.Textbox(lines=10, label="Context / Prompt", placeholder="Type your prompt here..."), | |
| gr.Slider(10, 500, value=100, step=1, label="Max Generation Length"), | |
| gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature"), | |
| gr.Slider(1, 100, value=40, step=1, label="Top K"), | |
| ], | |
| outputs=gr.Textbox(label="Generated Output"), | |
| title="OWoTGPT Model Hub", | |
| description=description | |
| ) | |
| if __name__ == "__main__": | |
| # mcp_server=True makes both models available as tools via MCP | |
| iface.launch(mcp_server=True) |