| |
| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import json |
| import os |
|
|
| |
| model_name = "bigcode/starcoder" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
|
|
| |
| CACHE_FILE = "cache.json" |
| cache = {} |
|
|
| |
| if os.path.exists(CACHE_FILE): |
| with open(CACHE_FILE, "r") as f: |
| cache = json.load(f) |
|
|
| def code_assistant(prompt, language): |
| |
| if not prompt.strip(): |
| return "Error: The input prompt cannot be empty. Please provide a coding question or code snippet." |
| if len(prompt) > 256: |
| return "Error: The input prompt is too long. Please limit it to 256 characters." |
|
|
| |
| cache_key = (prompt, language) |
| if str(cache_key) in cache: |
| return cache[str(cache_key)] |
|
|
| |
| if language: |
| prompt = f"[{language}] {prompt}" |
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
| |
| outputs = model.generate( |
| inputs.input_ids, |
| max_length=128, |
| temperature=0.1, |
| top_p=0.8, |
| do_sample=True |
| ) |
| |
| |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| if len(cache) >= 10: |
| cache.pop(next(iter(cache))) |
| cache[str(cache_key)] = generated_text |
|
|
| |
| with open(CACHE_FILE, "w") as f: |
| json.dump(cache, f) |
|
|
| return generated_text |
|
|
| |
| iface = gr.Interface( |
| fn=code_assistant, |
| inputs=[ |
| gr.Textbox(lines=5, placeholder="Ask a coding question or paste your code here..."), |
| gr.Dropdown(choices=["Python", "JavaScript", "Java", "C++", "HTML", "CSS", "SQL", "Other"], label="Programming Language") |
| ], |
| outputs="text", |
| title="Code Assistant with StarCoder", |
| description="An AI code assistant to help you with coding queries, debugging, and code generation. Specify the programming language for more accurate responses." |
| ) |
|
|
| |
| iface.launch() |
|
|