HekeReplicate / app.py
Clemylia's picture
Update app.py
e2bb3c6 verified
import gradio as gr
import os
import subprocess
import stat
import urllib.request
def setup_cog():
"""Télécharge cog localement si absent et lui donne les droits"""
cog_path = "./bin/cog"
if not os.path.exists(cog_path):
url = "https://github.com/replicate/cog/releases/latest/download/cog_Linux_x86_64"
urllib.request.urlretrieve(url, cog_path)
# Donne les droits d'exécution (chmod +x)
st = os.stat(cog_path)
os.chmod(cog_path, st.st_mode | stat.S_IEXEC)
return os.path.abspath(cog_path)
def deploy_to_replicate(hf_repo, replicate_repo):
try:
# Configuration de l'environnement
COG_EXE = setup_cog()
token = os.getenv("REPLICATE_API_TOKEN")
if not token:
return "❌ Erreur : Ajoute REPLICATE_API_TOKEN dans les Secrets du Space."
# 1. Génération du predict.py
predict_code = f"""
from cog import BasePredictor, Input
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class Predictor(BasePredictor):
def setup(self):
self.tokenizer = AutoTokenizer.from_pretrained("{hf_repo}")
self.model = AutoModelForCausalLM.from_pretrained(
"{hf_repo}",
torch_dtype=torch.float16,
device_map="auto"
)
def predict(self, prompt: str = Input(description="Prompt")) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = self.model.generate(**inputs, max_new_tokens=100)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
"""
with open("predict.py", "w") as f:
f.write(predict_code.strip())
# 2. Génération du cog.yaml
cog_yaml = """
build:
gpu: true
python_packages:
- "torch"
- "transformers"
- "accelerate"
predict: "predict.py:Predictor"
"""
with open("cog.yaml", "w") as f:
f.write(cog_yaml.strip())
# 3. Push (On utilise env pour passer le token proprement)
env = os.environ.copy()
env["REPLICATE_API_TOKEN"] = token
# Login explicite
subprocess.run([COG_EXE, "login", "--token", token], check=True, env=env)
# Push vers Replicate
process = subprocess.run(
[COG_EXE, "push", f"r8.im/{replicate_repo}"],
capture_output=True,
text=True,
env=env
)
if process.returncode != 0:
return f"❌ Erreur Replicate :\n{process.stderr}"
return f"✅ Déployé avec succès sur r8.im/{replicate_repo}"
except Exception as e:
return f"💥 Erreur : {str(e)}"
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🚀 Finisha HF -> Replicate")
hf_in = gr.Textbox(label="Dépôt HF", placeholder="Finisha-f-scratch/charlotte-amity")
rep_in = gr.Textbox(label="Dépôt Replicate", placeholder="clemylia27/charlotte-amity")
btn = gr.Button("Déployer")
out = gr.Textbox(label="Logs")
btn.click(deploy_to_replicate, [hf_in, rep_in], out)
demo.launch()