import os import tarfile import requests import gradio as gr import gpt_2_simple as gpt2 import tensorflow as tf # --- CONFIGURATION --- REPO_OWNER = "PizzaTowerFanGD" REPO_NAME = "owotgpt" # List of models to load MODELS = ['owotgpt', 'owotgpt-code'] CHECKPOINT_DIR = "checkpoint" def download_model(run_name): """Downloads and extracts a specific model by run_name.""" model_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest/download/{run_name}.tar" tar_path = f"{run_name}.tar" if not os.path.exists(os.path.join(CHECKPOINT_DIR, run_name)): print(f"Downloading {run_name} from {model_url}...") response = requests.get(model_url, stream=True) if response.status_code == 200: with open(tar_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Extracting {tar_path}...") try: with tarfile.open(tar_path, 'r') as tar_ref: tar_ref.extractall(path=".") print(f"Extraction of {run_name} complete.") except Exception as e: print(f"Error extracting {run_name}: {e}") finally: if os.path.exists(tar_path): os.remove(tar_path) else: print(f"Failed to download {run_name}. Status: {response.status_code}") # --- INITIALIZATION --- # Dictionary to store sessions and graphs for each model loaded_models = {} for model_name in MODELS: download_model(model_name) # Create a unique graph and session for each model to avoid collisions g = tf.Graph() with g.as_default(): s = gpt2.start_tf_sess() gpt2.load_gpt2(s, run_name=model_name) loaded_models[model_name] = {"sess": s, "graph": g} def generate_text(model_choice, prompt, length=100, temperature=0.8, top_k=40): """ Inference function that selects the model based on user choice. """ if model_choice not in loaded_models: return "Error: Model not loaded." selected = loaded_models[model_choice] sess = selected["sess"] graph = selected["graph"] with graph.as_default(): with sess.as_default(): output = gpt2.generate( sess, run_name=model_choice, length=int(length), temperature=float(temperature), top_k=int(top_k), prefix=prompt, return_as_list=True, include_prefix=True )[0] return output # --- GRADIO UI --- description = """ ### OWoTGPT Multi-Model Inference Choose between the standard conversation model or the OWoT source code model. - **owotgpt**: Trained on chat logs. - **owotgpt-code**: Trained on OWoT source code. """ iface = gr.Interface( fn=generate_text, inputs=[ gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model"), gr.Textbox(lines=10, label="Context / Prompt", placeholder="Type your prompt here..."), gr.Slider(10, 500, value=100, step=1, label="Max Generation Length"), gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature"), gr.Slider(1, 100, value=40, step=1, label="Top K"), ], outputs=gr.Textbox(label="Generated Output"), title="OWoTGPT Model Hub", description=description ) if __name__ == "__main__": # mcp_server=True makes both models available as tools via MCP iface.launch(mcp_server=True)