Spaces:
Running
Running
| """ | |
| MM Coder Agent v1 - Gradio Space | |
| A coding assistant AI based on Qwen2.5-1.5B | |
| """ | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from peft import PeftModel, PeftConfig | |
| import os | |
| # Configuration | |
| MODEL_ID = "amkyawdev/mm-coder-agent-v1-combined" | |
| # System prompt for better responses | |
| SYSTEM_PROMPT = """You are a Python coding assistant. | |
| Output requirements: | |
| 1. Output ONLY the code, nothing else | |
| 2. Use proper Python syntax | |
| 3. Include necessary imports | |
| 4. Start with code directly in a python code block | |
| 5. Do not add explanations unless asked | |
| Examples: | |
| Input: print hello world | |
| Output: | |
| ```python | |
| print("Hello, World!") | |
| ```""" | |
| # Global variables to store model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the model and tokenizer""" | |
| global model, tokenizer | |
| print("Loading model...") | |
| peft_config = PeftConfig.from_pretrained(MODEL_ID) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| peft_config.base_model_name_or_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| peft_config.base_model_name_or_path, | |
| trust_remote_code=True | |
| ) | |
| model = PeftModel.from_pretrained(base_model, MODEL_ID) | |
| print("Model loaded successfully!") | |
| def generate_code(prompt, max_tokens=512, temperature=0.7): | |
| """Generate code based on the prompt""" | |
| if model is None: | |
| load_model() | |
| # Format prompt with system instruction | |
| formatted_prompt = f"{SYSTEM_PROMPT}\n\nUser: {prompt}\nAssistant:" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| # Get pad token id safely | |
| pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_tokens), | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=pad_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt and system prompt from response | |
| prefix_to_remove = f"{SYSTEM_PROMPT}\n\nUser: {prompt}\nAssistant:" | |
| if response.startswith(prefix_to_remove): | |
| response = response[len(prefix_to_remove):].strip() | |
| elif response.startswith(prompt): | |
| response = response[len(prompt):].strip() | |
| return response | |
| # Create Gradio Interface | |
| with gr.Blocks(title="🤖 MM Coder Agent v1") as demo: | |
| gr.Markdown(""" | |
| # 🤖 MM Coder Agent v1 | |
| A professional AI coding assistant fine-tuned from Qwen2.5-1.5B-Instruct. | |
| ## Features | |
| - Code Generation in multiple languages | |
| - Bug Detection and Fixing | |
| - Algorithm Implementation | |
| - Code Review Assistance | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Your Prompt", | |
| placeholder="Write a Python function to calculate fibonacci numbers...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(64, 1024, value=512, step=64, label="Max Tokens") | |
| temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") | |
| generate_btn = gr.Button("Generate Code", variant="primary") | |
| with gr.Column(): | |
| output = gr.Code(label="Generated Code", language="python") | |
| generate_btn.click( | |
| generate_code, | |
| [prompt, max_tokens, temperature], | |
| output | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### Examples | |
| - "Write a Python function to reverse a string" | |
| - "Create a React component for a login form" | |
| - "How to fix CORS error in Express.js?" | |
| - "Explain bubble sort algorithm" | |
| """) | |
| # Launch the demo | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |