Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import torch | |
| import os | |
| app = Flask(__name__) | |
| # Tüm origin'lere izin ver - Vapi için gerekli | |
| CORS(app, resources={ | |
| r"/*": { | |
| "origins": "*", | |
| "methods": ["GET", "POST", "OPTIONS"], | |
| "allow_headers": ["Content-Type", "Authorization"], | |
| "expose_headers": ["Content-Type"], | |
| "supports_credentials": False | |
| } | |
| }) | |
| # Hugging Face token'ı environment variable'dan al | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Model adları | |
| MODEL_NAME = "ilkeileri/gemma-sales-comprehensive" | |
| BASE_MODEL = "google/gemma-1.1-2b-it" | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True, token=HF_TOKEN) | |
| print("Loading base model with 8-bit quantization for speed...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| load_in_8bit=True, # 8-bit quantization for 2-3x speedup | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| token=HF_TOKEN, | |
| device_map="auto" | |
| ) | |
| print("Loading LoRA adapters...") | |
| model = PeftModel.from_pretrained(base_model, MODEL_NAME, token=HF_TOKEN) | |
| model.eval() | |
| # Enable torch compile for faster inference (if available) | |
| try: | |
| model = torch.compile(model, mode="reduce-overhead") | |
| print("Torch compile enabled for faster inference") | |
| except Exception as e: | |
| print(f"Torch compile not available: {e}") | |
| # Device'ı belirle | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| print("Model loaded successfully!") | |
| def health_check(): | |
| return jsonify({ | |
| "status": "ok", | |
| "model": MODEL_NAME, | |
| "base_model": BASE_MODEL | |
| }), 200 | |
| def debug_endpoint(): | |
| """Debug endpoint to see what Vapi is sending""" | |
| try: | |
| data = request.get_json() | |
| headers = dict(request.headers) | |
| return jsonify({ | |
| "received_data": data, | |
| "received_headers": headers, | |
| "status": "debug_ok" | |
| }), 200 | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # Vapi uses /v1/chat/completions, add both routes | |
| def chat_completions(): | |
| # Log incoming request | |
| print("=" * 60) | |
| print(f"NEW REQUEST: {request.method} from {request.remote_addr}") | |
| print(f"Headers: {dict(request.headers)}") | |
| if request.method == "POST": | |
| print(f"Body: {request.get_data(as_text=True)[:500]}...") | |
| print("=" * 60) | |
| # OPTIONS isteği için preflight response | |
| if request.method == "OPTIONS": | |
| response = jsonify({"status": "ok"}) | |
| response.headers.add("Access-Control-Allow-Origin", "*") | |
| response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization") | |
| response.headers.add("Access-Control-Allow-Methods", "GET,POST,OPTIONS") | |
| return response, 200 | |
| try: | |
| data = request.get_json() | |
| print(f"Full request data: {data}") | |
| # Check if streaming is requested | |
| stream = data.get("stream", False) | |
| messages = data.get("messages", []) | |
| # Build conversation history in Gemma format | |
| # Only use user and assistant messages, skip system (Vapi's system prompt is just context) | |
| conversation_parts = [] | |
| for msg in messages: | |
| role = msg.get("role") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| conversation_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") | |
| elif role == "assistant": | |
| conversation_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") | |
| # Skip system messages - Vapi's system prompt is too long for context | |
| # Add the model turn start | |
| conversation_parts.append("<start_of_turn>model\n") | |
| formatted_prompt = "\n".join(conversation_parts) | |
| print(f"Formatted prompt length: {len(formatted_prompt)} chars") | |
| print(f"Last user message: {[m['content'] for m in messages if m.get('role')=='user'][-1] if messages else 'none'}") | |
| # Model yanıtı üret | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt") | |
| # Input'u model ile aynı device'a taşı | |
| if hasattr(model, 'device'): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| print(f"Generating response...") | |
| import time | |
| start_time = time.time() | |
| # Ultra-fast generation settings for <1s response | |
| with torch.inference_mode(): # Disable gradient computation for speed | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=40, # Minimal tokens for ultra-fast response | |
| do_sample=False, # Greedy decoding (fastest) | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| use_cache=True # Enable KV cache for faster generation | |
| ) | |
| elapsed = time.time() - start_time | |
| print(f"⚡ Response generated in {elapsed:.2f}s") | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| # Gemma formatındaki tag'leri temizle | |
| response_text = full_response.split("<start_of_turn>model\n")[-1] | |
| response_text = response_text.replace("<end_of_turn>", "").strip() | |
| # If streaming requested, return SSE format | |
| if stream: | |
| def generate(): | |
| # Send the complete response as a single chunk for simplicity | |
| chunk = { | |
| "id": "chatcmpl-" + str(hash(prompt))[-10:], | |
| "object": "chat.completion.chunk", | |
| "created": int(__import__('time').time()), | |
| "model": MODEL_NAME, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "role": "assistant", | |
| "content": response_text | |
| }, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {__import__('json').dumps(chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return app.response_class(generate(), mimetype='text/event-stream') | |
| # OpenAI-compatible response format (non-streaming) | |
| vapi_response = { | |
| "id": "chatcmpl-" + str(hash(prompt))[-10:], | |
| "object": "chat.completion", | |
| "created": int(__import__('time').time()), | |
| "model": MODEL_NAME, | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": response_text | |
| }, | |
| "finish_reason": "stop" | |
| }], | |
| "usage": { | |
| "prompt_tokens": len(inputs["input_ids"][0]), | |
| "completion_tokens": len(outputs[0]) - len(inputs["input_ids"][0]), | |
| "total_tokens": len(outputs[0]) | |
| } | |
| } | |
| return jsonify(vapi_response), 200 | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"ERROR: {str(e)}") | |
| print(f"Traceback: {error_details}") | |
| # Return error in both streaming and non-streaming format | |
| error_response = { | |
| "error": { | |
| "message": str(e), | |
| "type": type(e).__name__, | |
| "code": "internal_error" | |
| } | |
| } | |
| # If streaming was requested, return error in SSE format | |
| if 'stream' in locals() and stream: | |
| def error_stream(): | |
| yield f"data: {__import__('json').dumps(error_response)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return app.response_class(error_stream(), mimetype='text/event-stream') | |
| return jsonify(error_response), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |