Swarnimm22HF's picture
Fix Gradio slider error - simplify interface
b12132a
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
FINETUNED_MODEL = "Swarnimm22HF/ai-code-review-tinyllama"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
base_model,
FINETUNED_MODEL,
device_map=None
)
model.eval()
print("Model ready.")
def generate_review(code):
if not code.strip():
return "Please enter a Python function."
prompt = f"""### Instruction:
You are an expert Python code reviewer. Analyze the following function and provide specific, actionable feedback covering:
1. What the function does
2. Bugs or edge cases (e.g. division by zero, null inputs, type errors)
3. Code quality issues (naming, readability, structure)
4. Specific improvements with examples
### Code:
```python
{code}
```
### Response:
**What this function does:**"""
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
demo = gr.Interface(
fn=generate_review,
inputs=gr.Code(
language="python",
label="Python Function",
value="def divide(a, b):\n return a / b"
),
outputs=gr.Textbox(label="AI Code Review", lines=10),
title="🔍 AI Code Review Assistant",
description="Fine-tuned TinyLlama 1.1B for automated Python code review. Model: Swarnimm22HF/ai-code-review-tinyllama | ROUGE-L: +261% vs base",
examples=[
["def divide(a, b):\n return a / b"],
["def find_item(lst, target):\n for i in range(len(lst)):\n if lst[i] == target:\n return i\n return -1"],
["def get_user(users, id):\n for i in range(len(users)):\n if users[i]['id'] == id:\n data = users[i]\n return data"]
]
)
demo.launch()