Spaces:
Runtime error
Runtime error
File size: 4,223 Bytes
15d6fe0 c880c8c 8203986 0b73bfa 8203986 15d6fe0 c880c8c 0b73bfa 8203986 c880c8c 0b73bfa c880c8c 0b73bfa c880c8c 0b73bfa c880c8c 8203986 c880c8c 8203986 c880c8c 8203986 c880c8c 8203986 c880c8c 0b73bfa 8203986 c880c8c 8203986 c880c8c 8203986 c880c8c 8203986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from fastapi import FastAPI, HTTPException
from transformers import BitsAndBytesConfig
import uvicorn
import json
# Initialize FastAPI app
app = FastAPI()
# Model configuration
CHECKPOINT = "bigcode/starcoder2-15b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer with 4-bit quantization
@spaces.GPU(duration=120)
def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95):
try:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Initialize model with rs
model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
quantization_config=quantization_config,
device_map="auto"
)
# Create text generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# Format prompt for chat-like interaction
chat_prompt = f"User: {prompt}\nAssistant: Let's interpret this as a coding request. Please provide a code-related prompt, or I'll generate a response based on code context.\n{prompt} ```python\n```"
# Generate response
result = pipe(
chat_prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
truncation=True
)
generated_text = result[0]["generated_text"]
# Extract response after the prompt
response = generated_text[len(chat_prompt):].strip() if generated_text.startswith(chat_prompt) else generated_text
return response
except Exception as e:
return f"Error: {str(e)}"
# FastAPI endpoint for backdoor-chat
@app.post("/backdoor-chat")
async def backdoor_chat(request: dict):
try:
# Validate input
if not isinstance(request, dict) or "message" not in request:
raise HTTPException(status_code=400, detail="Request must contain 'message' field")
prompt = request["message"]
max_length = request.get("max_length", 256)
temperature = request.get("temperature", 0.2)
top_p = request.get("top_p", 0.95)
# Generate response
response = load_model_and_generate(prompt, max_length, temperature, top_p)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("# StarCoder2-15B Chat Interface (4-bit Quantization)")
gr.Markdown("Enter a prompt to generate code or simulate a chat. Use the API endpoint `/backdoor-chat` for programmatic access.")
# Input components
prompt = gr.Textbox(label="Message", placeholder="Enter your message (e.g., 'Write a Python function')")
max_length = gr.Slider(50, 512, value=256, label="Max Length", step=1)
temperature = gr.Slider(0.1, 1.0, value=0.2, label="Temperature", step=0.1)
top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top P", step=0.05)
# Output component
output = gr.Textbox(label="Generated Response")
# Submit button
submit_btn = gr.Button("Generate")
# Connect button to function
submit_btn.click(
fn=load_model_and_generate,
inputs=[prompt, max_length, temperature, top_p],
outputs=output
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
# Run the app (for local testing; Hugging Face handles this in Spaces)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |