Lassen / main.py
Rahaf2001's picture
ุชุญู…ูŠู„ ู…ู„ูุงุช ุงู„ู…ุดุฑูˆุน
5f56b43 verified
import os, re, html, pickle
import numpy as np
import torch
from collections import Counter
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from openai import OpenAI
# โ”€โ”€ Config โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
METER_MODEL_ID = "Rahaf2001/Lassen-meter-classifier"
ERA_MODEL_ID = "Rahaf2001/LassenEraClassifier"
TOPIC_MODEL_ID = "Rahaf2001/Lassen-topic-classifier"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# โ”€โ”€ Global model holders โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
models = {}
# โ”€โ”€ Arabic text cleaning โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ARABIC_DIACRITICS = re.compile(r'[\u0617-\u061A\u064B-\u0652\u0670\u06D6-\u06ED]')
def clean_arabic(text: str) -> str:
if not text:
return ""
text = html.unescape(str(text))
text = re.sub(r"<.*?>", " ", text)
text = text.replace("\u0640", "")
text = ARABIC_DIACRITICS.sub("", text)
text = re.sub(r'[\u0623\u0625\u0622\u0671]', '\u0627', text)
text = text.replace("\u0629", "\u0647")
text = re.sub(r"[0-9\u0660-\u0669]", " ", text)
text = re.sub(r"[^\u0600-\u06FF\s]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
# โ”€โ”€ Meter labels โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
LABELS_METER = ['saree', 'kamel', 'mutakareb', 'mutadarak', 'munsareh',
'madeed', 'mujtath', 'ramal', 'baseet', 'khafeef',
'taweel', 'wafer', 'hazaj', 'rajaz']
METER_ARABIC = {
'saree': 'ุงู„ุณุฑูŠุน', 'kamel': 'ุงู„ูƒุงู…ู„', 'mutakareb': 'ุงู„ู…ุชู‚ุงุฑุจ',
'mutadarak': 'ุงู„ู…ุชุฏุงุฑูƒ', 'munsareh': 'ุงู„ู…ู†ุณุฑุญ', 'madeed': 'ุงู„ู…ุฏูŠุฏ',
'mujtath': 'ุงู„ู…ุฌุชุซ', 'ramal': 'ุงู„ุฑู…ู„', 'baseet': 'ุงู„ุจุณูŠุท',
'khafeef': 'ุงู„ุฎููŠู', 'taweel': 'ุงู„ุทูˆูŠู„', 'wafer': 'ุงู„ูˆุงูุฑ',
'hazaj': 'ุงู„ู‡ุฒุฌ', 'rajaz': 'ุงู„ุฑุฌุฒ'
}
# โ”€โ”€ Meter taf'ila patterns โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
METER_PATTERNS = {
'ุงู„ุทูˆูŠู„': 'ููŽุนููˆู„ูู†ู’ ู…ูŽููŽุงุนููŠู„ูู†ู’ ููŽุนููˆู„ูู†ู’ ู…ูŽููŽุงุนูู„ูู†ู’',
'ุงู„ูƒุงู…ู„': 'ู…ูุชูŽููŽุงุนูู„ูู†ู’ ู…ูุชูŽููŽุงุนูู„ูู†ู’ ู…ูุชูŽููŽุงุนูู„ูู†ู’',
'ุงู„ุจุณูŠุท': 'ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ููŽุงุนูู„ูู†ู’ ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ููŽุงุนูู„ูู†ู’',
'ุงู„ูˆุงูุฑ': 'ู…ูููŽุงุนูŽู„ูŽุชูู†ู’ ู…ูููŽุงุนูŽู„ูŽุชูู†ู’ ููŽุนููˆู„ูู†ู’',
'ุงู„ุฎููŠู': 'ููŽุงุนูู„ูŽุงุชูู†ู’ ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ููŽุงุนูู„ูŽุงุชูู†ู’',
'ุงู„ุฑุฌุฒ': 'ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’',
'ุงู„ุฑู…ู„': 'ููŽุงุนูู„ูŽุงุชูู†ู’ ููŽุงุนูู„ูŽุงุชูู†ู’ ููŽุงุนูู„ูŽุงุชูู†ู’',
'ุงู„ุณุฑูŠุน': 'ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ู…ูŽูู’ุนููˆู„ูŽุงุชู',
'ุงู„ู…ู†ุณุฑุญ': 'ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’ ู…ูŽูู’ุนููˆู„ูŽุงุชู ู…ูุณู’ุชูŽูู’ุนูู„ูู†ู’',
'ุงู„ู‡ุฒุฌ': 'ู…ูŽููŽุงุนููŠู„ูู†ู’ ู…ูŽููŽุงุนููŠู„ูู†ู’',
'ุงู„ู…ุชู‚ุงุฑุจ': 'ููŽุนููˆู„ูู†ู’ ููŽุนููˆู„ูู†ู’ ููŽุนููˆู„ูู†ู’ ููŽุนููˆู„ูู†ู’',
'ุงู„ู…ุชุฏุงุฑูƒ': 'ููŽุงุนูู„ูู†ู’ ููŽุงุนูู„ูู†ู’ ููŽุงุนูู„ูู†ู’ ููŽุงุนูู„ูู†ู’',
'ุงู„ู…ุฏูŠุฏ': 'ููŽุงุนูู„ูŽุงุชูู†ู’ ููŽุงุนูู„ูู†ู’ ููŽุงุนูู„ูŽุงุชูู†ู’',
}
# โ”€โ”€ Lifespan: load models once at startup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading models...")
# Meter
models["meter_tokenizer"] = AutoTokenizer.from_pretrained(METER_MODEL_ID)
models["meter_model"] = AutoModelForSequenceClassification.from_pretrained(METER_MODEL_ID)
models["meter_model"].to(device).eval()
# Era
models["era_tokenizer"] = AutoTokenizer.from_pretrained(ERA_MODEL_ID)
models["era_model"] = AutoModelForSequenceClassification.from_pretrained(ERA_MODEL_ID)
models["era_model"].to(device).eval()
# Topic
models["topic_tokenizer"] = AutoTokenizer.from_pretrained(TOPIC_MODEL_ID)
models["topic_model"] = AutoModelForSequenceClassification.from_pretrained(TOPIC_MODEL_ID)
models["topic_model"].to(device).eval()
# Topic labels โ€” loaded from HF model config
topic_cfg = models["topic_model"].config
if hasattr(topic_cfg, "id2label"):
models["id2label_topic"] = {int(k): v for k, v in topic_cfg.id2label.items()}
else:
models["id2label_topic"] = {i: str(i) for i in range(topic_cfg.num_labels)}
# OpenAI client
models["openai"] = OpenAI(api_key=OPENAI_API_KEY)
print(f"All models loaded on {device} โœ“")
yield
models.clear()
# โ”€โ”€ App โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
app = FastAPI(title="Bayan API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# โ”€โ”€ Inference helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def predict_meter(poem_text: str) -> dict:
verses = [v.replace("#", " ").strip() for v in poem_text.strip().split("\n") if v.strip()]
if not verses:
raise ValueError("ุงู„ู‚ุตูŠุฏุฉ ูุงุฑุบุฉ")
predictions = []
for verse in verses:
inputs = models["meter_tokenizer"](verse, return_tensors="pt", truncation=True,
max_length=32, padding="max_length")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.softmax(models["meter_model"](**inputs).logits, dim=-1)[0]
pred_id = torch.argmax(probs).item()
predictions.append((LABELS_METER[pred_id], probs[pred_id].item()))
top_meter = Counter(p[0] for p in predictions).most_common(1)[0][0]
avg_conf = sum(c for _, c in predictions) / len(predictions)
return {"meter_ar": METER_ARABIC[top_meter], "meter_en": top_meter,
"confidence": round(avg_conf, 3)}
def predict_era(poem_text: str) -> dict:
cleaned = clean_arabic(poem_text)
enc = models["era_tokenizer"](cleaned, padding="max_length", truncation=True,
max_length=256, return_tensors="pt")
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
probs = torch.softmax(models["era_model"](**enc).logits, dim=-1).cpu().numpy()[0]
label_names = ["ู‚ุฏูŠู…", "ุญุฏูŠุซ"]
pred_idx = int(np.argmax(probs))
return {"era": label_names[pred_idx],
"classical_probability": round(float(probs[0]), 4),
"modern_probability": round(float(probs[1]), 4)}
def predict_topic(poem_text: str) -> dict:
cleaned = clean_arabic(poem_text)
inputs = models["topic_tokenizer"](cleaned, truncation=True, max_length=512,
return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
probs = torch.softmax(models["topic_model"](**inputs).logits, dim=-1)[0].cpu().numpy()
top3 = np.argsort(probs)[::-1][:3]
id2label = models["id2label_topic"]
return {"topic": id2label[int(top3[0])],
"confidence": round(float(probs[top3[0]]), 3),
"top3": [{"label": id2label[int(i)], "prob": round(float(probs[i]), 3)} for i in top3]}
# โ”€โ”€ Request / Response schemas โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class PoemRequest(BaseModel):
poem: str
class GenerateRequest(BaseModel):
idea: str
meter: str
num_verses: int = 4
# โ”€โ”€ Routes โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@app.get("/")
def root():
return {"status": "ok", "service": "Bayan API"}
@app.get("/health")
def health():
return {"status": "healthy", "device": str(device)}
@app.post("/fasserha")
def fasserha(req: PoemRequest):
"""ูุณู‘ุฑู‡ุง ู„ูŠ โ€” classify meter, era, topic then generate literary analysis."""
if not req.poem.strip():
raise HTTPException(400, "ุงู„ู‚ุตูŠุฏุฉ ูุงุฑุบุฉ")
try:
meter = predict_meter(req.poem)
era = predict_era(req.poem)
topic = predict_topic(req.poem)
except Exception as e:
raise HTTPException(500, f"ุฎุทุฃ ููŠ ุงู„ุชุตู†ูŠู: {str(e)}")
system_prompt = """ุฃู†ุช ู†ุงู‚ุฏ ุฃุฏุจูŠ ู…ุชุฎุตุต ููŠ ุงู„ุดุนุฑ ุงู„ุนุฑุจูŠ ุงู„ูƒู„ุงุณูŠูƒูŠ ูˆุงู„ุญุฏูŠุซ.
ุชุญู„ู„ ุงู„ู‚ุตุงุฆุฏ ุจุฃุณู„ูˆุจ ุฃูƒุงุฏูŠู…ูŠ ุฑุงู‚ูุŒ ูˆุชุณุชุฎุฏู… ุงู„ู…ุตุทู„ุญุงุช ุงู„ุจู„ุงุบูŠุฉ ูˆุงู„ุนุฑูˆุถูŠุฉ ุจุฏู‚ุฉ.
ุฑุฏูƒ ุฏุงุฆู…ุงู‹ ุจุงู„ุนุฑุจูŠุฉ ุงู„ูุตุญู‰."""
user_prompt = f"""ุญู„ู‘ู„ ู‡ุฐู‡ ุงู„ู‚ุตูŠุฏุฉ:
{req.poem}
ู…ุนุทูŠุงุช ุงู„ู†ู…ุงุฐุฌ (ุญู‚ุงุฆู‚ ู…ุคูƒุฏุฉ):
- ุงู„ุจุญุฑ ุงู„ุดุนุฑูŠ: {meter['meter_ar']} (ุซู‚ุฉ: {meter['confidence']*100:.0f}%)
- ุงู„ุนุตุฑ: {era['era']} (ูƒู„ุงุณูŠูƒูŠ: {era['classical_probability']*100:.0f}% | ุญุฏูŠุซ: {era['modern_probability']*100:.0f}%)
- ุงู„ู…ูˆุถูˆุน: {topic['topic']} (ุซู‚ุฉ: {topic['confidence']*100:.0f}%)
ุงูƒุชุจ ุชุญู„ูŠู„ุงู‹ ุฃุฏุจูŠุงู‹ ุดุงู…ู„ุงู‹ ูŠุชุถู…ู†:
1. ุงู„ููƒุฑุฉ ุงู„ุนุงู…ุฉ ูˆุงู„ู…ุนู†ู‰ ุงู„ูƒู„ูŠ
2. ุงู„ู…ุนู†ู‰ ุงู„ุชูุตูŠู„ูŠ ู„ู„ุฃุจูŠุงุช
3. ุงู„ุฌู…ุงู„ูŠุงุช ุงู„ุจู„ุงุบูŠุฉ ูˆุงู„ุฃุณู„ูˆุจูŠุฉ
4. ุงู„ุจุญุฑ ูˆุงู„ุฅูŠู‚ุงุน ูˆุฃุซุฑู‡ู…ุง ููŠ ุงู„ู…ุนู†ู‰
5. ู„ู…ุณุฉ ู†ู‚ุฏูŠุฉ ุชู‚ูŠูŠู…ูŠุฉ
ุงู„ุชุฒู… ุจุงู„ุชุฑุชูŠุจ ุฃุนู„ุงู‡. ู„ุง ุชูƒุฑุฑ ุงู„ู…ุนู„ูˆู…ุงุช."""
try:
response = models["openai"].chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=1200,
temperature=0.7
)
explanation = response.choices[0].message.content
except Exception as e:
raise HTTPException(500, f"ุฎุทุฃ ููŠ ุงู„ุชูุณูŠุฑ: {str(e)}")
return {
"success": True,
"data": {
"meter": meter,
"era": era,
"topic": topic,
"explanation": explanation
}
}
@app.post("/generate")
def generate(req: GenerateRequest):
"""ุณุงุนุฏู†ูŠ ููŠ ุงู„ูƒุชุงุจุฉ โ€” generate classical Arabic verses."""
if not req.idea.strip():
raise HTTPException(400, "ุงู„ููƒุฑุฉ ูุงุฑุบุฉ")
if req.meter not in METER_PATTERNS and req.meter not in METER_ARABIC.values():
raise HTTPException(400, f"ุงู„ุจุญุฑ ุบูŠุฑ ู…ุนุฑูˆู: {req.meter}")
if not 1 <= req.num_verses <= 12:
raise HTTPException(400, "ุนุฏุฏ ุงู„ุฃุจูŠุงุช ุจูŠู† 1 ูˆ 12")
pattern = METER_PATTERNS.get(req.meter, "")
pattern_line = f"ุชูุนูŠู„ุฉ ุงู„ุจุญุฑ: {pattern}" if pattern else ""
prompt = f"""ุฃู†ุช ุดุงุนุฑ ุนุฑุจูŠ ู…ุชุฎุตุต ููŠ ุงู„ุนุฑูˆุถ ุงู„ูƒู„ุงุณูŠูƒูŠ.
ุงู„ู…ูˆุถูˆุน: "{req.idea}"
ุงู„ุจุญุฑ: {req.meter}
{pattern_line}
ุงู„ู…ุทู„ูˆุจ: ุงูƒุชุจ {req.num_verses} ุฃุจูŠุงุช ุดุนุฑูŠุฉ ุจุงู„ูุตุญู‰ ุงู„ูƒู„ุงุณูŠูƒูŠุฉ.
ุงู„ู‚ูˆุงุนุฏ ุงู„ุตุงุฑู…ุฉ:
- ูƒู„ ุจูŠุช ู…ู† ุดุทุฑูŠู† ุตุญูŠุญูŠู† ุนุฑูˆุถูŠุงู‹
- ู‚ุงููŠุฉ ู…ูˆุญุฏุฉ ููŠ ุฌู…ูŠุน ุงู„ุฃุจูŠุงุช
- ูุตุญู‰ ูƒู„ุงุณูŠูƒูŠุฉ ูู‚ุท
- ุงู„ุฃุจูŠุงุช ู…ุชุตู„ุฉ ูƒู‚ุตูŠุฏุฉ ูˆุงุญุฏุฉ
- ุงูƒุชุจ ุงู„ุฃุจูŠุงุช ูู‚ุทุŒ ุจุฏูˆู† ุชุฑู‚ูŠู… ุฃูˆ ุดุฑุญ
- ุณุทุฑ ูˆุงุญุฏ ู„ูƒู„ ุจูŠุชุŒ {req.num_verses} ุณุทูˆุฑ ูู‚ุท"""
try:
response = models["openai"].chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "ุฃู†ุช ุดุงุนุฑ ุนุฑุจูŠ ูƒู„ุงุณูŠูƒูŠ. ุงูƒุชุจ ุงู„ุฃุจูŠุงุช ูู‚ุทุŒ ุณุทุฑ ู„ูƒู„ ุจูŠุช."},
{"role": "user", "content": prompt}
],
temperature=0.75,
max_tokens=600
)
raw = response.choices[0].message.content.strip()
verses = [l.strip() for l in raw.split("\n") if l.strip() and len(l.strip()) > 10]
return {
"success": True,
"data": {
"verses": verses[:req.num_verses],
"meter": req.meter,
"pattern": pattern
}
}
except Exception as e:
raise HTTPException(500, f"ุฎุทุฃ ููŠ ุงู„ุชูˆู„ูŠุฏ: {str(e)}")
@app.get("/meters")
def list_meters():
"""Return all supported meters."""
return {"meters": list(METER_PATTERNS.keys())}