Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| REPO_ID = "Tbain20/olmo2-1b-eeg-v11" | |
| PREFIX_I = "### Instruction:\n" | |
| PREFIX_R = "\n\n### Response:\n" | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-7B") | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| REPO_ID, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| model.eval() | |
| print("Ready") | |
| def generate_code(prompt): | |
| if not prompt.strip(): | |
| return "Please enter a prompt." | |
| full = PREFIX_I + prompt.strip() + PREFIX_R | |
| ids = tokenizer(full, return_tensors="pt").input_ids | |
| out = model.generate( | |
| ids, | |
| max_new_tokens=600, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=40, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = out[0][ids.shape[1]:] | |
| result = tokenizer.decode(generated, skip_special_tokens=True).strip() | |
| # If cut off mid-function add a note | |
| if result and not result.rstrip().endswith((')', 'return', 'pass')) and result.count('def ') > 0: | |
| result += "\n # [Note: generation truncated — use local demo for complete output]" | |
| return result | |
| EXAMPLES = [ | |
| ["Write a Python function using MNE to filter EEG for beta waves (13-30 Hz)"], | |
| ["Write a Python function to compute beta band PSD using Welch method"], | |
| ["Write a Python function to load TDT block and extract RSn1 stream"], | |
| ] | |
| with gr.Blocks(title="OLMo EEG Code Generator") as demo: | |
| gr.Markdown("""# OLMo EEG Code Generator | |
| ### NDML Lab — Cleveland Clinic | |
| **Note:** This demo runs on CPU and may truncate long functions. | |
| For complete output use the lab server: `http://100.104.177.70:7860`""") | |
| prompt_box = gr.Textbox(label="Prompt", lines=3, | |
| placeholder="Write a Python function to...") | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| output_box = gr.Textbox(label="Generated Code", lines=25) | |
| gr.Examples(examples=EXAMPLES, inputs=prompt_box) | |
| generate_btn.click(fn=generate_code, inputs=prompt_box, outputs=output_box) | |
| prompt_box.submit(fn=generate_code, inputs=prompt_box, outputs=output_box) | |
| demo.launch() | |