Telma_AI / app.py
Diallo
Update app.py
071c3b1 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER = "AmadouDiarouga/telma-mistral-telecom-gn"
print("Chargement tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Chargement modele base...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True,
)
print("Chargement adapteur LoRA...")
model = PeftModel.from_pretrained(base_model, ADAPTER)
model.eval()
print("Modele pret !")
INSTRUCTION_FR = (
"Tu es TELMA, un assistant vocal IA pour les telecoms en Guinee. "
"Tu aides les clients d'Orange Guinee, MTN et Celcom. "
"Reponds en francais simple, court et direct. "
"Cite toujours le code USSD ou le numero quand c'est pertinent."
)
INSTRUCTION_PL = (
"Be njahii e on ngondi e telecoms Guinee. "
"Jaabii dow e Pulaar, yitere e laawol."
)
def repondre(question, langue="fr"):
if not question.strip():
return json.dumps({"error": "question vide"})
instruction = INSTRUCTION_PL if langue == "poular" else INSTRUCTION_FR
prompt = f"<s>[INST] {instruction}\n\n{question.strip()} [/INST]"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.3,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
reponse = response.split("[/INST]")[-1].strip()
return json.dumps({
"question": question,
"reponse": reponse,
"langue": langue,
"status": "ok"
}, ensure_ascii=False)
demo = gr.Interface(
fn=repondre,
inputs=[
gr.Textbox(label="Question", placeholder="Comment verifier mon solde Orange ?"),
gr.Dropdown(choices=["fr", "poular"], value="fr", label="Langue"),
],
outputs=gr.Textbox(label="Reponse JSON"),
title="TELMA - Assistant Telecoms Guinee",
description="API vocale pour Orange, MTN et Celcom en Guinee",
examples=[
["Comment verifier mon solde Orange ?", "fr"],
["Je veux internet MTN pour une semaine", "fr"],
["Hol balansi Orange am ?", "poular"],
],
)
demo.launch(server_name="0.0.0.0", server_port=7860)