Spaces:
Sleeping
Sleeping
| import torch | |
| from tokenizers import Tokenizer | |
| from torch.utils.data import DataLoader | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| from model import CustomDataset, TransformerEncoder, load_model_to_cpu | |
| app = FastAPI() | |
| tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4} | |
| id2tag = {value: key for key, value in tag2id.items()} | |
| device = torch.device('cpu') | |
| def predict_fonk(model, device, example, tokenizer): | |
| model.to(device) | |
| model.eval() | |
| predictions = [] | |
| encodings_prdict = tokenizer.encode(example) | |
| predict_texts = [encodings_prdict.tokens] | |
| predict_input_ids = [encodings_prdict.ids] | |
| predict_attention_masks = [encodings_prdict.attention_mask] | |
| predict_token_type_ids = [encodings_prdict.type_ids] | |
| prediction_labels = [encodings_prdict.type_ids] | |
| predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids, | |
| prediction_labels) | |
| predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False) | |
| with torch.no_grad(): | |
| for dataset in predict_loader: | |
| batch_input_ids = dataset['input_ids'].to(device) | |
| batch_att_mask = dataset['attention_mask'].to(device) | |
| outputs = model(batch_input_ids, batch_att_mask) | |
| logits = outputs.view(-1, outputs.size(-1)) # Flatten the outputs | |
| _, predicted = torch.max(logits, 1) | |
| # Ignore padding tokens for predictions | |
| predictions.append(predicted) | |
| results_list = [] | |
| entity_list = [] | |
| results_dict = {} | |
| trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0]) | |
| for i, (token, label, attention) in enumerate(trio): | |
| if attention != 0 and label != 0 and label !=4: | |
| for next_ones in predictions[0].tolist()[i+1:]: | |
| i+=1 | |
| if next_ones == 4: | |
| token = token +" "+ predict_loader.dataset[0]["text"][i] | |
| else:break | |
| if token not in entity_list: | |
| entity_list.append(token) | |
| results_list.append({"entity":token,"sentiment":id2tag.get(label)}) | |
| results_dict["entity_list"] = entity_list | |
| results_dict["results"] = results_list | |
| return results_dict | |
| model = TransformerEncoder() | |
| model = load_model_to_cpu(model, "model.pth") | |
| tokenizer = Tokenizer.from_file("tokenizer.json") | |
| class Item(BaseModel): | |
| text: str = Field(..., example="""Fiber 100mb SuperOnline kullanıcısıyım yaklaşık 2 haftadır @Twitch @Kick_Turkey gibi canlı yayın platformlarında 360p yayın izlerken donmalar yaşıyoruz. Başka hiç bir operatörler bu sorunu yaşamazken ben parasını verip alamadığım hizmeti neden ödeyeyim ? @Turkcell """) | |
| async def predict(item: Item): | |
| predict_list = predict_fonk(model=model, device=device, example=item.text, tokenizer=tokenizer) | |
| #Buraya model'in çıktısı gelecek | |
| #Çıktı formatı aşağıdaki örnek gibi olacak | |
| return predict_list | |
| if __name__=="__main__": | |
| uvicorn.run(app,host="0.0.0.0",port=8000) |