File size: 2,596 Bytes
c2ec8f9
 
 
 
 
 
071c3b1
 
c2ec8f9
 
 
 
 
 
071c3b1
c2ec8f9
 
071c3b1
c2ec8f9
 
 
 
 
 
 
071c3b1
c2ec8f9
071c3b1
 
 
 
 
c2ec8f9
 
071c3b1
 
 
 
c2ec8f9
071c3b1
 
 
c2ec8f9
071c3b1
c2ec8f9
 
071c3b1
c2ec8f9
 
 
 
 
 
 
 
 
 
 
071c3b1
c2ec8f9
 
 
 
 
 
 
 
 
071c3b1
c2ec8f9
071c3b1
c2ec8f9
 
071c3b1
 
 
c2ec8f9
071c3b1
c2ec8f9
 
 
 
 
071c3b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json

BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER    = "AmadouDiarouga/telma-mistral-telecom-gn"

print("Chargement tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token    = tokenizer.eos_token
tokenizer.padding_side = "right"

print("Chargement modele base...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float32,
    device_map="cpu",
    low_cpu_mem_usage=True,
)

print("Chargement adapteur LoRA...")
model = PeftModel.from_pretrained(base_model, ADAPTER)
model.eval()
print("Modele pret !")

INSTRUCTION_FR = (
    "Tu es TELMA, un assistant vocal IA pour les telecoms en Guinee. "
    "Tu aides les clients d'Orange Guinee, MTN et Celcom. "
    "Reponds en francais simple, court et direct. "
    "Cite toujours le code USSD ou le numero quand c'est pertinent."
)

INSTRUCTION_PL = (
    "Be njahii e on ngondi e telecoms Guinee. "
    "Jaabii dow e Pulaar, yitere e laawol."
)

def repondre(question, langue="fr"):
    if not question.strip():
        return json.dumps({"error": "question vide"})

    instruction = INSTRUCTION_PL if langue == "poular" else INSTRUCTION_FR
    prompt = f"<s>[INST] {instruction}\n\n{question.strip()} [/INST]"

    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.3,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    reponse  = response.split("[/INST]")[-1].strip()

    return json.dumps({
        "question": question,
        "reponse": reponse,
        "langue": langue,
        "status": "ok"
    }, ensure_ascii=False)

demo = gr.Interface(
    fn=repondre,
    inputs=[
        gr.Textbox(label="Question", placeholder="Comment verifier mon solde Orange ?"),
        gr.Dropdown(choices=["fr", "poular"], value="fr", label="Langue"),
    ],
    outputs=gr.Textbox(label="Reponse JSON"),
    title="TELMA - Assistant Telecoms Guinee",
    description="API vocale pour Orange, MTN et Celcom en Guinee",
    examples=[
        ["Comment verifier mon solde Orange ?", "fr"],
        ["Je veux internet MTN pour une semaine", "fr"],
        ["Hol balansi Orange am ?", "poular"],
    ],
)

demo.launch(server_name="0.0.0.0", server_port=7860)