|
|
import os |
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
from datetime import datetime |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv(“HF_TOKEN”) |
|
|
if HF_TOKEN: |
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title=“Code Assistant”, |
|
|
page_icon=“🧠”, |
|
|
layout=“wide” |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if “history” not in st.session_state: |
|
|
st.session_state.history = [] |
|
|
|
|
|
if “datasets” not in st.session_state: |
|
|
st.session_state.datasets = {} |
|
|
|
|
|
if “pipelines” not in st.session_state: |
|
|
st.session_state.pipelines = {} |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_textgen_pipeline(model_id: str): |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map=“auto”, |
|
|
torch_dtype=“auto” |
|
|
) |
|
|
return pipeline( |
|
|
“text-generation”, |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.3 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def generate_code(prompt, model_id, dataset_name=None): |
|
|
if model_id not in st.session_state.pipelines: |
|
|
st.session_state.pipelines[model_id] = load_textgen_pipeline(model_id) |
|
|
|
|
|
``` |
|
|
pipe = st.session_state.pipelines[model_id] |
|
|
|
|
|
context = "" |
|
|
if dataset_name: |
|
|
dataset = st.session_state.datasets[dataset_name] |
|
|
context = f"\n# Dataset preview:\n{dataset[:3]}\n" |
|
|
|
|
|
full_prompt = f"""You are a precise coding assistant. |
|
|
``` |
|
|
|
|
|
Write clean, correct code. |
|
|
{context} |
|
|
Prompt: |
|
|
{prompt} |
|
|
“”” |
|
|
|
|
|
``` |
|
|
result = pipe(full_prompt)[0]["generated_text"] |
|
|
return result |
|
|
``` |
|
|
|
|
|
# ––––– Sidebar ––––– |
|
|
|
|
|
with st.sidebar: |
|
|
st.header(“⚙️ Control Plane”) |
|
|
|
|
|
``` |
|
|
model_id = st.selectbox( |
|
|
"Model", |
|
|
[ |
|
|
"microsoft/phi-2", |
|
|
"codellama/CodeLlama-7b-hf", |
|
|
"bigcode/starcoder2-3b", |
|
|
] |
|
|
) |
|
|
|
|
|
st.divider() |
|
|
|
|
|
st.subheader("📦 Dataset") |
|
|
|
|
|
dataset_source = st.radio( |
|
|
"Dataset source", |
|
|
["None", "Hugging Face Hub", "Upload file"] |
|
|
) |
|
|
|
|
|
dataset_name = None |
|
|
|
|
|
if dataset_source == "Hugging Face Hub": |
|
|
hf_dataset_id = st.text_input( |
|
|
"Dataset repo (e.g. squad, openwebtext)" |
|
|
) |
|
|
|
|
|
if st.button("Load dataset") and hf_dataset_id: |
|
|
ds = load_dataset(hf_dataset_id, split="train[:100]") |
|
|
st.session_state.datasets[hf_dataset_id] = ds |
|
|
dataset_name = hf_dataset_id |
|
|
st.success(f"Loaded {hf_dataset_id}") |
|
|
|
|
|
elif dataset_source == "Upload file": |
|
|
uploaded = st.file_uploader("Upload CSV", type=["csv"]) |
|
|
if uploaded: |
|
|
df = pd.read_csv(uploaded) |
|
|
st.session_state.datasets[uploaded.name] = df.to_dict(orient="records") |
|
|
dataset_name = uploaded.name |
|
|
st.success(f"Loaded {uploaded.name}") |
|
|
|
|
|
if st.session_state.datasets: |
|
|
dataset_name = st.selectbox( |
|
|
"Active dataset", |
|
|
options=[None] + list(st.session_state.datasets.keys()) |
|
|
) |
|
|
``` |
|
|
|
|
|
# ––––– Main UI ––––– |
|
|
|
|
|
st.title(“🧠 Code Assistant”) |
|
|
st.caption(“Transformers + Datasets + Hugging Face Hub”) |
|
|
|
|
|
prompt = st.text_area( |
|
|
“Coding prompt”, |
|
|
height=150, |
|
|
placeholder=“Generate a Python function that validates the dataset schema” |
|
|
) |
|
|
|
|
|
if st.button(“Generate”, type=“primary”): |
|
|
with st.spinner(“Thinking…”): |
|
|
output = generate_code(prompt, model_id, dataset_name) |
|
|
|
|
|
``` |
|
|
st.session_state.history.append({ |
|
|
"time": datetime.now().strftime("%H:%M:%S"), |
|
|
"model": model_id, |
|
|
"dataset": dataset_name, |
|
|
"prompt": prompt, |
|
|
"output": output |
|
|
}) |
|
|
``` |
|
|
|
|
|
# ––––– Output ––––– |
|
|
|
|
|
for item in reversed(st.session_state.history): |
|
|
label = f”[{item[‘time’]}] {item[‘model’]}” |
|
|
if item[“dataset”]: |
|
|
label += f” | {item[‘dataset’]}” |
|
|
|
|
|
``` |
|
|
with st.expander(label): |
|
|
st.code(item["output"], language="python") |
|
|
``` |