vapi-gemma-api / app.py
Ilke Ileri
ULTRA SPEED: 8-bit quantization, greedy decoding, 40 tokens, inference_mode
c46fe44
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os
app = Flask(__name__)
# Tüm origin'lere izin ver - Vapi için gerekli
CORS(app, resources={
r"/*": {
"origins": "*",
"methods": ["GET", "POST", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization"],
"expose_headers": ["Content-Type"],
"supports_credentials": False
}
})
# Hugging Face token'ı environment variable'dan al
HF_TOKEN = os.environ.get("HF_TOKEN")
# Model adları
MODEL_NAME = "ilkeileri/gemma-sales-comprehensive"
BASE_MODEL = "google/gemma-1.1-2b-it"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True, token=HF_TOKEN)
print("Loading base model with 8-bit quantization for speed...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=True, # 8-bit quantization for 2-3x speedup
low_cpu_mem_usage=True,
trust_remote_code=True,
token=HF_TOKEN,
device_map="auto"
)
print("Loading LoRA adapters...")
model = PeftModel.from_pretrained(base_model, MODEL_NAME, token=HF_TOKEN)
model.eval()
# Enable torch compile for faster inference (if available)
try:
model = torch.compile(model, mode="reduce-overhead")
print("Torch compile enabled for faster inference")
except Exception as e:
print(f"Torch compile not available: {e}")
# Device'ı belirle
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Model loaded successfully!")
@app.route("/", methods=["GET"])
def health_check():
return jsonify({
"status": "ok",
"model": MODEL_NAME,
"base_model": BASE_MODEL
}), 200
@app.route("/debug", methods=["POST"])
def debug_endpoint():
"""Debug endpoint to see what Vapi is sending"""
try:
data = request.get_json()
headers = dict(request.headers)
return jsonify({
"received_data": data,
"received_headers": headers,
"status": "debug_ok"
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
# Vapi uses /v1/chat/completions, add both routes
@app.route("/chat/completions", methods=["POST", "OPTIONS"])
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
def chat_completions():
# Log incoming request
print("=" * 60)
print(f"NEW REQUEST: {request.method} from {request.remote_addr}")
print(f"Headers: {dict(request.headers)}")
if request.method == "POST":
print(f"Body: {request.get_data(as_text=True)[:500]}...")
print("=" * 60)
# OPTIONS isteği için preflight response
if request.method == "OPTIONS":
response = jsonify({"status": "ok"})
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization")
response.headers.add("Access-Control-Allow-Methods", "GET,POST,OPTIONS")
return response, 200
try:
data = request.get_json()
print(f"Full request data: {data}")
# Check if streaming is requested
stream = data.get("stream", False)
messages = data.get("messages", [])
# Build conversation history in Gemma format
# Only use user and assistant messages, skip system (Vapi's system prompt is just context)
conversation_parts = []
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "user":
conversation_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
elif role == "assistant":
conversation_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
# Skip system messages - Vapi's system prompt is too long for context
# Add the model turn start
conversation_parts.append("<start_of_turn>model\n")
formatted_prompt = "\n".join(conversation_parts)
print(f"Formatted prompt length: {len(formatted_prompt)} chars")
print(f"Last user message: {[m['content'] for m in messages if m.get('role')=='user'][-1] if messages else 'none'}")
# Model yanıtı üret
inputs = tokenizer(formatted_prompt, return_tensors="pt")
# Input'u model ile aynı device'a taşı
if hasattr(model, 'device'):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
print(f"Generating response...")
import time
start_time = time.time()
# Ultra-fast generation settings for <1s response
with torch.inference_mode(): # Disable gradient computation for speed
outputs = model.generate(
**inputs,
max_new_tokens=40, # Minimal tokens for ultra-fast response
do_sample=False, # Greedy decoding (fastest)
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True # Enable KV cache for faster generation
)
elapsed = time.time() - start_time
print(f"⚡ Response generated in {elapsed:.2f}s")
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
# Gemma formatındaki tag'leri temizle
response_text = full_response.split("<start_of_turn>model\n")[-1]
response_text = response_text.replace("<end_of_turn>", "").strip()
# If streaming requested, return SSE format
if stream:
def generate():
# Send the complete response as a single chunk for simplicity
chunk = {
"id": "chatcmpl-" + str(hash(prompt))[-10:],
"object": "chat.completion.chunk",
"created": int(__import__('time').time()),
"model": MODEL_NAME,
"choices": [{
"index": 0,
"delta": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}]
}
yield f"data: {__import__('json').dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return app.response_class(generate(), mimetype='text/event-stream')
# OpenAI-compatible response format (non-streaming)
vapi_response = {
"id": "chatcmpl-" + str(hash(prompt))[-10:],
"object": "chat.completion",
"created": int(__import__('time').time()),
"model": MODEL_NAME,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs["input_ids"][0]),
"completion_tokens": len(outputs[0]) - len(inputs["input_ids"][0]),
"total_tokens": len(outputs[0])
}
}
return jsonify(vapi_response), 200
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"ERROR: {str(e)}")
print(f"Traceback: {error_details}")
# Return error in both streaming and non-streaming format
error_response = {
"error": {
"message": str(e),
"type": type(e).__name__,
"code": "internal_error"
}
}
# If streaming was requested, return error in SSE format
if 'stream' in locals() and stream:
def error_stream():
yield f"data: {__import__('json').dumps(error_response)}\n\n"
yield "data: [DONE]\n\n"
return app.response_class(error_stream(), mimetype='text/event-stream')
return jsonify(error_response), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)