Text Generation
Transformers
English
Russian
legal
SkillForge45 commited on
Commit
65fba8a
·
verified ·
1 Parent(s): e618d88

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from model import SimpleTransformerModel, FullChatDataset, VoiceInterface, generate_response
5
+ import torch
6
+ import uvicorn
7
+ import os
8
+ from typing import Optional
9
+
10
+ app = FastAPI()
11
+
12
+ # CORS middleware
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Initialize components
21
+ dataset = FullChatDataset()
22
+ model = SimpleTransformerModel(len(dataset.tokenizer))
23
+ voice_interface = VoiceInterface()
24
+
25
+ class ChatRequest(BaseModel):
26
+ prompt: str
27
+ max_length: int = 100
28
+ use_voice: bool = False
29
+
30
+ @app.post("/chat/")
31
+ async def chat_endpoint(
32
+ prompt: Optional[str] = Form(None),
33
+ max_length: int = Form(100),
34
+ use_voice: bool = Form(False),
35
+ audio_file: Optional[UploadFile] = File(None)
36
+ ):
37
+ try:
38
+ # Handle voice input if audio file provided
39
+ if audio_file:
40
+ contents = await audio_file.read()
41
+ with open("temp_audio.wav", "wb") as f:
42
+ f.write(contents)
43
+
44
+ with sr.AudioFile("temp_audio.wav") as source:
45
+ audio = voice_interface.recognizer.record(source)
46
+ prompt = voice_interface.recognizer.recognize_google(audio)
47
+ os.remove("temp_audio.wav")
48
+
49
+ # If no prompt provided (either text or voice)
50
+ if not prompt:
51
+ raise HTTPException(status_code=400, detail="No input provided")
52
+
53
+ response = generate_response(
54
+ model,
55
+ dataset.tokenizer,
56
+ prompt,
57
+ max_length,
58
+ voice_interface if use_voice else None
59
+ )
60
+
61
+ return {"response": response}
62
+
63
+ except Exception as e:
64
+ raise HTTPException(status_code=500, detail=str(e))
65
+
66
+ @app.get("/")
67
+ async def read_root():
68
+ return {"message": "CyberFuture Running"}
69
+
70
+ if __name__ == "__main__":
71
+ uvicorn.run(app, host="0.0.0.0", port=8000)