drrobot9's picture
Initial commit
86e9c13 verified
# app/models/llm.py
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from app.utils.config import DEVICE, QWEN_MODEL_ID
tokenizer = AutoTokenizer.from_pretrained(
QWEN_MODEL_ID,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_ID,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
trust_remote_code=True
)
model.eval()
def strip_markdown(text: str) -> str:
text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
text = re.sub(r"\*(.*?)\*", r"\1", text)
text = re.sub(r"#+\s*", "", text)
text = re.sub(r"-\s+", "", text)
text = re.sub(r"`+", "", text)
text = text.replace("---", "")
return text.strip()
@torch.no_grad()
def infer_common_name(
species: str,
domain: str,
max_tokens: int = 16,
) -> str | None:
"""
Uses LLM to infer the most widely accepted English common name.
Returns None if no common name exists.
"""
messages = [
{
"role": "system",
"content": (
"You are a taxonomy assistant. "
"Return ONLY the most widely used English common name "
"for the given scientific name. "
"Do not explain or add extra text."
"dont respond null for common name make sure u provide a common name"
)
},
{
"role": "user",
"content": f"Scientific name: {species} ({domain})"
}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=0.0,
pad_token_id=tokenizer.eos_token_id,
)
generated_ids = outputs[:, inputs.input_ids.shape[1]:]
response = tokenizer.decode(
generated_ids[0],
skip_special_tokens=True
).strip()
if not response or response.lower() == "none":
return None
return response
def _build_messages(
species: str,
confidence: float,
domain: str,
top_k: list | None = None,
):
alternatives = ""
if top_k:
alternatives = "\n".join(
[f"{x['species']} ({x['similarity']:.2f})" for x in top_k[1:]]
)
system_message = (
"You are a scientific biodiversity assistant. "
"Provide factual, neutral descriptions of species. "
"Do not mention instructions, rules, or formatting. "
"Do not use markdown or bullet points."
)
user_message = (
f"Species: {species}\n"
f"Confidence: {confidence:.2f}\n\n"
f"Alternative candidates:\n"
f"{alternatives if alternatives else 'None'}\n\n"
"Provide a factual description covering physical traits, "
"natural habitat and distribution, diet or ecological role, "
"conservation status, and relevant human interactions. "
)
return [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
@torch.no_grad()
def explain_species(
species: str,
confidence: float,
domain: str,
top_k: list | None = None,
max_tokens: int = 512,
):
"""
Returns:
{
"common_name": str | None,
"description": str
}
"""
COMMON_NAME_MIN_CONFIDENCE = 0.01
common_name = None
if confidence >= COMMON_NAME_MIN_CONFIDENCE:
common_name = infer_common_name(species, domain)
messages = _build_messages(species, confidence, domain, top_k)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
outputs = model.generate(
**model_inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=0.0,
pad_token_id=tokenizer.eos_token_id,
)
generated_ids = outputs[:, model_inputs.input_ids.shape[1]:]
response = tokenizer.decode(
generated_ids[0],
skip_special_tokens=True
)
return {
"common_name": common_name,
"description": strip_markdown(response),
}