nl-code-gen-bot / app.py
akshaybharadwaj96's picture
update app.py to share link
a175106
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel
# =====================
# Config
# =====================
BASE_MODEL = "Salesforce/codegen-350M-mono" # base model
ADAPTER_MODEL = "akshaybharadwaj96/nl-code-gen-python" # fine-tuned LoRA adapter repo
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# =====================
# Load tokenizer & model
# =====================
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32,
device_map="auto" if DEVICE=="cuda" else None
)
# Load LoRA adapter on top of base model
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
model = model.to(DEVICE)
model.eval()
# Define the code generation function
def generate_code(prompt, chat_history):
"""
prompt: user message
chat_history: previous turns (list of [user, bot])
"""
# Combine history for context if needed
full_prompt = ""
for user, bot in chat_history:
full_prompt += f"User: {user}\nAssistant: {bot}\n"
full_prompt += f"User: {prompt}\nAssistant:"
# Tokenize input
inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
# Generate code output
outputs = model.generate(
**inputs,
max_new_tokens=500,
temperature=0.2,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new assistant part
new_response = response[len(full_prompt):].strip()
print(f'bit response : {new_response}')
# Add to chat history
chat_history.append((prompt, new_response))
return chat_history, chat_history
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 💬 Text-to-Code Assistant")
gr.Markdown("Type a natural language instruction and get code suggestions!")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(placeholder="e.g., 'Write a Python function to sort a list using merge sort'")
clear = gr.Button("Clear Chat")
msg.submit(generate_code, [msg, chatbot], [chatbot, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(share=True)