| import torch |
| from fastapi import FastAPI |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import uvicorn |
| import os |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| model_id = "google/gemma-2b" |
|
|
| |
| |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| print(f"Loading {model_id} on {device}...") |
|
|
| try: |
| hf_token = os.environ.get("HF_TOKEN") |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| attn_implementation="eager", |
| token=hf_token |
| ).to(device) |
| print("Model loaded successfully.") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| print("Make sure you are logged into Hugging Face and have access to the Gemma model.") |
| print("Run `huggingface-cli login` in your terminal.") |
|
|
| class TextRequest(BaseModel): |
| text: str |
|
|
| @app.post("/analyze") |
| async def analyze_text(request: TextRequest): |
| text = request.text |
| if not text.strip(): |
| return {"tokens": [], "scores": []} |
|
|
| inputs = tokenizer(text, return_tensors="pt").to(device) |
| |
| with torch.no_grad(): |
| |
| outputs = model(**inputs, output_attentions=True) |
| |
| |
| if not outputs.attentions: |
| print("Warning: Model did not return attentions.") |
| return {"words": []} |
| |
| |
| |
| attentions = outputs.attentions[-1] |
| |
| |
| avg_attention = attentions[0].mean(dim=0) |
| |
| |
| importance = avg_attention.sum(dim=0).cpu().float().numpy() |
| |
| if len(importance) > 1: |
| |
| |
| min_score = importance[1:].min() |
| max_score = importance[1:].max() |
| |
| normalized_scores = (importance - min_score) / (max_score - min_score) |
| |
| normalized_scores[0] = 1.0 |
| normalized_scores = normalized_scores.clip(0, 1) |
| else: |
| normalized_scores = [1.0] * len(importance) |
| |
| input_ids = inputs["input_ids"][0].tolist() |
| tokens = tokenizer.convert_ids_to_tokens(input_ids) |
| |
| result = [] |
| for i, t in enumerate(tokens): |
| |
| word = tokenizer.decode([input_ids[i]]) |
| |
| |
| raw_clean = t.replace('\u2581', ' ') |
| |
| |
| result.append({ |
| "token": raw_clean, |
| "word": word, |
| "score": float(normalized_scores[i]) |
| }) |
| |
| return {"words": result} |
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", 7860)) |
| uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True) |