SomePersonAlt's picture
Update app.py
7e6c20f verified
import os
import tarfile
import requests
import gradio as gr
import gpt_2_simple as gpt2
import tensorflow as tf
# --- CONFIGURATION ---
REPO_OWNER = "PizzaTowerFanGD"
REPO_NAME = "owotgpt"
# List of models to load
MODELS = ['owotgpt', 'owotgpt-code']
CHECKPOINT_DIR = "checkpoint"
def download_model(run_name):
"""Downloads and extracts a specific model by run_name."""
model_url = f"https://github.com/{REPO_OWNER}/{REPO_NAME}/releases/latest/download/{run_name}.tar"
tar_path = f"{run_name}.tar"
if not os.path.exists(os.path.join(CHECKPOINT_DIR, run_name)):
print(f"Downloading {run_name} from {model_url}...")
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(tar_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Extracting {tar_path}...")
try:
with tarfile.open(tar_path, 'r') as tar_ref:
tar_ref.extractall(path=".")
print(f"Extraction of {run_name} complete.")
except Exception as e:
print(f"Error extracting {run_name}: {e}")
finally:
if os.path.exists(tar_path):
os.remove(tar_path)
else:
print(f"Failed to download {run_name}. Status: {response.status_code}")
# --- INITIALIZATION ---
# Dictionary to store sessions and graphs for each model
loaded_models = {}
for model_name in MODELS:
download_model(model_name)
# Create a unique graph and session for each model to avoid collisions
g = tf.Graph()
with g.as_default():
s = gpt2.start_tf_sess()
gpt2.load_gpt2(s, run_name=model_name)
loaded_models[model_name] = {"sess": s, "graph": g}
def generate_text(model_choice, prompt, length=100, temperature=0.8, top_k=40):
"""
Inference function that selects the model based on user choice.
"""
if model_choice not in loaded_models:
return "Error: Model not loaded."
selected = loaded_models[model_choice]
sess = selected["sess"]
graph = selected["graph"]
with graph.as_default():
with sess.as_default():
output = gpt2.generate(
sess,
run_name=model_choice,
length=int(length),
temperature=float(temperature),
top_k=int(top_k),
prefix=prompt,
return_as_list=True,
include_prefix=True
)[0]
return output
# --- GRADIO UI ---
description = """
### OWoTGPT Multi-Model Inference
Choose between the standard conversation model or the OWoT source code model.
- **owotgpt**: Trained on chat logs.
- **owotgpt-code**: Trained on OWoT source code.
"""
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model"),
gr.Textbox(lines=10, label="Context / Prompt", placeholder="Type your prompt here..."),
gr.Slider(10, 500, value=100, step=1, label="Max Generation Length"),
gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature"),
gr.Slider(1, 100, value=40, step=1, label="Top K"),
],
outputs=gr.Textbox(label="Generated Output"),
title="OWoTGPT Model Hub",
description=description
)
if __name__ == "__main__":
# mcp_server=True makes both models available as tools via MCP
iface.launch(mcp_server=True)