import os import json import time import gc from threading import Thread, Lock from flask import Flask, request, jsonify, Response, abort import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from duckduckgo_search import DDGS app = Flask(__name__) torch.set_num_threads(2) AVAILABLE_MODELS = [ "google/gemma-4-E2B-it", "Qwen/Qwen3-4B-Instruct-2507", "HuggingFaceTB/SmolLM3-3B" ] DEFAULT_MODEL = "google/gemma-4-E2B-it" active_model_repo = None model = None tokenizer = None HF_TOKEN = os.environ.get("HF_TOKEN") model_lock = Lock() def search_web(query): try: results = DDGS().text(query, max_results=3) if not results: return "" context = "CONTEXTO DE INTERNET ACTUALIZADO:\n" for r in results: context += f"- {r.get('title')}: {r.get('body')}\n" return context except Exception as e: return "" def load_model_into_memory(repo_id): global active_model_repo, model, tokenizer if repo_id not in AVAILABLE_MODELS: repo_id = DEFAULT_MODEL if active_model_repo == repo_id: return if model is not None: del model del tokenizer gc.collect() try: tokenizer = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( repo_id, device_map="cpu", torch_dtype=torch.bfloat16, token=HF_TOKEN, trust_remote_code=True ) active_model_repo = repo_id except Exception as e: raise e @app.route('/', methods=['GET']) def health_check(): return jsonify({"status": "online", "role": "mirror_worker"}) @app.route('/v1/models', methods=['GET']) def list_models(): data = [{"id": m, "object": "model"} for m in AVAILABLE_MODELS] return jsonify({"object": "list", "data": data}) @app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): # Eliminamos el Rate Limit aquí, porque el Server Principal se encarga data = request.get_json(silent=True) if not data or 'messages' not in data: abort(400, description="Petición inválida") messages = data.get('messages', []) requested_model = data.get('model', DEFAULT_MODEL) temperature = data.get('temperature', 0.6) max_new_tokens = min(data.get('max_tokens', 1024), 2048) stream = data.get('stream', False) use_web_search = data.get('web_search', False) with model_lock: try: load_model_into_memory(requested_model) if use_web_search and messages and messages[-1]['role'] == 'user': user_query = messages[-1]['content'] web_context = search_web(user_query) if web_context: messages[-1]['content'] = f"Responde usando esta info de internet si es útil:\n{web_context}\n\nPregunta: {user_query}" prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer([prompt], return_tensors="pt") generation_kwargs = dict( inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=0.9 ) if stream: streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() def generate_stream(): for new_text in streamer: if new_text: yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n" yield "data: [DONE]\n\n" return Response(generate_stream(), mimetype='text/event-stream') else: outputs = model.generate(**generation_kwargs) generated_ids = outputs[0][len(inputs.input_ids[0]):] reply = tokenizer.decode(generated_ids, skip_special_tokens=True) return jsonify({"choices": [{"message": {"role": "assistant", "content": reply}}]}) except Exception as e: abort(500, description=str(e)) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)