File size: 4,537 Bytes
30bc440 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import os
import json
import time
import gc
from threading import Thread, Lock
from flask import Flask, request, jsonify, Response, abort
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from duckduckgo_search import DDGS
app = Flask(__name__)
torch.set_num_threads(2)
AVAILABLE_MODELS = [
"google/gemma-4-E2B-it",
"Qwen/Qwen3-4B-Instruct-2507",
"HuggingFaceTB/SmolLM3-3B"
]
DEFAULT_MODEL = "google/gemma-4-E2B-it"
active_model_repo = None
model = None
tokenizer = None
HF_TOKEN = os.environ.get("HF_TOKEN")
model_lock = Lock()
def search_web(query):
try:
results = DDGS().text(query, max_results=3)
if not results:
return ""
context = "CONTEXTO DE INTERNET ACTUALIZADO:\n"
for r in results:
context += f"- {r.get('title')}: {r.get('body')}\n"
return context
except Exception as e:
return ""
def load_model_into_memory(repo_id):
global active_model_repo, model, tokenizer
if repo_id not in AVAILABLE_MODELS:
repo_id = DEFAULT_MODEL
if active_model_repo == repo_id:
return
if model is not None:
del model
del tokenizer
gc.collect()
try:
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
repo_id,
device_map="cpu",
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
trust_remote_code=True
)
active_model_repo = repo_id
except Exception as e:
raise e
@app.route('/', methods=['GET'])
def health_check():
return jsonify({"status": "online", "role": "mirror_worker"})
@app.route('/v1/models', methods=['GET'])
def list_models():
data = [{"id": m, "object": "model"} for m in AVAILABLE_MODELS]
return jsonify({"object": "list", "data": data})
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
# Eliminamos el Rate Limit aquí, porque el Server Principal se encarga
data = request.get_json(silent=True)
if not data or 'messages' not in data:
abort(400, description="Petición inválida")
messages = data.get('messages', [])
requested_model = data.get('model', DEFAULT_MODEL)
temperature = data.get('temperature', 0.6)
max_new_tokens = min(data.get('max_tokens', 1024), 2048)
stream = data.get('stream', False)
use_web_search = data.get('web_search', False)
with model_lock:
try:
load_model_into_memory(requested_model)
if use_web_search and messages and messages[-1]['role'] == 'user':
user_query = messages[-1]['content']
web_context = search_web(user_query)
if web_context:
messages[-1]['content'] = f"Responde usando esta info de internet si es útil:\n{web_context}\n\nPregunta: {user_query}"
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt")
generation_kwargs = dict(
inputs, max_new_tokens=max_new_tokens, temperature=temperature,
do_sample=temperature > 0.0, top_p=0.9
)
if stream:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
def generate_stream():
for new_text in streamer:
if new_text:
yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n"
yield "data: [DONE]\n\n"
return Response(generate_stream(), mimetype='text/event-stream')
else:
outputs = model.generate(**generation_kwargs)
generated_ids = outputs[0][len(inputs.input_ids[0]):]
reply = tokenizer.decode(generated_ids, skip_special_tokens=True)
return jsonify({"choices": [{"message": {"role": "assistant", "content": reply}}]})
except Exception as e:
abort(500, description=str(e))
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |