File size: 4,537 Bytes
30bc440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import json
import time
import gc
from threading import Thread, Lock
from flask import Flask, request, jsonify, Response, abort
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from duckduckgo_search import DDGS

app = Flask(__name__)

torch.set_num_threads(2)

AVAILABLE_MODELS = [
    "google/gemma-4-E2B-it",
    "Qwen/Qwen3-4B-Instruct-2507",
    "HuggingFaceTB/SmolLM3-3B"
]
DEFAULT_MODEL = "google/gemma-4-E2B-it"

active_model_repo = None
model = None
tokenizer = None
HF_TOKEN = os.environ.get("HF_TOKEN")
model_lock = Lock()

def search_web(query):
    try:
        results = DDGS().text(query, max_results=3)
        if not results:
            return ""
        context = "CONTEXTO DE INTERNET ACTUALIZADO:\n"
        for r in results:
            context += f"- {r.get('title')}: {r.get('body')}\n"
        return context
    except Exception as e:
        return ""

def load_model_into_memory(repo_id):
    global active_model_repo, model, tokenizer
    
    if repo_id not in AVAILABLE_MODELS:
        repo_id = DEFAULT_MODEL
        
    if active_model_repo == repo_id:
        return
        
    if model is not None:
        del model
        del tokenizer
        gc.collect()
        
    try:
        tokenizer = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            repo_id,
            device_map="cpu",
            torch_dtype=torch.bfloat16,
            token=HF_TOKEN,
            trust_remote_code=True
        )
        active_model_repo = repo_id
    except Exception as e:
        raise e

@app.route('/', methods=['GET'])
def health_check():
    return jsonify({"status": "online", "role": "mirror_worker"})

@app.route('/v1/models', methods=['GET'])
def list_models():
    data = [{"id": m, "object": "model"} for m in AVAILABLE_MODELS]
    return jsonify({"object": "list", "data": data})

@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
    # Eliminamos el Rate Limit aquí, porque el Server Principal se encarga
    data = request.get_json(silent=True)
    if not data or 'messages' not in data:
        abort(400, description="Petición inválida")

    messages = data.get('messages', [])
    requested_model = data.get('model', DEFAULT_MODEL)
    temperature = data.get('temperature', 0.6)
    max_new_tokens = min(data.get('max_tokens', 1024), 2048)
    stream = data.get('stream', False)
    use_web_search = data.get('web_search', False)

    with model_lock:
        try:
            load_model_into_memory(requested_model)

            if use_web_search and messages and messages[-1]['role'] == 'user':
                user_query = messages[-1]['content']
                web_context = search_web(user_query)
                if web_context:
                    messages[-1]['content'] = f"Responde usando esta info de internet si es útil:\n{web_context}\n\nPregunta: {user_query}"

            prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer([prompt], return_tensors="pt")

            generation_kwargs = dict(
                inputs, max_new_tokens=max_new_tokens, temperature=temperature,
                do_sample=temperature > 0.0, top_p=0.9
            )

            if stream:
                streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
                generation_kwargs["streamer"] = streamer
                
                thread = Thread(target=model.generate, kwargs=generation_kwargs)
                thread.start()

                def generate_stream():
                    for new_text in streamer:
                        if new_text:
                            yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n"
                    yield "data: [DONE]\n\n"
                
                return Response(generate_stream(), mimetype='text/event-stream')
            else:
                outputs = model.generate(**generation_kwargs)
                generated_ids = outputs[0][len(inputs.input_ids[0]):]
                reply = tokenizer.decode(generated_ids, skip_special_tokens=True)
                
                return jsonify({"choices": [{"message": {"role": "assistant", "content": reply}}]})
        except Exception as e:
            abort(500, description=str(e))

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)