File size: 3,556 Bytes
b379801
c65a8ef
 
 
 
b379801
c65a8ef
 
 
b379801
c65a8ef
b379801
c65a8ef
 
 
 
 
b379801
c65a8ef
b379801
c65a8ef
b379801
c65a8ef
 
 
 
 
 
 
 
 
 
 
b379801
 
c65a8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b379801
c65a8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b379801
c65a8ef
 
 
 
 
 
b379801
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import time
import json
import torch

# --- CONFIGURATION OMNIGROUP ---
# On utilise un modèle compact mais puissant pour le CPU gratuit
MODEL_ID = "HuggingFaceTB/SmolLM-135M-Instruct" 

print(f"Initialisation du moteur Pangea sur {MODEL_ID}...")

# Chargement du tokenizer et du modèle
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

def generate_response(prompt, max_tokens=128, temperature=0.7):
    """
    Génère une réponse avec calcul du débit (tokens/s)
    """
    start_time = time.time()
    
    # Encodage
    inputs = tokenizer(prompt, return_tensors="pt")
    input_length = inputs.input_ids.shape[1]
    
    # Génération
    outputs = model.generate(
        **inputs, 
        max_new_tokens=max_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    end_time = time.time()
    
    # Décodage
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extraire uniquement la nouvelle réponse (après le prompt)
    new_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    
    # Métriques
    duration = end_time - start_time
    tokens_generated = len(outputs[0]) - input_length
    tokens_per_sec = round(tokens_generated / duration, 2) if duration > 0 else 0
    
    # Construction du JSON (Format Gemini-like)
    json_output = {
        "id": f"omni-{int(start_time)}",
        "object": "text_completion",
        "created": int(start_time),
        "model": MODEL_ID,
        "choices": [{
            "text": new_text,
            "index": 0,
            "finish_reason": "stop"
        }],
        "usage": {
            "prompt_tokens": input_length,
            "completion_tokens": tokens_generated,
            "total_tokens": input_length + tokens_generated,
            "speed": f"{tokens_per_sec} tokens/s"
        }
    }
    
    return new_text, json.dumps(json_output, indent=2), f"{tokens_per_sec} t/s"

# --- INTERFACE GRADIO PRO ---
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# 🚀 OmniGroup Pangea API v2")
    gr.Markdown("Endpoint haute performance avec métriques de débit en temps réel.")
    
    with gr.Row():
        with gr.Column(scale=2):
            input_text = gr.Textbox(label="Prompt", placeholder="Posez une question à l'IA...", lines=5)
            with gr.Row():
                slider_tokens = gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens")
                slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Température")
            submit_btn = gr.Button("Générer l'inférence", variant="primary")
            
        with gr.Column(scale=1):
            speed_metric = gr.Label(label="Vitesse d'exécution (Débit)")
    
    with gr.Tabs():
        with gr.TabItem("Réponse Texte"):
            output_text = gr.Textbox(label="Sortie Brute", lines=10)
        with gr.TabItem("Réponse JSON (Format API)"):
            output_json = gr.Code(label="JSON Payload", language="json")

    # Mapping des fonctions
    submit_btn.click(
        fn=generate_response, 
        inputs=[input_text, slider_tokens, slider_temp], 
        outputs=[output_text, output_json, speed_metric],
        api_name="chat" # L'endpoint sera /chat
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)