Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from fastapi import FastAPI, HTTPException | |
| from transformers import BitsAndBytesConfig | |
| import uvicorn | |
| import json | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Model configuration | |
| CHECKPOINT = "bigcode/starcoder2-15b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model and tokenizer with 4-bit quantization | |
| def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95): | |
| try: | |
| # Initialize tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) | |
| # Configure 4-bit quantization | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| # Initialize model with rs | |
| model = AutoModelForCausalLM.from_pretrained( | |
| CHECKPOINT, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| # Create text generation pipeline | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device_map="auto" | |
| ) | |
| # Format prompt for chat-like interaction | |
| chat_prompt = f"User: {prompt}\nAssistant: Let's interpret this as a coding request. Please provide a code-related prompt, or I'll generate a response based on code context.\n{prompt} ```python\n```" | |
| # Generate response | |
| result = pipe( | |
| chat_prompt, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| truncation=True | |
| ) | |
| generated_text = result[0]["generated_text"] | |
| # Extract response after the prompt | |
| response = generated_text[len(chat_prompt):].strip() if generated_text.startswith(chat_prompt) else generated_text | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # FastAPI endpoint for backdoor-chat | |
| async def backdoor_chat(request: dict): | |
| try: | |
| # Validate input | |
| if not isinstance(request, dict) or "message" not in request: | |
| raise HTTPException(status_code=400, detail="Request must contain 'message' field") | |
| prompt = request["message"] | |
| max_length = request.get("max_length", 256) | |
| temperature = request.get("temperature", 0.2) | |
| top_p = request.get("top_p", 0.95) | |
| # Generate response | |
| response = load_model_and_generate(prompt, max_length, temperature, top_p) | |
| return {"response": response} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Gradio interface setup | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# StarCoder2-15B Chat Interface (4-bit Quantization)") | |
| gr.Markdown("Enter a prompt to generate code or simulate a chat. Use the API endpoint `/backdoor-chat` for programmatic access.") | |
| # Input components | |
| prompt = gr.Textbox(label="Message", placeholder="Enter your message (e.g., 'Write a Python function')") | |
| max_length = gr.Slider(50, 512, value=256, label="Max Length", step=1) | |
| temperature = gr.Slider(0.1, 1.0, value=0.2, label="Temperature", step=0.1) | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top P", step=0.05) | |
| # Output component | |
| output = gr.Textbox(label="Generated Response") | |
| # Submit button | |
| submit_btn = gr.Button("Generate") | |
| # Connect button to function | |
| submit_btn.click( | |
| fn=load_model_and_generate, | |
| inputs=[prompt, max_length, temperature, top_p], | |
| outputs=output | |
| ) | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Run the app (for local testing; Hugging Face handles this in Spaces) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |