| import os, re, html, pickle |
| import numpy as np |
| import torch |
| from collections import Counter |
| from contextlib import asynccontextmanager |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from openai import OpenAI |
|
|
| |
| METER_MODEL_ID = "Rahaf2001/Lassen-meter-classifier" |
| ERA_MODEL_ID = "Rahaf2001/LassenEraClassifier" |
| TOPIC_MODEL_ID = "Rahaf2001/Lassen-topic-classifier" |
|
|
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| models = {} |
|
|
| |
| ARABIC_DIACRITICS = re.compile(r'[\u0617-\u061A\u064B-\u0652\u0670\u06D6-\u06ED]') |
|
|
| def clean_arabic(text: str) -> str: |
| if not text: |
| return "" |
| text = html.unescape(str(text)) |
| text = re.sub(r"<.*?>", " ", text) |
| text = text.replace("\u0640", "") |
| text = ARABIC_DIACRITICS.sub("", text) |
| text = re.sub(r'[\u0623\u0625\u0622\u0671]', '\u0627', text) |
| text = text.replace("\u0629", "\u0647") |
| text = re.sub(r"[0-9\u0660-\u0669]", " ", text) |
| text = re.sub(r"[^\u0600-\u06FF\s]", " ", text) |
| text = re.sub(r"\s+", " ", text).strip() |
| return text |
|
|
| |
| LABELS_METER = ['saree', 'kamel', 'mutakareb', 'mutadarak', 'munsareh', |
| 'madeed', 'mujtath', 'ramal', 'baseet', 'khafeef', |
| 'taweel', 'wafer', 'hazaj', 'rajaz'] |
|
|
| METER_ARABIC = { |
| 'saree': 'ุงูุณุฑูุน', 'kamel': 'ุงููุงู
ู', 'mutakareb': 'ุงูู
ุชูุงุฑุจ', |
| 'mutadarak': 'ุงูู
ุชุฏุงุฑู', 'munsareh': 'ุงูู
ูุณุฑุญ', 'madeed': 'ุงูู
ุฏูุฏ', |
| 'mujtath': 'ุงูู
ุฌุชุซ', 'ramal': 'ุงูุฑู
ู', 'baseet': 'ุงูุจุณูุท', |
| 'khafeef': 'ุงูุฎููู', 'taweel': 'ุงูุทููู', 'wafer': 'ุงููุงูุฑ', |
| 'hazaj': 'ุงููุฒุฌ', 'rajaz': 'ุงูุฑุฌุฒ' |
| } |
|
|
| |
| METER_PATTERNS = { |
| 'ุงูุทููู': 'ููุนูููููู ู
ูููุงุนูููููู ููุนูููููู ู
ูููุงุนููููู', |
| 'ุงููุงู
ู': 'ู
ูุชูููุงุนููููู ู
ูุชูููุงุนููููู ู
ูุชูููุงุนููููู', |
| 'ุงูุจุณูุท': 'ู
ูุณูุชูููุนููููู ููุงุนููููู ู
ูุณูุชูููุนููููู ููุงุนููููู', |
| 'ุงููุงูุฑ': 'ู
ูููุงุนูููุชููู ู
ูููุงุนูููุชููู ููุนูููููู', |
| 'ุงูุฎููู': 'ููุงุนูููุงุชููู ู
ูุณูุชูููุนููููู ููุงุนูููุงุชููู', |
| 'ุงูุฑุฌุฒ': 'ู
ูุณูุชูููุนููููู ู
ูุณูุชูููุนููููู ู
ูุณูุชูููุนููููู', |
| 'ุงูุฑู
ู': 'ููุงุนูููุงุชููู ููุงุนูููุงุชููู ููุงุนูููุงุชููู', |
| 'ุงูุณุฑูุน': 'ู
ูุณูุชูููุนููููู ู
ูุณูุชูููุนููููู ู
ูููุนููููุงุชู', |
| 'ุงูู
ูุณุฑุญ': 'ู
ูุณูุชูููุนููููู ู
ูููุนููููุงุชู ู
ูุณูุชูููุนููููู', |
| 'ุงููุฒุฌ': 'ู
ูููุงุนูููููู ู
ูููุงุนูููููู', |
| 'ุงูู
ุชูุงุฑุจ': 'ููุนูููููู ููุนูููููู ููุนูููููู ููุนูููููู', |
| 'ุงูู
ุชุฏุงุฑู': 'ููุงุนููููู ููุงุนููููู ููุงุนููููู ููุงุนููููู', |
| 'ุงูู
ุฏูุฏ': 'ููุงุนูููุงุชููู ููุงุนููููู ููุงุนูููุงุชููู', |
| } |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| print("Loading models...") |
|
|
| |
| models["meter_tokenizer"] = AutoTokenizer.from_pretrained(METER_MODEL_ID) |
| models["meter_model"] = AutoModelForSequenceClassification.from_pretrained(METER_MODEL_ID) |
| models["meter_model"].to(device).eval() |
|
|
| |
| models["era_tokenizer"] = AutoTokenizer.from_pretrained(ERA_MODEL_ID) |
| models["era_model"] = AutoModelForSequenceClassification.from_pretrained(ERA_MODEL_ID) |
| models["era_model"].to(device).eval() |
|
|
| |
| models["topic_tokenizer"] = AutoTokenizer.from_pretrained(TOPIC_MODEL_ID) |
| models["topic_model"] = AutoModelForSequenceClassification.from_pretrained(TOPIC_MODEL_ID) |
| models["topic_model"].to(device).eval() |
|
|
| |
| topic_cfg = models["topic_model"].config |
| if hasattr(topic_cfg, "id2label"): |
| models["id2label_topic"] = {int(k): v for k, v in topic_cfg.id2label.items()} |
| else: |
| models["id2label_topic"] = {i: str(i) for i in range(topic_cfg.num_labels)} |
|
|
| |
| models["openai"] = OpenAI(api_key=OPENAI_API_KEY) |
|
|
| print(f"All models loaded on {device} โ") |
| yield |
| models.clear() |
|
|
| |
| app = FastAPI(title="Bayan API", lifespan=lifespan) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| def predict_meter(poem_text: str) -> dict: |
| verses = [v.replace("#", " ").strip() for v in poem_text.strip().split("\n") if v.strip()] |
| if not verses: |
| raise ValueError("ุงููุตูุฏุฉ ูุงุฑุบุฉ") |
| predictions = [] |
| for verse in verses: |
| inputs = models["meter_tokenizer"](verse, return_tensors="pt", truncation=True, |
| max_length=32, padding="max_length") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| probs = torch.softmax(models["meter_model"](**inputs).logits, dim=-1)[0] |
| pred_id = torch.argmax(probs).item() |
| predictions.append((LABELS_METER[pred_id], probs[pred_id].item())) |
| top_meter = Counter(p[0] for p in predictions).most_common(1)[0][0] |
| avg_conf = sum(c for _, c in predictions) / len(predictions) |
| return {"meter_ar": METER_ARABIC[top_meter], "meter_en": top_meter, |
| "confidence": round(avg_conf, 3)} |
|
|
| def predict_era(poem_text: str) -> dict: |
| cleaned = clean_arabic(poem_text) |
| enc = models["era_tokenizer"](cleaned, padding="max_length", truncation=True, |
| max_length=256, return_tensors="pt") |
| enc = {k: v.to(device) for k, v in enc.items()} |
| with torch.no_grad(): |
| probs = torch.softmax(models["era_model"](**enc).logits, dim=-1).cpu().numpy()[0] |
| label_names = ["ูุฏูู
", "ุญุฏูุซ"] |
| pred_idx = int(np.argmax(probs)) |
| return {"era": label_names[pred_idx], |
| "classical_probability": round(float(probs[0]), 4), |
| "modern_probability": round(float(probs[1]), 4)} |
|
|
| def predict_topic(poem_text: str) -> dict: |
| cleaned = clean_arabic(poem_text) |
| inputs = models["topic_tokenizer"](cleaned, truncation=True, max_length=512, |
| return_tensors="pt", padding=True) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| probs = torch.softmax(models["topic_model"](**inputs).logits, dim=-1)[0].cpu().numpy() |
| top3 = np.argsort(probs)[::-1][:3] |
| id2label = models["id2label_topic"] |
| return {"topic": id2label[int(top3[0])], |
| "confidence": round(float(probs[top3[0]]), 3), |
| "top3": [{"label": id2label[int(i)], "prob": round(float(probs[i]), 3)} for i in top3]} |
|
|
| |
| class PoemRequest(BaseModel): |
| poem: str |
|
|
| class GenerateRequest(BaseModel): |
| idea: str |
| meter: str |
| num_verses: int = 4 |
|
|
| |
| @app.get("/") |
| def root(): |
| return {"status": "ok", "service": "Bayan API"} |
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "healthy", "device": str(device)} |
|
|
| @app.post("/fasserha") |
| def fasserha(req: PoemRequest): |
| """ูุณูุฑูุง ูู โ classify meter, era, topic then generate literary analysis.""" |
| if not req.poem.strip(): |
| raise HTTPException(400, "ุงููุตูุฏุฉ ูุงุฑุบุฉ") |
| try: |
| meter = predict_meter(req.poem) |
| era = predict_era(req.poem) |
| topic = predict_topic(req.poem) |
| except Exception as e: |
| raise HTTPException(500, f"ุฎุทุฃ ูู ุงูุชุตููู: {str(e)}") |
|
|
| system_prompt = """ุฃูุช ูุงูุฏ ุฃุฏุจู ู
ุชุฎุตุต ูู ุงูุดุนุฑ ุงูุนุฑุจู ุงูููุงุณููู ูุงูุญุฏูุซ. |
| ุชุญูู ุงููุตุงุฆุฏ ุจุฃุณููุจ ุฃูุงุฏูู
ู ุฑุงููุ ูุชุณุชุฎุฏู
ุงูู
ุตุทูุญุงุช ุงูุจูุงุบูุฉ ูุงูุนุฑูุถูุฉ ุจุฏูุฉ. |
| ุฑุฏู ุฏุงุฆู
ุงู ุจุงูุนุฑุจูุฉ ุงููุตุญู.""" |
|
|
| user_prompt = f"""ุญููู ูุฐู ุงููุตูุฏุฉ: |
| |
| {req.poem} |
| |
| ู
ุนุทูุงุช ุงููู
ุงุฐุฌ (ุญูุงุฆู ู
ุคูุฏุฉ): |
| - ุงูุจุญุฑ ุงูุดุนุฑู: {meter['meter_ar']} (ุซูุฉ: {meter['confidence']*100:.0f}%) |
| - ุงูุนุตุฑ: {era['era']} (ููุงุณููู: {era['classical_probability']*100:.0f}% | ุญุฏูุซ: {era['modern_probability']*100:.0f}%) |
| - ุงูู
ูุถูุน: {topic['topic']} (ุซูุฉ: {topic['confidence']*100:.0f}%) |
| |
| ุงูุชุจ ุชุญูููุงู ุฃุฏุจูุงู ุดุงู
ูุงู ูุชุถู
ู: |
| 1. ุงูููุฑุฉ ุงูุนุงู
ุฉ ูุงูู
ุนูู ุงูููู |
| 2. ุงูู
ุนูู ุงูุชูุตููู ููุฃุจูุงุช |
| 3. ุงูุฌู
ุงููุงุช ุงูุจูุงุบูุฉ ูุงูุฃุณููุจูุฉ |
| 4. ุงูุจุญุฑ ูุงูุฅููุงุน ูุฃุซุฑูู
ุง ูู ุงูู
ุนูู |
| 5. ูู
ุณุฉ ููุฏูุฉ ุชูููู
ูุฉ |
| |
| ุงูุชุฒู
ุจุงูุชุฑุชูุจ ุฃุนูุงู. ูุง ุชูุฑุฑ ุงูู
ุนููู
ุงุช.""" |
|
|
| try: |
| response = models["openai"].chat.completions.create( |
| model="gpt-4o", |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ], |
| max_tokens=1200, |
| temperature=0.7 |
| ) |
| explanation = response.choices[0].message.content |
| except Exception as e: |
| raise HTTPException(500, f"ุฎุทุฃ ูู ุงูุชูุณูุฑ: {str(e)}") |
|
|
| return { |
| "success": True, |
| "data": { |
| "meter": meter, |
| "era": era, |
| "topic": topic, |
| "explanation": explanation |
| } |
| } |
|
|
|
|
| @app.post("/generate") |
| def generate(req: GenerateRequest): |
| """ุณุงุนุฏูู ูู ุงููุชุงุจุฉ โ generate classical Arabic verses.""" |
| if not req.idea.strip(): |
| raise HTTPException(400, "ุงูููุฑุฉ ูุงุฑุบุฉ") |
| if req.meter not in METER_PATTERNS and req.meter not in METER_ARABIC.values(): |
| raise HTTPException(400, f"ุงูุจุญุฑ ุบูุฑ ู
ุนุฑูู: {req.meter}") |
| if not 1 <= req.num_verses <= 12: |
| raise HTTPException(400, "ุนุฏุฏ ุงูุฃุจูุงุช ุจูู 1 ู 12") |
|
|
| pattern = METER_PATTERNS.get(req.meter, "") |
| pattern_line = f"ุชูุนููุฉ ุงูุจุญุฑ: {pattern}" if pattern else "" |
|
|
| prompt = f"""ุฃูุช ุดุงุนุฑ ุนุฑุจู ู
ุชุฎุตุต ูู ุงูุนุฑูุถ ุงูููุงุณููู. |
| |
| ุงูู
ูุถูุน: "{req.idea}" |
| ุงูุจุญุฑ: {req.meter} |
| {pattern_line} |
| |
| ุงูู
ุทููุจ: ุงูุชุจ {req.num_verses} ุฃุจูุงุช ุดุนุฑูุฉ ุจุงููุตุญู ุงูููุงุณูููุฉ. |
| |
| ุงูููุงุนุฏ ุงูุตุงุฑู
ุฉ: |
| - ูู ุจูุช ู
ู ุดุทุฑูู ุตุญูุญูู ุนุฑูุถูุงู |
| - ูุงููุฉ ู
ูุญุฏุฉ ูู ุฌู
ูุน ุงูุฃุจูุงุช |
| - ูุตุญู ููุงุณูููุฉ ููุท |
| - ุงูุฃุจูุงุช ู
ุชุตูุฉ ููุตูุฏุฉ ูุงุญุฏุฉ |
| - ุงูุชุจ ุงูุฃุจูุงุช ููุทุ ุจุฏูู ุชุฑููู
ุฃู ุดุฑุญ |
| - ุณุทุฑ ูุงุญุฏ ููู ุจูุชุ {req.num_verses} ุณุทูุฑ ููุท""" |
|
|
| try: |
| response = models["openai"].chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[ |
| {"role": "system", "content": "ุฃูุช ุดุงุนุฑ ุนุฑุจู ููุงุณููู. ุงูุชุจ ุงูุฃุจูุงุช ููุทุ ุณุทุฑ ููู ุจูุช."}, |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.75, |
| max_tokens=600 |
| ) |
| raw = response.choices[0].message.content.strip() |
| verses = [l.strip() for l in raw.split("\n") if l.strip() and len(l.strip()) > 10] |
| return { |
| "success": True, |
| "data": { |
| "verses": verses[:req.num_verses], |
| "meter": req.meter, |
| "pattern": pattern |
| } |
| } |
| except Exception as e: |
| raise HTTPException(500, f"ุฎุทุฃ ูู ุงูุชูููุฏ: {str(e)}") |
|
|
|
|
| @app.get("/meters") |
| def list_meters(): |
| """Return all supported meters.""" |
| return {"meters": list(METER_PATTERNS.keys())} |
|
|