hadxs's picture
8 models optimized
d4b9c42 verified
import gradio as gr, requests, json, time, threading, torch, os, random, math
from transformers import AutoModelForCausalLM, AutoTokenizer
import urllib3; urllib3.disable_warnings()
sid = os.environ.get("SPACE_ID") or os.environ.get("SPACE_NAME") or ""
NUM = ""
for part in sid.split("-"):
if part.isdigit(): NUM = part
if not NUM: NUM = "1"
TEACHER_ID = "hf-worker-" + NUM
N = int(NUM)
# === 8 MODELES du plus petit au plus gros ===
# Choisi selon le numero du worker (round-robin)
ALL_MODELS = [
("HuggingFaceTB/SmolLM2-360M-Instruct", 0.36),
("Qwen/Qwen2.5-0.5B-Instruct", 0.5),
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 1.1),
("deepseek-ai/deepseek-coder-1.3b-instruct", 1.3),
("Qwen/Qwen2.5-1.5B-Instruct", 1.5),
("Qwen/Qwen2.5-Coder-1.5B-Instruct", 1.5),
("stabilityai/stablelm-2-1_6b", 1.6),
("HuggingFaceTB/SmolLM2-1.7B-Instruct", 1.7),
]
MID, B = ALL_MODELS[N % len(ALL_MODELS)]
DELAY = max(3, int(5 + B * 10 - N % 10)) # 8-19s, petit modele = delai court
print(f"ID:{TEACHER_ID} MODEL:{MID} ({B}B) DELAY:{DELAY}s")
BIP="3.125.223.134"
BH="alfredo-agravic-saddeningly.ngrok-free.dev"
m=None; t=None; s=requests.Session(); s.verify=False
def g(p): return s.get("https://"+BIP+p,headers={"Host":BH},timeout=60)
def p(p,d): return s.post("https://"+BIP+p,headers={"Host":BH},json=d,timeout=30)
def l():
global m,t
m=AutoModelForCausalLM.from_pretrained(MID,torch_dtype=torch.float32,device_map="cpu",trust_remote_code=True)
t=AutoTokenizer.from_pretrained(MID,trust_remote_code=True)
if t.pad_token is None: t.pad_token=t.eos_token
def gen(ms, max_tokens, temp=0.7):
x=t.apply_chat_template(ms,tokenize=False)
i=t(x,return_tensors="pt").to("cpu")
with torch.no_grad():
o=m.generate(**i,max_new_tokens=max_tokens,temperature=temp,do_sample=True,
top_p=0.9,top_k=40,repetition_penalty=1.05,pad_token_id=t.pad_token_id)
return t.decode(o[0][i["input_ids"].shape[1]:],skip_special_tokens=True).strip()
# Prompts adaptes a la taille du modele
if B < 0.5:
# Ultra-petit: prompts tres courts, reponses courtes
PROMPTS = [
lambda s: (f"Explique {s} en 1 phrase.", 64, 0.7),
lambda s: (f"C'est quoi {s} ?", 64, 0.7),
lambda s: (f"Donne 1 fait sur {s}.", 64, 0.8),
lambda s: (f"Pourquoi {s} est utile ?", 96, 0.7),
lambda s: (f"Un conseil sur {s}.", 96, 0.7),
lambda s: (f"Decris {s} en 2 lignes.", 96, 0.7),
]
elif B < 1.0:
# Petit: prompts simples, reponses courtes
PROMPTS = [
lambda s: (f"Explique {s} simplement.", 128, 0.7),
lambda s: (f"C'est quoi {s} ? Donne un exemple.", 128, 0.7),
lambda s: (f"Quels sont les points cles de {s} ?", 160, 0.6),
lambda s: (f"Pourquoi {s} est important ?", 128, 0.8),
lambda s: (f"Compare {s} avec son alternative.", 160, 0.7),
lambda s: (f"Comment utiliser {s} ?", 160, 0.7),
]
else:
# Moyen/Gros: prompts normaux, reponses completes
PROMPTS = [
lambda s: (f"Explique {s} en detail avec des exemples concrets.", 256, 0.7),
lambda s: (f"Analyse {s} : fonctionnement, applications, limites.", 320, 0.7),
lambda s: (f"Quels sont les concepts cles de {s} ? Liste structuree.", 256, 0.6),
lambda s: (f"Compare {s} avec son alternative. Avantages/inconvenients.", 320, 0.7),
lambda s: (f"Tutoriel pas a pas pour utiliser {s}.", 320, 0.7),
lambda s: (f"L'histoire et l'avenir de {s}.", 320, 0.7),
]
if "coder" in MID.lower():
# Workers code: prompts orientes code
PROMPTS = [
lambda s: (f"Write code example illustrating {s}.", 256, 0.6),
lambda s: (f"Explain {s} with a code snippet.", 256, 0.7),
lambda s: (f"How to implement {s}? Step by step.", 320, 0.7),
lambda s: (f"Best practices for {s} in programming.", 256, 0.7),
lambda s: (f"Compare approaches for {s} with code.", 320, 0.7),
lambda s: (f"Common bugs with {s} and how to fix them.", 256, 0.7),
]
def w():
global m,t; l()
cycle = 0
while True:
try:
p("/heartbeat",{"teacher":TEACHER_ID,"model":MID})
r=g("/next-batch?teacher="+TEACHER_ID+"&batch_size=5")
if r.status_code==200:
es=r.json().get("entries",[])
if es:
rs=[]
for e in es:
sj=e.get("subject","")
if not sj: continue
for idx in range(6):
pi = (hash(TEACHER_ID + sj + str(cycle + idx))) % len(PROMPTS)
fn = PROMPTS[pi]
try:
inst, mt, tmp = fn(sj)
resp=gen([{"role":"system","content":"Tu es Connor."},{"role":"user","content":inst}],
max_tokens=mt, temp=tmp)
if resp and len(resp)>20:
rs.append({"instruction":inst,"input":"","output":resp,
"teacher":TEACHER_ID,"subject":sj,"model":MID})
except: pass
if rs:
p("/push-results",{"teacher":TEACHER_ID,"results":rs})
cycle += 1
time.sleep(DELAY)
except Exception as e:
time.sleep(30)
threading.Thread(target=w,daemon=True).start()
gr.Interface(fn=lambda:json.dumps({"status":"ok","id":TEACHER_ID,"model":MID}),inputs=[],outputs="text").launch()