Spaces:
Runtime error
Runtime error
File size: 2,596 Bytes
c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 c2ec8f9 071c3b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER = "AmadouDiarouga/telma-mistral-telecom-gn"
print("Chargement tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Chargement modele base...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True,
)
print("Chargement adapteur LoRA...")
model = PeftModel.from_pretrained(base_model, ADAPTER)
model.eval()
print("Modele pret !")
INSTRUCTION_FR = (
"Tu es TELMA, un assistant vocal IA pour les telecoms en Guinee. "
"Tu aides les clients d'Orange Guinee, MTN et Celcom. "
"Reponds en francais simple, court et direct. "
"Cite toujours le code USSD ou le numero quand c'est pertinent."
)
INSTRUCTION_PL = (
"Be njahii e on ngondi e telecoms Guinee. "
"Jaabii dow e Pulaar, yitere e laawol."
)
def repondre(question, langue="fr"):
if not question.strip():
return json.dumps({"error": "question vide"})
instruction = INSTRUCTION_PL if langue == "poular" else INSTRUCTION_FR
prompt = f"<s>[INST] {instruction}\n\n{question.strip()} [/INST]"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.3,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
reponse = response.split("[/INST]")[-1].strip()
return json.dumps({
"question": question,
"reponse": reponse,
"langue": langue,
"status": "ok"
}, ensure_ascii=False)
demo = gr.Interface(
fn=repondre,
inputs=[
gr.Textbox(label="Question", placeholder="Comment verifier mon solde Orange ?"),
gr.Dropdown(choices=["fr", "poular"], value="fr", label="Langue"),
],
outputs=gr.Textbox(label="Reponse JSON"),
title="TELMA - Assistant Telecoms Guinee",
description="API vocale pour Orange, MTN et Celcom en Guinee",
examples=[
["Comment verifier mon solde Orange ?", "fr"],
["Je veux internet MTN pour une semaine", "fr"],
["Hol balansi Orange am ?", "poular"],
],
)
demo.launch(server_name="0.0.0.0", server_port=7860) |