fingpt / app.py
revana's picture
Upload app.py
f984d6b verified
#!/usr/bin/env python3
"""fingpt Gradio app β€” Chat + Code Correction tabs.
Serves two modes:
β€’ Chat β€” multi-turn conversation (history preserved per session)
β€’ Code β€” single-turn code correction with examples
Launch
------
python app.py --adapter weights_lora_coder_1b5/adapter_final.pt
uv run app.py --adapter weights_lora_coder_1b5/adapter_final.pt
make app # uses adapter_final.pt automatically
make app-share # public Gradio link
"""
import argparse
import os
import sys
from pathlib import Path
from typing import List, Tuple
_HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(_HERE))
from infer import generate, load_model # noqa: E402
# ── Generation helpers ────────────────────────────────────────────────────────
def _chat_generate(model, tokenizer, history: List[Tuple[str, str]],
message: str, max_new_tokens: int, temperature: float) -> str:
"""Multi-turn: build messages list from Gradio history + new message."""
from transformers import AutoTokenizer # noqa: F401 (already loaded)
import torch
messages = []
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
device = next(model.parameters()).device
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else 1.0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
new_ids = outputs[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_ids, skip_special_tokens=True)
# ── Code examples ─────────────────────────────────────────────────────────────
CODE_EXAMPLES = [
["Fix this Python bug:\n\n```python\ndef factorial(n):\n if n == 0:\n return 1\n return n * factorial(n) # missing -1\n```\n\nError: RecursionError: maximum recursion depth exceeded"],
["What's wrong?\n\n```python\ndef divide(a, b):\n return a / b\n\nresult = divide(10, 0)\n```"],
["Fix the KeyError:\n\n```python\nd = {'a': 1, 'b': 2}\nprint(d['c'])\n```"],
["Rewrite without stack overflow:\n\n```python\ndef fib(n):\n if n <= 1: return n\n return fib(n-1) + fib(n-2)\n```"],
]
CHAT_EXAMPLES = [
"Explain the difference between a list and a tuple in Python.",
"What is the time complexity of binary search?",
"How does gradient descent work?",
"Write a Python function to check if a string is a palindrome.",
]
# ── App ───────────────────────────────────────────────────────────────────────
def make_app(model, tokenizer):
import gradio as gr
# ── Settings row (shared across tabs) ─────────────────────────────────────
with gr.Blocks(title="fingpt", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# fingpt\n"
"**Qwen2.5-Coder-1.5B-Instruct** fine-tuned with LoRA on "
"[Code-Feedback](https://huggingface.co/datasets/m-a-p/Code-Feedback) "
"(66K error→fix pairs, 3 epochs)."
)
with gr.Row():
max_tokens = gr.Slider(64, 1024, value=512, step=64,
label="Max new tokens", scale=2)
temperature = gr.Slider(0.0, 1.0, value=0.1, step=0.05,
label="Temperature (0 = greedy)", scale=2)
with gr.Tabs():
# ── Chat tab ──────────────────────────────────────────────────────
with gr.Tab("πŸ’¬ Chat"):
chatbot = gr.Chatbot(height=480, label="fingpt Chat", type="tuples")
chat_input = gr.Textbox(
placeholder="Ask anything β€” coding, concepts, debugging…",
label="Message",
lines=2,
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear")
gr.Examples(CHAT_EXAMPLES, inputs=chat_input, label="Starter prompts")
def chat_respond(message, history, max_tok, temp):
if not message.strip():
return history, ""
reply = _chat_generate(model, tokenizer, history,
message, int(max_tok), float(temp))
history = history + [(message, reply)]
return history, ""
send_btn.click(
chat_respond,
inputs=[chat_input, chatbot, max_tokens, temperature],
outputs=[chatbot, chat_input],
)
chat_input.submit(
chat_respond,
inputs=[chat_input, chatbot, max_tokens, temperature],
outputs=[chatbot, chat_input],
)
clear_btn.click(lambda: ([], ""), outputs=[chatbot, chat_input])
# ── Code correction tab ───────────────────────────────────────────
with gr.Tab("πŸ› οΈ Code Correction"):
with gr.Row():
with gr.Column():
code_input = gr.Textbox(
label="Code / Error / Question",
lines=14,
placeholder="Paste broken code + error message…",
)
fix_btn = gr.Button("Fix β†’", variant="primary")
with gr.Column():
code_output = gr.Textbox(label="Response", lines=18)
gr.Examples(CODE_EXAMPLES, inputs=code_input, label="Examples")
def code_respond(prompt, max_tok, temp):
if not prompt.strip():
return ""
return generate(model, tokenizer, prompt, int(max_tok), float(temp))
fix_btn.click(
code_respond,
inputs=[code_input, max_tokens, temperature],
outputs=code_output,
)
code_input.submit(
code_respond,
inputs=[code_input, max_tokens, temperature],
outputs=code_output,
)
return demo
# ── Entry point ───────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description="fingpt app β€” Chat + Code tabs")
parser.add_argument(
"--adapter",
default=os.environ.get("ADAPTER", "weights_lora_coder_1b5/adapter_final.pt"),
help="Path to adapter .pt file (or set ADAPTER env var)",
)
parser.add_argument("--share", action="store_true",
help="Create a public Gradio share link")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
model, tokenizer = load_model(args.adapter)
demo = make_app(model, tokenizer)
demo.launch(share=args.share, server_port=args.port)
if __name__ == "__main__":
main()