import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from fastapi import FastAPI, HTTPException from transformers import BitsAndBytesConfig import uvicorn import json # Initialize FastAPI app app = FastAPI() # Model configuration CHECKPOINT = "bigcode/starcoder2-15b" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load model and tokenizer with 4-bit quantization @spaces.GPU(duration=120) def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95): try: # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) # Configure 4-bit quantization quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) # Initialize model with rs model = AutoModelForCausalLM.from_pretrained( CHECKPOINT, quantization_config=quantization_config, device_map="auto" ) # Create text generation pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="auto" ) # Format prompt for chat-like interaction chat_prompt = f"User: {prompt}\nAssistant: Let's interpret this as a coding request. Please provide a code-related prompt, or I'll generate a response based on code context.\n{prompt} ```python\n```" # Generate response result = pipe( chat_prompt, max_length=max_length, temperature=temperature, top_p=top_p, num_return_sequences=1, do_sample=True, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, truncation=True ) generated_text = result[0]["generated_text"] # Extract response after the prompt response = generated_text[len(chat_prompt):].strip() if generated_text.startswith(chat_prompt) else generated_text return response except Exception as e: return f"Error: {str(e)}" # FastAPI endpoint for backdoor-chat @app.post("/backdoor-chat") async def backdoor_chat(request: dict): try: # Validate input if not isinstance(request, dict) or "message" not in request: raise HTTPException(status_code=400, detail="Request must contain 'message' field") prompt = request["message"] max_length = request.get("max_length", 256) temperature = request.get("temperature", 0.2) top_p = request.get("top_p", 0.95) # Generate response response = load_model_and_generate(prompt, max_length, temperature, top_p) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Gradio interface setup with gr.Blocks() as demo: gr.Markdown("# StarCoder2-15B Chat Interface (4-bit Quantization)") gr.Markdown("Enter a prompt to generate code or simulate a chat. Use the API endpoint `/backdoor-chat` for programmatic access.") # Input components prompt = gr.Textbox(label="Message", placeholder="Enter your message (e.g., 'Write a Python function')") max_length = gr.Slider(50, 512, value=256, label="Max Length", step=1) temperature = gr.Slider(0.1, 1.0, value=0.2, label="Temperature", step=0.1) top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top P", step=0.05) # Output component output = gr.Textbox(label="Generated Response") # Submit button submit_btn = gr.Button("Generate") # Connect button to function submit_btn.click( fn=load_model_and_generate, inputs=[prompt, max_length, temperature, top_p], outputs=output ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/") # Run the app (for local testing; Hugging Face handles this in Spaces) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)