File size: 4,223 Bytes
15d6fe0
c880c8c
 
 
8203986
0b73bfa
8203986
 
 
 
 
15d6fe0
c880c8c
 
 
 
0b73bfa
8203986
c880c8c
 
 
 
 
0b73bfa
 
 
 
 
 
 
 
 
c880c8c
 
0b73bfa
c880c8c
 
 
 
 
 
 
 
0b73bfa
c880c8c
 
8203986
 
 
c880c8c
 
8203986
c880c8c
 
 
 
 
 
 
 
 
 
 
8203986
 
 
c880c8c
 
 
8203986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c880c8c
 
0b73bfa
8203986
c880c8c
 
8203986
c880c8c
 
 
 
 
8203986
c880c8c
 
 
 
 
 
 
 
 
 
 
8203986
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from fastapi import FastAPI, HTTPException
from transformers import BitsAndBytesConfig
import uvicorn
import json

# Initialize FastAPI app
app = FastAPI()

# Model configuration
CHECKPOINT = "bigcode/starcoder2-15b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and tokenizer with 4-bit quantization
@spaces.GPU(duration=120)
def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95):
    try:
        # Initialize tokenizer
        tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
        
        # Configure 4-bit quantization
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True
        )
        
        # Initialize model with rs
        model = AutoModelForCausalLM.from_pretrained(
            CHECKPOINT,
            quantization_config=quantization_config,
            device_map="auto"
        )
        
        # Create text generation pipeline
        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            device_map="auto"
        )
        
        # Format prompt for chat-like interaction
        chat_prompt = f"User: {prompt}\nAssistant: Let's interpret this as a coding request. Please provide a code-related prompt, or I'll generate a response based on code context.\n{prompt} ```python\n```"
        
        # Generate response
        result = pipe(
            chat_prompt,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            num_return_sequences=1,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            truncation=True
        )
        
        generated_text = result[0]["generated_text"]
        # Extract response after the prompt
        response = generated_text[len(chat_prompt):].strip() if generated_text.startswith(chat_prompt) else generated_text
        return response
    except Exception as e:
        return f"Error: {str(e)}"

# FastAPI endpoint for backdoor-chat
@app.post("/backdoor-chat")
async def backdoor_chat(request: dict):
    try:
        # Validate input
        if not isinstance(request, dict) or "message" not in request:
            raise HTTPException(status_code=400, detail="Request must contain 'message' field")
        
        prompt = request["message"]
        max_length = request.get("max_length", 256)
        temperature = request.get("temperature", 0.2)
        top_p = request.get("top_p", 0.95)
        
        # Generate response
        response = load_model_and_generate(prompt, max_length, temperature, top_p)
        return {"response": response}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("# StarCoder2-15B Chat Interface (4-bit Quantization)")
    gr.Markdown("Enter a prompt to generate code or simulate a chat. Use the API endpoint `/backdoor-chat` for programmatic access.")
    
    # Input components
    prompt = gr.Textbox(label="Message", placeholder="Enter your message (e.g., 'Write a Python function')")
    max_length = gr.Slider(50, 512, value=256, label="Max Length", step=1)
    temperature = gr.Slider(0.1, 1.0, value=0.2, label="Temperature", step=0.1)
    top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top P", step=0.05)
    
    # Output component
    output = gr.Textbox(label="Generated Response")
    
    # Submit button
    submit_btn = gr.Button("Generate")
    
    # Connect button to function
    submit_btn.click(
        fn=load_model_and_generate,
        inputs=[prompt, max_length, temperature, top_p],
        outputs=output
    )

# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

# Run the app (for local testing; Hugging Face handles this in Spaces)
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)