obx0x3's picture
Update app.py
9c14bb9 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from fastapi.responses import JSONResponse
import torch
import uvicorn
app = FastAPI()
# === Load local HF models ===
# Text generation (DialoGPT)
dialogpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
dialogpt_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
# Emotion detection
emotion = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
# === Input schema ===
class PromptRequest(BaseModel):
message: str
lang: str = "en" # Optional for future logic
@app.post("/generate")
async def generate_response(payload: PromptRequest):
message = payload.message.strip()
if not message:
return JSONResponse(content={"reply": "Please say something."}, status_code=400)
# Step 1: Emotion detection
emotion_result = emotion(message)[0]
detected_emotion = emotion_result["label"]
emotion_score = round(emotion_result["score"], 3)
# Step 2: Generate response
input_ids = dialogpt_tokenizer.encode(message + dialogpt_tokenizer.eos_token, return_tensors="pt")
output = dialogpt_model.generate(input_ids, max_length=100, pad_token_id=dialogpt_tokenizer.eos_token_id)
response_text = dialogpt_tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return {
"reply": response_text,
"emotion": detected_emotion,
"confidence": emotion_score,
"language": payload.lang
}
# Required for running on HF Space
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)