|
|
|
|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
import base64 |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import flask |
|
|
from flask import request, jsonify |
|
|
import requests |
|
|
from bs4 import BeautifulSoup |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "HuggingFaceTB/SmolLM2-360M-Instruct") |
|
|
PORT = int(os.environ.get("PORT", 7860)) |
|
|
FILES_DIR = Path(os.environ.get("FILES_DIR", "engine_files")) |
|
|
FILES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
GEN_DEFAULTS = { |
|
|
"max_new_tokens": int(os.environ.get("MAX_NEW_TOKENS", 512)), |
|
|
"do_sample": os.environ.get("DO_SAMPLE", "true").lower() == "true", |
|
|
"temperature": float(os.environ.get("TEMPERATURE", 0.6)), |
|
|
"top_p": float(os.environ.get("TOP_P", 0.9)), |
|
|
"repetition_penalty": float(os.environ.get("REPETITION_PENALTY", 1.05)), |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_CONTEXT_TOKENS = int(os.environ.get("MODEL_CONTEXT_TOKENS", 4096)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = flask.Flask(__name__) |
|
|
_start_time = time.time() |
|
|
|
|
|
print(f"🔄 Loading model {MODEL_ID} ... (this may take a while the first time)") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
dtype = torch.bfloat16 |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype, low_cpu_mem_usage=True) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
print(f"✅ Model loaded: {MODEL_ID} on {device} (dtype={dtype})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_filename(name: str) -> str: |
|
|
safe = "".join(c for c in name if c.isalnum() or c in "._- ").strip() |
|
|
if not safe: |
|
|
safe = str(uuid.uuid4()) |
|
|
return safe |
|
|
|
|
|
def _truncate_prompt_for_context(prompt: str, max_new_tokens: int) -> str: |
|
|
""" |
|
|
Truncate prompt so that total tokens (prompt + new tokens) <= MODEL_CONTEXT_TOKENS. |
|
|
Keeps the last part of prompt (most recent user content). |
|
|
""" |
|
|
|
|
|
margin = 32 |
|
|
allowed_prompt_tokens = max(MODEL_CONTEXT_TOKENS - max_new_tokens - margin, 32) |
|
|
|
|
|
toks = tokenizer.encode(prompt, add_special_tokens=False) |
|
|
if len(toks) <= allowed_prompt_tokens: |
|
|
return prompt |
|
|
|
|
|
toks = toks[-allowed_prompt_tokens:] |
|
|
return tokenizer.decode(toks, clean_up_tokenization_spaces=True) |
|
|
|
|
|
def generate_from_model(prompt: str, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
do_sample: Optional[bool] = None, |
|
|
temperature: Optional[float] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: Optional[float] = None) -> str: |
|
|
cfg = { |
|
|
"max_new_tokens": int(max_new_tokens) if max_new_tokens is not None else GEN_DEFAULTS["max_new_tokens"], |
|
|
"do_sample": do_sample if do_sample is not None else GEN_DEFAULTS["do_sample"], |
|
|
"temperature": float(temperature) if temperature is not None else GEN_DEFAULTS["temperature"], |
|
|
"top_p": float(top_p) if top_p is not None else GEN_DEFAULTS["top_p"], |
|
|
"repetition_penalty": float(repetition_penalty) if repetition_penalty is not None else GEN_DEFAULTS["repetition_penalty"], |
|
|
} |
|
|
|
|
|
|
|
|
prompt = _truncate_prompt_for_context(prompt, cfg["max_new_tokens"]) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MODEL_CONTEXT_TOKENS).to(device) |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=cfg["max_new_tokens"], |
|
|
do_sample=cfg["do_sample"], |
|
|
temperature=cfg["temperature"], |
|
|
top_p=cfg["top_p"], |
|
|
repetition_penalty=cfg["repetition_penalty"], |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/health", methods=["GET"]) |
|
|
def health(): |
|
|
uptime = time.time() - _start_time |
|
|
try: |
|
|
import psutil |
|
|
mem = psutil.virtual_memory()._asdict() |
|
|
except Exception: |
|
|
mem = {"info": "psutil not installed or unavailable"} |
|
|
return jsonify({ |
|
|
"status": "ok", |
|
|
"uptime_seconds": int(uptime), |
|
|
"device": str(device), |
|
|
"model_id": MODEL_ID, |
|
|
"memory": mem |
|
|
}) |
|
|
|
|
|
@app.route("/model_info", methods=["GET"]) |
|
|
def model_info(): |
|
|
return jsonify({ |
|
|
"model_id": MODEL_ID, |
|
|
"device": str(device), |
|
|
"dtype": str(dtype), |
|
|
"vocab_size": getattr(tokenizer, "vocab_size", None), |
|
|
"tokenizer_fast": getattr(tokenizer, "is_fast", None), |
|
|
}) |
|
|
|
|
|
|
|
|
@app.route("/chat", methods=["POST"]) |
|
|
def chat(): |
|
|
""" |
|
|
POST JSON: |
|
|
{ |
|
|
"message": "text", |
|
|
"max_new_tokens": 256, # optional |
|
|
"do_sample": true/false, # optional |
|
|
"temperature": 0.7, # optional |
|
|
"top_p": 0.9, # optional |
|
|
"repetition_penalty": 1.05 # optional |
|
|
} |
|
|
""" |
|
|
try: |
|
|
body = request.get_json(force=True) |
|
|
msg = (body.get("message") or body.get("prompt") or "").strip() |
|
|
if not msg: |
|
|
return jsonify({"error": "No message provided"}), 400 |
|
|
|
|
|
max_new_tokens = body.get("max_new_tokens") |
|
|
do_sample = body.get("do_sample") |
|
|
temperature = body.get("temperature") |
|
|
top_p = body.get("top_p") |
|
|
repetition_penalty = body.get("repetition_penalty") |
|
|
|
|
|
|
|
|
prompt = f"User: {msg}\nAssistant:" |
|
|
|
|
|
full = generate_from_model(prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=do_sample, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty) |
|
|
|
|
|
|
|
|
if "Assistant:" in full: |
|
|
reply = full.split("Assistant:", 1)[1].strip() |
|
|
else: |
|
|
|
|
|
reply = full.replace(prompt, "").strip() |
|
|
|
|
|
return jsonify({"reply": reply}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/search", methods=["POST"]) |
|
|
def search(): |
|
|
""" |
|
|
POST JSON: |
|
|
{ "q": "your query", "top_k": 5 } |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
q = (data.get("q") or "").strip() |
|
|
if not q: |
|
|
return jsonify({"error": "Query 'q' missing"}), 400 |
|
|
top_k = int(data.get("top_k", 5)) |
|
|
|
|
|
url = "https://html.duckduckgo.com/html/" |
|
|
r = requests.post(url, data={"q": q}, timeout=10) |
|
|
r.raise_for_status() |
|
|
soup = BeautifulSoup(r.text, "html.parser") |
|
|
|
|
|
results = [] |
|
|
|
|
|
anchors = soup.select("a.result__a")[:top_k] |
|
|
for a in anchors: |
|
|
title = a.get_text().strip() |
|
|
href = a.get("href") |
|
|
|
|
|
snippet = "" |
|
|
parent = a.parent |
|
|
if parent: |
|
|
s = parent.select_one("a.result__snippet") or parent.select_one(".result__snippet") |
|
|
if s: |
|
|
snippet = s.get_text().strip() |
|
|
results.append({"title": title, "url": href, "snippet": snippet}) |
|
|
return jsonify({"query": q, "results": results}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/fetch_url", methods=["POST"]) |
|
|
def fetch_url(): |
|
|
""" |
|
|
POST JSON: { "url": "https://...", "max_chars": 10000 } |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
url = data.get("url", "") |
|
|
if not url: |
|
|
return jsonify({"error": "url missing"}), 400 |
|
|
max_chars = int(data.get("max_chars", 10000)) |
|
|
r = requests.get(url, timeout=10) |
|
|
r.raise_for_status() |
|
|
text = r.text |
|
|
if len(text) > max_chars: |
|
|
text = text[:max_chars] + "\n\n...[truncated]" |
|
|
return jsonify({"url": url, "content": text}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/summarize", methods=["POST"]) |
|
|
def summarize(): |
|
|
""" |
|
|
POST JSON: { "text": "...", "max_new_tokens": 200 } |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
text = (data.get("text") or "").strip() |
|
|
if not text: |
|
|
return jsonify({"error": "text missing"}), 400 |
|
|
max_new_tokens = int(data.get("max_new_tokens", GEN_DEFAULTS["max_new_tokens"])) |
|
|
prompt = f"Summarize the following text concisely and clearly:\n\n{text}\n\nSummary:" |
|
|
out = generate_from_model(prompt, max_new_tokens=max_new_tokens) |
|
|
if "Summary:" in out: |
|
|
summary = out.split("Summary:", 1)[1].strip() |
|
|
else: |
|
|
summary = out.replace(prompt, "").strip() |
|
|
return jsonify({"summary": summary}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/run_code", methods=["POST"]) |
|
|
def run_code(): |
|
|
""" |
|
|
POST JSON: { "code": "print('hi')", "timeout": 8 } |
|
|
Returns stdout, stderr, exit_code |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
code = data.get("code", "") |
|
|
if not code: |
|
|
return jsonify({"error": "code missing"}), 400 |
|
|
timeout = float(data.get("timeout", 8)) |
|
|
job_id = str(uuid.uuid4()) |
|
|
tmp_file = FILES_DIR / f"job_{job_id}.py" |
|
|
tmp_file.write_text(code, encoding="utf-8") |
|
|
|
|
|
proc = subprocess.run( |
|
|
["python3", str(tmp_file)], |
|
|
capture_output=True, |
|
|
text=True, |
|
|
timeout=timeout |
|
|
) |
|
|
stdout = proc.stdout |
|
|
stderr = proc.stderr |
|
|
exit_code = proc.returncode |
|
|
|
|
|
return jsonify({"stdout": stdout, "stderr": stderr, "exit_code": exit_code, "job_id": job_id}) |
|
|
except subprocess.TimeoutExpired as te: |
|
|
return jsonify({"error": "timeout", "detail": str(te)}), 500 |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/create_file", methods=["POST"]) |
|
|
def create_file(): |
|
|
""" |
|
|
POST JSON: { "filename": "name.txt", "content": "...", "encode_base64": false } |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
filename = safe_filename(data.get("filename", f"file_{uuid.uuid4()}.txt")) |
|
|
content = data.get("content", "") |
|
|
b64 = bool(data.get("encode_base64", False)) |
|
|
path = FILES_DIR / filename |
|
|
if b64: |
|
|
decoded = base64.b64decode(content) |
|
|
path.write_bytes(decoded) |
|
|
else: |
|
|
path.write_text(content, encoding="utf-8") |
|
|
return jsonify({"path": str(path), "filename": filename}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/list_files", methods=["GET"]) |
|
|
def list_files(): |
|
|
files = [] |
|
|
for f in FILES_DIR.iterdir(): |
|
|
if f.is_file(): |
|
|
files.append({"name": f.name, "size": f.stat().st_size, "path": str(f)}) |
|
|
return jsonify({"files": files}) |
|
|
|
|
|
|
|
|
@app.route("/download_file", methods=["POST"]) |
|
|
def download_file(): |
|
|
""" |
|
|
POST JSON: { "filename": "name.txt", "as_base64": false } |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
filename = data.get("filename", "") |
|
|
if not filename: |
|
|
return jsonify({"error": "filename missing"}), 400 |
|
|
path = FILES_DIR / filename |
|
|
if not path.exists(): |
|
|
return jsonify({"error": "file not found"}), 404 |
|
|
as_b64 = bool(data.get("as_base64", False)) |
|
|
if as_b64: |
|
|
b = path.read_bytes() |
|
|
return jsonify({"filename": filename, "content_base64": base64.b64encode(b).decode()}) |
|
|
else: |
|
|
text = path.read_text(encoding="utf-8", errors="replace") |
|
|
return jsonify({"filename": filename, "content": text}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/delete_file", methods=["POST"]) |
|
|
def delete_file(): |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
filename = data.get("filename", "") |
|
|
if not filename: |
|
|
return jsonify({"error": "filename missing"}), 400 |
|
|
path = FILES_DIR / filename |
|
|
if not path.exists(): |
|
|
return jsonify({"error": "file not found"}), 404 |
|
|
path.unlink() |
|
|
return jsonify({"deleted": filename}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route("/ask_search", methods=["POST"]) |
|
|
def ask_search(): |
|
|
""" |
|
|
POST JSON: { "q": "question", "top_k": 3, "max_new_tokens": 300 } |
|
|
Returns search results + LLM synthesized answer |
|
|
""" |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
q = (data.get("q") or "").strip() |
|
|
if not q: |
|
|
return jsonify({"error": "q missing"}), 400 |
|
|
top_k = int(data.get("top_k", 3)) |
|
|
|
|
|
search_resp = requests.post("https://html.duckduckgo.com/html/", data={"q": q}, timeout=10) |
|
|
soup = BeautifulSoup(search_resp.text, "html.parser") |
|
|
anchors = soup.select("a.result__a")[:top_k] |
|
|
snippets = [] |
|
|
results = [] |
|
|
for a in anchors: |
|
|
title = a.get_text().strip() |
|
|
href = a.get("href") |
|
|
results.append({"title": title, "url": href}) |
|
|
|
|
|
try: |
|
|
r2 = requests.get(href, timeout=5) |
|
|
txt = r2.text[:4000] |
|
|
snippets.append(txt) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
combined = "\n\n---\n\n".join(snippets[:3]) |
|
|
prompt = f"Question: {q}\n\nUse the following snippets from web pages to answer the question (be concise and cite urls where useful):\n\n{combined}\n\nAnswer:" |
|
|
max_new_tokens = int(data.get("max_new_tokens", GEN_DEFAULTS["max_new_tokens"])) |
|
|
answer = generate_from_model(prompt, max_new_tokens=max_new_tokens) |
|
|
|
|
|
if "Answer:" in answer: |
|
|
answer_text = answer.split("Answer:", 1)[1].strip() |
|
|
else: |
|
|
answer_text = answer.replace(prompt, "").strip() |
|
|
return jsonify({"query": q, "search_results": results, "answer": answer_text}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Engine ready. Endpoints:") |
|
|
print(" /health (GET)") |
|
|
print(" /model_info (GET)") |
|
|
print(" /chat (POST) -> {message}") |
|
|
print(" /search (POST) -> {q, top_k}") |
|
|
print(" /fetch_url (POST) -> {url, max_chars}") |
|
|
print(" /summarize (POST) -> {text}") |
|
|
print(" /run_code (POST) -> {code, timeout}") |
|
|
print(" /create_file (POST) -> {filename, content}") |
|
|
print(" /list_files (GET)") |
|
|
print(" /download_file (POST) -> {filename, as_base64}") |
|
|
print(" /delete_file (POST) -> {filename}") |
|
|
print(" /ask_search (POST) -> {q, top_k}") |
|
|
app.run(host="0.0.0.0", port=PORT, threaded=True) |