import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM REPO_ID = "Tbain20/olmo2-1b-eeg-v11" PREFIX_I = "### Instruction:\n" PREFIX_R = "\n\n### Response:\n" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-7B") print("Loading model...") model = AutoModelForCausalLM.from_pretrained( REPO_ID, ignore_mismatched_sizes=True, ) model.eval() print("Ready") @torch.no_grad() def generate_code(prompt): if not prompt.strip(): return "Please enter a prompt." full = PREFIX_I + prompt.strip() + PREFIX_R ids = tokenizer(full, return_tensors="pt").input_ids out = model.generate( ids, max_new_tokens=600, do_sample=True, temperature=0.7, top_k=40, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated = out[0][ids.shape[1]:] result = tokenizer.decode(generated, skip_special_tokens=True).strip() # If cut off mid-function add a note if result and not result.rstrip().endswith((')', 'return', 'pass')) and result.count('def ') > 0: result += "\n # [Note: generation truncated — use local demo for complete output]" return result EXAMPLES = [ ["Write a Python function using MNE to filter EEG for beta waves (13-30 Hz)"], ["Write a Python function to compute beta band PSD using Welch method"], ["Write a Python function to load TDT block and extract RSn1 stream"], ] with gr.Blocks(title="OLMo EEG Code Generator") as demo: gr.Markdown("""# OLMo EEG Code Generator ### NDML Lab — Cleveland Clinic **Note:** This demo runs on CPU and may truncate long functions. For complete output use the lab server: `http://100.104.177.70:7860`""") prompt_box = gr.Textbox(label="Prompt", lines=3, placeholder="Write a Python function to...") generate_btn = gr.Button("Generate", variant="primary") output_box = gr.Textbox(label="Generated Code", lines=25) gr.Examples(examples=EXAMPLES, inputs=prompt_box) generate_btn.click(fn=generate_code, inputs=prompt_box, outputs=output_box) prompt_box.submit(fn=generate_code, inputs=prompt_box, outputs=output_box) demo.launch()