Tbain20's picture
Increase to 600 tokens and add truncation note
1aa4511
Raw
History Blame Contribute Delete
2.26 kB
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
REPO_ID = "Tbain20/olmo2-1b-eeg-v11"
PREFIX_I = "### Instruction:\n"
PREFIX_R = "\n\n### Response:\n"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-7B")
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
REPO_ID,
ignore_mismatched_sizes=True,
)
model.eval()
print("Ready")
@torch.no_grad()
def generate_code(prompt):
if not prompt.strip():
return "Please enter a prompt."
full = PREFIX_I + prompt.strip() + PREFIX_R
ids = tokenizer(full, return_tensors="pt").input_ids
out = model.generate(
ids,
max_new_tokens=600,
do_sample=True,
temperature=0.7,
top_k=40,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated = out[0][ids.shape[1]:]
result = tokenizer.decode(generated, skip_special_tokens=True).strip()
# If cut off mid-function add a note
if result and not result.rstrip().endswith((')', 'return', 'pass')) and result.count('def ') > 0:
result += "\n # [Note: generation truncated — use local demo for complete output]"
return result
EXAMPLES = [
["Write a Python function using MNE to filter EEG for beta waves (13-30 Hz)"],
["Write a Python function to compute beta band PSD using Welch method"],
["Write a Python function to load TDT block and extract RSn1 stream"],
]
with gr.Blocks(title="OLMo EEG Code Generator") as demo:
gr.Markdown("""# OLMo EEG Code Generator
### NDML Lab — Cleveland Clinic
**Note:** This demo runs on CPU and may truncate long functions.
For complete output use the lab server: `http://100.104.177.70:7860`""")
prompt_box = gr.Textbox(label="Prompt", lines=3,
placeholder="Write a Python function to...")
generate_btn = gr.Button("Generate", variant="primary")
output_box = gr.Textbox(label="Generated Code", lines=25)
gr.Examples(examples=EXAMPLES, inputs=prompt_box)
generate_btn.click(fn=generate_code, inputs=prompt_box, outputs=output_box)
prompt_box.submit(fn=generate_code, inputs=prompt_box, outputs=output_box)
demo.launch()